Browse Source

feat: 支持配置proxy的能力 (#67)

二丫讲梵 2 years ago
parent
commit
6d4d1df808
7 changed files with 39 additions and 18 deletions
  1. 11 4
      README.md
  2. 1 0
      config.dev.json
  3. 6 4
      config/config.go
  4. 4 4
      main.go
  5. 14 3
      pkg/chatgpt/chatgpt.go
  6. 1 1
      pkg/chatgpt/chatgpt_test.go
  7. 2 2
      pkg/chatgpt/context_test.go

File diff suppressed because it is too large
+ 11 - 4
README.md


+ 1 - 0
config.dev.json

@@ -1,5 +1,6 @@
 {
 {
     "api_key": "xxxxxxxxx",
     "api_key": "xxxxxxxxx",
     "session_timeout": 600,
     "session_timeout": 600,
+    "http_proxy": "",
     "default_mode": "单聊"
     "default_mode": "单聊"
 }
 }

+ 6 - 4
config/config.go

@@ -19,6 +19,8 @@ type Configuration struct {
 	SessionTimeout time.Duration `json:"session_timeout"`
 	SessionTimeout time.Duration `json:"session_timeout"`
 	// 默认对话模式
 	// 默认对话模式
 	DefaultMode string `json:"default_mode"`
 	DefaultMode string `json:"default_mode"`
+	// 代理地址
+	HttpProxy string `json:"http_proxy"`
 }
 }
 
 
 var config *Configuration
 var config *Configuration
