소스 검색

更新模块儿版本,支持自定义 baseurl (#92)

二丫讲梵 2 년 전
부모
커밋
4e56dad105
10개의 변경된 파일47개의 추가작업 그리고 66개의 파일을 삭제
  1. 2 1
      README.md
  2. 1 0
      config.dev.json
  3. 13 7
      config/config.go
  4. 1 2
      go.mod
  5. 2 4
      go.sum
  6. 2 2
      main.go
  7. 17 37
      pkg/chatgpt/chatgpt.go
  8. 1 2
      pkg/chatgpt/chatgpt_test.go
  9. 5 5
      pkg/chatgpt/context.go
  10. 3 6
      pkg/chatgpt/context_test.go

+ 2 - 1
README.md

@@ -82,7 +82,7 @@
 
 ```sh
 # 运行项目
-$ docker run -itd --name chatgpt -p 8090:8090 --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest
+$ docker run -itd --name chatgpt -p 8090:8090 --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest
 ```
 
 `📢 注意:`如果使用docker部署,那么proxy地址可以直接使用如上方式部署,`host.docker.internal`会指向容器所在宿主机的IP,只需要更改端口为你的代理端口即可。参见:[Docker容器如何优雅地访问宿主机网络](https://wiki.eryajf.net/pages/674f53/)
@@ -239,6 +239,7 @@ $ go run main.go
 ```json
 {
     "api_key": "xxxxxxxxx",   // openai api_key
+    "base_url": "api.openai.com", //  如果你想指定请求url的地址,可通过这个参数进行配置,默认为官方地址,不需要再添加 /v1
     "model": "gpt-3.5-turbo", // 指定模型,默认为 gpt-3.5-turbo ,具体选项参考官网训练场
     "session_timeout": 600,   // 会话超时时间,默认600秒,在会话时间内所有发送给机器人的信息会作为上下文
     "http_proxy": "",         // 指定请求时使用的代理,如果为空,则不使用代理

+ 1 - 0
config.dev.json

@@ -1,5 +1,6 @@
 {
     "api_key": "xxxxxxxxx",
+    "base_url": "",
     "model": "gpt-3.5-turbo",
     "session_timeout": 600,
     "http_proxy": "",

+ 13 - 7
config/config.go

@@ -15,6 +15,8 @@ import (
 type Configuration struct {
 	// gtp apikey
 	ApiKey string `json:"api_key"`
+	// 请求的 URL 地址
+	BaseURL string `json:"base_url"`
 	// 使用模型
 	Model string `json:"model"`
 	// 会话超时时间
@@ -46,18 +48,22 @@ func LoadConfig() *Configuration {
 			return
 		}
 		// 如果环境变量有配置,读取环境变量
-		ApiKey := os.Getenv("APIKEY")
+		apiKey := os.Getenv("APIKEY")
+		baseURL := os.Getenv("BASE_URL")
 		model := os.Getenv("MODEL")
-		SessionTimeout := os.Getenv("SESSION_TIMEOUT")
+		sessionTimeout := os.Getenv("SESSION_TIMEOUT")
 		defaultMode := os.Getenv("DEFAULT_MODE")
 		httpProxy := os.Getenv("HTTP_PROXY")
-		if ApiKey != "" {
-			config.ApiKey = ApiKey
+		if apiKey != "" {
+			config.ApiKey = apiKey
 		}
-		if SessionTimeout != "" {
-			duration, err := strconv.ParseInt(SessionTimeout, 10, 64)
+		if baseURL != "" {
+			config.BaseURL = baseURL
+		}
+		if sessionTimeout != "" {
+			duration, err := strconv.ParseInt(sessionTimeout, 10, 64)
 			if err != nil {
-				logger.Danger(fmt.Sprintf("config session timeout err: %v ,get is %v", err, SessionTimeout))
+				logger.Danger(fmt.Sprintf("config session timeout err: %v ,get is %v", err, sessionTimeout))
 				return
 			}
 			config.SessionTimeout = time.Duration(duration) * time.Second

+ 1 - 2
go.mod

@@ -9,8 +9,7 @@ require (
 )
 
 require (
-	github.com/joho/godotenv v1.5.1 // indirect
-	github.com/sashabaranov/go-gpt3 v1.3.0 // indirect
+	github.com/sashabaranov/go-openai v1.5.0 // indirect
 	golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect
 )
 

+ 2 - 4
go.sum

@@ -1,11 +1,9 @@
 github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY=
 github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I=
-github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
-github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
 github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
 github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
-github.com/sashabaranov/go-gpt3 v1.3.0 h1:IbvaK2yTnlm7f/oiC2HC9cbzu/4Znt4GkarFiwZ60uI=
-github.com/sashabaranov/go-gpt3 v1.3.0/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
+github.com/sashabaranov/go-openai v1.5.0 h1:4Gr/7g/KtVzW0ddn7TC2aUlyzvhZBIM+qRZ6Ae2kMa0=
+github.com/sashabaranov/go-openai v1.5.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
 golang.org/x/net v0.0.0-20211029224645-99673261e6eb h1:pirldcYWx7rx7kE5r+9WsOXPXK0+WH5+uZ7uPmJ44uM=
 golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

+ 2 - 2
main.go

@@ -212,13 +212,13 @@ func Do(mode string, rmsg public.ReceiveMsg) error {
 }
 
 func SingleQa(question, userId string) (answer string, err error) {
-	chat := chatgpt.New(public.Config.ApiKey, public.Config.HttpProxy, userId, public.Config.SessionTimeout)
+	chat := chatgpt.New(userId)
 	defer chat.Close()
 	return chat.ChatWithContext(question)
 }
 
 func ContextQa(question, userId string) (chat *chatgpt.ChatGPT, answer string, err error) {
-	chat = chatgpt.New(public.Config.ApiKey, public.Config.HttpProxy, userId, public.Config.SessionTimeout)
+	chat = chatgpt.New(userId)
 	if public.UserService.GetUserSessionContext(userId) != "" {
 		err = chat.ChatContext.LoadConversation(userId)
 		if err != nil {

+ 17 - 37
pkg/chatgpt/chatgpt.go

@@ -6,11 +6,12 @@ import (
 	"net/url"
 	"time"
 
-	gogpt "github.com/sashabaranov/go-gpt3"
+	"github.com/eryajf/chatgpt-dingtalk/public"
+	openai "github.com/sashabaranov/go-openai"
 )
 
 type ChatGPT struct {
-	client         *gogpt.Client
+	client         *openai.Client
 	ctx            context.Context
 	userId         string
 	maxQuestionLen int
@@ -23,13 +24,15 @@ type ChatGPT struct {
 	ChatContext *ChatContext
 }
 
-func New(apiKey, proxyUrl, userId string, timeOut time.Duration) *ChatGPT {
+func New(userId string) *ChatGPT {
 	var ctx context.Context
 	var cancel func()
-	if timeOut == 0 {
+
+	// public.Config.BaseURL, public.Config.ApiKey, public.Config.HttpProxy
+	if public.Config.SessionTimeout == 0 {
 		ctx, cancel = context.WithCancel(context.Background())
 	} else {
-		ctx, cancel = context.WithTimeout(context.Background(), timeOut)
+		ctx, cancel = context.WithTimeout(context.Background(), public.Config.SessionTimeout)
 	}
 	timeOutChan := make(chan struct{}, 1)
 	go func() {
@@ -37,22 +40,26 @@ func New(apiKey, proxyUrl, userId string, timeOut time.Duration) *ChatGPT {
 		timeOutChan <- struct{}{} // 发送超时信号,或是提示结束,用于聊天机器人场景,配合GetTimeOutChan() 使用
 	}()
 
-	config := gogpt.DefaultConfig(apiKey)
-	if proxyUrl != "" {
+	config := openai.DefaultConfig(public.Config.ApiKey)
+	if public.Config.HttpProxy != "" {
 		config.HTTPClient.Transport = &http.Transport{
 			// 设置代理
 			Proxy: func(req *http.Request) (*url.URL, error) {
-				return url.Parse(proxyUrl)
+				return url.Parse(public.Config.HttpProxy)
 			}}
 	}
+	if public.Config.BaseURL != "" {
+		config.BaseURL = public.Config.BaseURL + "/v1"
+	}
+
 	return &ChatGPT{
-		client:         gogpt.NewClientWithConfig(config),
+		client:         openai.NewClientWithConfig(config),
 		ctx:            ctx,
 		userId:         userId,
 		maxQuestionLen: 2048, // 最大问题长度
 		maxAnswerLen:   2048, // 最大答案长度
 		maxText:        4096, // 最大文本 = 问题 + 回答, 接口限制
-		timeOut:        timeOut,
+		timeOut:        public.Config.SessionTimeout,
 		doneChan:       timeOutChan,
 		cancel: func() {
 			cancel()
@@ -75,30 +82,3 @@ func (c *ChatGPT) SetMaxQuestionLen(maxQuestionLen int) int {
 	c.maxQuestionLen = maxQuestionLen
 	return c.maxQuestionLen
 }
-
-// func (c *ChatGPT) Chat(question string) (answer string, err error) {
-// 	question = question + "."
-// 	if len(question) > c.maxQuestionLen {
-// 		return "", OverMaxQuestionLength
-// 	}
-// 	if len(question)+c.maxAnswerLen > c.maxText {
-// 		question = question[:c.maxText-c.maxAnswerLen]
-// 	}
-// 	req := gogpt.CompletionRequest{
-// 		Model:            gogpt.GPT3TextDavinci003,
-// 		MaxTokens:        c.maxAnswerLen,
-// 		Prompt:           question,
-// 		Temperature:      0.9,
-// 		TopP:             1,
-// 		N:                1,
-// 		FrequencyPenalty: 0,
-// 		PresencePenalty:  0.5,
-// 		User:             c.userId,
-// 		Stop:             []string{},
-// 	}
-// 	resp, err := c.client.CreateCompletion(c.ctx, req)
-// 	if err != nil {
-// 		return "", err
-// 	}
-// 	return formatAnswer(resp.Choices[0].Text), err
-// }

+ 1 - 2
pkg/chatgpt/chatgpt_test.go

@@ -3,11 +3,10 @@ package chatgpt
 import (
 	"fmt"
 	"testing"
-	"time"
 )
 
 func TestChatGPT_ChatWithContext(t *testing.T) {
-	chat := New("CHATGPT_API_KEY", "", "", 10*time.Minute)
+	chat := New("")
 	defer chat.Close()
 	//go func() {
 	//	select {

+ 5 - 5
pkg/chatgpt/context.go

@@ -7,7 +7,7 @@ import (
 	"strings"
 
 	"github.com/eryajf/chatgpt-dingtalk/public"
-	gogpt "github.com/sashabaranov/go-gpt3"
+	openai "github.com/sashabaranov/go-openai"
 )
 
 var (
@@ -157,10 +157,10 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
 		return "", OverMaxTextLength
 	}
 
-	if public.Config.Model == gogpt.GPT3Dot5Turbo0301 || public.Config.Model == gogpt.GPT3Dot5Turbo {
-		req := gogpt.ChatCompletionRequest{
+	if public.Config.Model == openai.GPT3Dot5Turbo0301 || public.Config.Model == openai.GPT3Dot5Turbo {
+		req := openai.ChatCompletionRequest{
 			Model: public.Config.Model,
-			Messages: []gogpt.ChatCompletionMessage{
+			Messages: []openai.ChatCompletionMessage{
 				{
 					Role:    "user",
 					Content: prompt,
@@ -182,7 +182,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
 		c.ChatContext.seqTimes++
 		return resp.Choices[0].Message.Content, nil
 	} else {
-		req := gogpt.CompletionRequest{
+		req := openai.CompletionRequest{
 			Model:            public.Config.Model,
 			MaxTokens:        c.maxAnswerLen,
 			Prompt:           prompt,

+ 3 - 6
pkg/chatgpt/context_test.go

@@ -3,9 +3,6 @@ package chatgpt
 import (
 	"os"
 	"testing"
-	"time"
-
-	"github.com/joho/godotenv"
 )
 
 func TestOfflineContext(t *testing.T) {
@@ -13,7 +10,7 @@ func TestOfflineContext(t *testing.T) {
 	if key == "" {
 		t.Skip("CHATGPT_API_KEY is not set")
 	}
-	cli := New(key, "", "user1", time.Second*30)
+	cli := New("")
 	reply, err := cli.ChatWithContext("我叫老三,你是?")
 	if err != nil {
 		t.Fatal(err)
@@ -56,7 +53,7 @@ func TestMaintainContext(t *testing.T) {
 	if key == "" {
 		t.Skip("CHATGPT_API_KEY is not set")
 	}
-	cli := New(key, "", "user1", time.Second*30)
+	cli := New("")
 	cli.ChatContext = NewContext(
 		WithMaxSeqTimes(1),
 		WithMaintainSeqTimes(true),
@@ -80,5 +77,5 @@ func TestMaintainContext(t *testing.T) {
 
 func init() {
 	// 本地加载适用于本地测试,如果要在github进行测试,可以透过传入 secrets 到环境参数
-	_ = godotenv.Load(".env.local")
+	// _ = godotenv.Load(".env.local")
 }