@@ -45,10 +47,7 @@ func LoadConfig() *Configuration {
 		ApiKey := os.Getenv("APIKEY")
 		ApiKey := os.Getenv("APIKEY")
 		SessionTimeout := os.Getenv("SESSION_TIMEOUT")
 		SessionTimeout := os.Getenv("SESSION_TIMEOUT")
 		defaultMode := os.Getenv("DEFAULT_MODE")
 		defaultMode := os.Getenv("DEFAULT_MODE")
-		// Model := os.Getenv("MODEL")
-		// MaxTokens := os.Getenv("MAX_TOKENS")
-		// Temperature := os.Getenv("TEMPREATURE")
-		// SessionClearToken := os.Getenv("SESSION_CLEAR_TOKEN")
+		httpProxy := os.Getenv("HTTP_PROXY")
 		if ApiKey != "" {
 		if ApiKey != "" {
 			config.ApiKey = ApiKey
 			config.ApiKey = ApiKey
 		}
 		}
@@ -65,6 +64,9 @@ func LoadConfig() *Configuration {
 		if defaultMode != "" {
 		if defaultMode != "" {
 			config.DefaultMode = defaultMode
 			config.DefaultMode = defaultMode
 		}
 		}
+		if httpProxy != "" {
+			config.HttpProxy = httpProxy
+		}
 	})
 	})
 	if config.DefaultMode == "" {
 	if config.DefaultMode == "" {
 		config.DefaultMode = "单聊"
 		config.DefaultMode = "单聊"

+ 4 - 4
main.go

@@ -22,7 +22,7 @@ func main() {
 
 
 var Welcome string = `Commands:
 var Welcome string = `Commands:
 =================================
 =================================
-🙋 单聊 👉 单独聊天,缺省
+🙋 单聊 👉 单独聊天
 📣 串聊 👉 带上下文聊天
 📣 串聊 👉 带上下文聊天
 🔃 重置 👉 重置带上下文聊天
 🔃 重置 👉 重置带上下文聊天
 🚀 帮助 👉 显示帮助信息
 🚀 帮助 👉 显示帮助信息
@@ -122,7 +122,7 @@ func Do(mode string, rmsg public.ReceiveMsg) error {
 	public.UserService.SetUserMode(rmsg.SenderStaffId, mode)
 	public.UserService.SetUserMode(rmsg.SenderStaffId, mode)
 	switch mode {
 	switch mode {
 	case "单聊":
 	case "单聊":
-		reply, err := SingleQa(rmsg.Text.Content, rmsg.SenderNick)
+		reply, err := SingleQa(rmsg.Text.Content, rmsg.SenderStaffId)
 		if err != nil {
 		if err != nil {
 			logger.Info(fmt.Errorf("gpt request error: %v", err))
 			logger.Info(fmt.Errorf("gpt request error: %v", err))
 			if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
 			if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
@@ -195,14 +195,14 @@ func Do(mode string, rmsg public.ReceiveMsg) error {
 
 
 func SingleQa(question, userId string) (answer string, err error) {
 func SingleQa(question, userId string) (answer string, err error) {
 	cfg := config.LoadConfig()
 	cfg := config.LoadConfig()
-	chat := chatgpt.New(cfg.ApiKey, userId, cfg.SessionTimeout)
+	chat := chatgpt.New(cfg.ApiKey, cfg.HttpProxy, userId, cfg.SessionTimeout)
 	defer chat.Close()
 	defer chat.Close()
 	return chat.ChatWithContext(question)
 	return chat.ChatWithContext(question)
 }
 }
 
 
 func ContextQa(question, userId string) (chat *chatgpt.ChatGPT, answer string, err error) {
 func ContextQa(question, userId string) (chat *chatgpt.ChatGPT, answer string, err error) {
 	cfg := config.LoadConfig()
 	cfg := config.LoadConfig()
-	chat = chatgpt.New(cfg.ApiKey, userId, cfg.SessionTimeout)
+	chat = chatgpt.New(cfg.ApiKey, cfg.HttpProxy, userId, cfg.SessionTimeout)
 	if public.UserService.GetUserSessionContext(userId) != "" {
 	if public.UserService.GetUserSessionContext(userId) != "" {
 		err = chat.ChatContext.LoadConversation(userId)
 		err = chat.ChatContext.LoadConversation(userId)
 		if err != nil {
 		if err != nil {

+ 14 - 3
pkg/chatgpt/chatgpt.go

@@ -2,6 +2,8 @@ package chatgpt
 
 
 import (
 import (
 	"context"
 	"context"
+	"net/http"
+	"net/url"
 	"time"
 	"time"
 
 
 	gogpt "github.com/sashabaranov/go-gpt3"
 	gogpt "github.com/sashabaranov/go-gpt3"
@@ -21,7 +23,7 @@ type ChatGPT struct {
 	ChatContext *ChatContext
 	ChatContext *ChatContext
 }
 }
 
 
-func New(ApiKey, UserId string, timeOut time.Duration) *ChatGPT {
+func New(apiKey, proxyUrl, userId string, timeOut time.Duration) *ChatGPT {
 	var ctx context.Context
 	var ctx context.Context
 	var cancel func()
 	var cancel func()
 	if timeOut == 0 {
 	if timeOut == 0 {
@@ -34,10 +36,19 @@ func New(ApiKey, UserId string, timeOut time.Duration) *ChatGPT {
 		<-ctx.Done()
 		<-ctx.Done()
 		timeOutChan <- struct{}{} // 发送超时信号,或是提示结束,用于聊天机器人场景,配合GetTimeOutChan() 使用
 		timeOutChan <- struct{}{} // 发送超时信号,或是提示结束,用于聊天机器人场景,配合GetTimeOutChan() 使用
 	}()
 	}()
+
+	config := gogpt.DefaultConfig(apiKey)
+	if proxyUrl != "" {
+		config.HTTPClient.Transport = &http.Transport{
+			// 设置代理
+			Proxy: func(req *http.Request) (*url.URL, error) {
+				return url.Parse(proxyUrl)
+			}}
+	}
 	return &ChatGPT{
 	return &ChatGPT{
-		client:         gogpt.NewClient(ApiKey),
+		client:         gogpt.NewClientWithConfig(config),
 		ctx:            ctx,
 		ctx:            ctx,
-		userId:         UserId,
+		userId:         userId,
 		maxQuestionLen: 2048, // 最大问题长度
 		maxQuestionLen: 2048, // 最大问题长度
 		maxAnswerLen:   2048, // 最大答案长度
 		maxAnswerLen:   2048, // 最大答案长度
 		maxText:        4096, // 最大文本 = 问题 + 回答, 接口限制
 		maxText:        4096, // 最大文本 = 问题 + 回答, 接口限制

+ 1 - 1
pkg/chatgpt/chatgpt_test.go

@@ -7,7 +7,7 @@ import (
 )
 )
 
 
 func TestChatGPT_ChatWithContext(t *testing.T) {
 func TestChatGPT_ChatWithContext(t *testing.T) {
-	chat := New("CHATGPT_API_KEY", "", 10*time.Minute)
+	chat := New("CHATGPT_API_KEY", "", "", 10*time.Minute)
 	defer chat.Close()
 	defer chat.Close()
 	//go func() {
 	//go func() {
 	//	select {
 	//	select {

+ 2 - 2
pkg/chatgpt/context_test.go

@@ -13,7 +13,7 @@ func TestOfflineContext(t *testing.T) {
 	if key == "" {
 	if key == "" {
 		t.Skip("CHATGPT_API_KEY is not set")
 		t.Skip("CHATGPT_API_KEY is not set")
 	}
 	}
-	cli := New(key, "user1", time.Second*30)
+	cli := New(key, "", "user1", time.Second*30)
 	reply, err := cli.ChatWithContext("我叫老三,你是?")
 	reply, err := cli.ChatWithContext("我叫老三,你是?")
 	if err != nil {
 	if err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
@@ -56,7 +56,7 @@ func TestMaintainContext(t *testing.T) {
 	if key == "" {
 	if key == "" {
 		t.Skip("CHATGPT_API_KEY is not set")
 		t.Skip("CHATGPT_API_KEY is not set")
 	}
 	}
-	cli := New(key, "user1", time.Second*30)
+	cli := New(key, "", "user1", time.Second*30)
 	cli.ChatContext = NewContext(
 	cli.ChatContext = NewContext(
 		WithMaxSeqTimes(1),
 		WithMaxSeqTimes(1),
 		WithMaintainSeqTimes(true),
 		WithMaintainSeqTimes(true),