Browse Source

feat: 添加指定模型的能力 (#74)

二丫讲梵 2 years ago
parent
commit
acb5462a01
8 changed files with 72 additions and 47 deletions
  1. 1 11
      .pre-commit-config.yaml
  2. 3 1
      README.md
  3. 1 0
      config.dev.json
  4. 6 0
      config/config.go
  5. 2 5
      main.go
  6. 53 24
      pkg/chatgpt/context.go
  7. 3 4
      public/gpt.go
  8. 3 2
      public/public.go

+ 1 - 11
.pre-commit-config.yaml

@@ -3,14 +3,4 @@ repos:
     rev: v4.4.0
     hooks:
     - id: check-yaml
-    - id: check-added-large-files
--   repo: https://github.com/golangci/golangci-lint # golangci-lint hook repo
-    rev: v1.47.3 # golangci-lint hook repo revision
-    hooks:
-    - id: golangci-lint
-      name: golangci-lint
-      description: Fast linters runner for Go.
-      entry: golangci-lint run --fix
-      types: [go]
-      language: golang
-      pass_filenames: false
+    - id: check-added-large-files

+ 3 - 1
README.md

@@ -35,6 +35,7 @@
 - 支持在钉钉群聊中添加机器人,通过@机器人进行聊天交互。
 - 提问支持单聊与串聊两种模式,通过@机器人发关键字切换。
 - 支持添加代理,通过配置化指定。
+- 支持自定义指定的模型,通过配置化指定。
 - 支持自定义默认的聊天模式,通过配置化指定。
 
 ## 使用前提
@@ -78,7 +79,7 @@
 
 ```sh
 # 运行项目
-$ docker run -itd --name chatgpt -p 8090:8090 -e APIKEY=换成你的key -e SESSION_TIMEOUT=600 -e HTTP_PROXY="" -e DEFAULT_MODE="单聊" --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest
+$ docker run -itd --name chatgpt -p 8090:8090 -e APIKEY=换成你的key -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="" -e DEFAULT_MODE="单聊" --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest
 ```
 
 `📢 注意:`如果你使用docker部署,那么proxy指定地址的时候,请指定宿主机的IP,而不要写成127,以免代理不生效。
@@ -221,6 +222,7 @@ $ go run main.go
 ```json
 {
     "api_key": "xxxxxxxxx",   // openai api_key
+    "model": "gpt-3.5-turbo", // 指定模型,默认为 gpt-3.5-turbo ,具体选项参考官网训练场
     "session_timeout": 600,   // 会话超时时间,默认600秒,在会话时间内所有发送给机器人的信息会作为上下文
     "http_proxy": "",         // 指定请求时使用的代理,如果为空,则不使用代理
     "default_mode": "单聊"    // 默认对话模式,可根据实际场景自定义,如果不设置,默认为单聊

+ 1 - 0
config.dev.json

@@ -1,5 +1,6 @@
 {
     "api_key": "xxxxxxxxx",
+    "model": "gpt-3.5-turbo",
     "session_timeout": 600,
     "http_proxy": "",
     "default_mode": "单聊"

+ 6 - 0
config/config.go

@@ -15,6 +15,8 @@ import (
 type Configuration struct {
 	// gtp apikey
 	ApiKey string `json:"api_key"`
+	// 使用模型
+	Model string `json:"model"`
 	// 会话超时时间
 	SessionTimeout time.Duration `json:"session_timeout"`
 	// 默认对话模式
@@ -45,6 +47,7 @@ func LoadConfig() *Configuration {
 		}
 		// 如果环境变量有配置,读取环境变量
 		ApiKey := os.Getenv("APIKEY")
+		model := os.Getenv("MODEL")
 		SessionTimeout := os.Getenv("SESSION_TIMEOUT")
 		defaultMode := os.Getenv("DEFAULT_MODE")
 		httpProxy := os.Getenv("HTTP_PROXY")
@@ -67,6 +70,9 @@ func LoadConfig() *Configuration {
 		if httpProxy != "" {
 			config.HttpProxy = httpProxy
 		}
+		if model != "" {
+			config.Model = model
+		}
 	})
 	if config.DefaultMode == "" {
 		config.DefaultMode = "单聊"

+ 2 - 5
main.go

@@ -8,7 +8,6 @@ import (
 	"strings"
 	"time"
 
-	"github.com/eryajf/chatgpt-dingtalk/config"
 	"github.com/eryajf/chatgpt-dingtalk/public"
 	"github.com/eryajf/chatgpt-dingtalk/public/logger"
 	"github.com/solywsh/chatgpt"
@@ -213,15 +212,13 @@ func Do(mode string, rmsg public.ReceiveMsg) error {
 }
 
 func SingleQa(question, userId string) (answer string, err error) {
-	cfg := config.LoadConfig()
-	chat := chatgpt.New(cfg.ApiKey, cfg.HttpProxy, userId, cfg.SessionTimeout)
+	chat := chatgpt.New(public.Config.ApiKey, public.Config.HttpProxy, userId, public.Config.SessionTimeout)
 	defer chat.Close()
 	return chat.ChatWithContext(question)
 }
 
 func ContextQa(question, userId string) (chat *chatgpt.ChatGPT, answer string, err error) {
-	cfg := config.LoadConfig()
-	chat = chatgpt.New(cfg.ApiKey, cfg.HttpProxy, userId, cfg.SessionTimeout)
+	chat = chatgpt.New(public.Config.ApiKey, public.Config.HttpProxy, userId, public.Config.SessionTimeout)
 	if public.UserService.GetUserSessionContext(userId) != "" {
 		err = chat.ChatContext.LoadConversation(userId)
 		if err != nil {

+ 53 - 24
pkg/chatgpt/context.go

@@ -156,31 +156,60 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
 	if len(prompt) > c.maxText-c.maxAnswerLen {
 		return "", OverMaxTextLength
 	}
-	c1 := gogpt.ChatCompletionRequest{
-		Model: gogpt.GPT3Dot5Turbo,
-		Messages: []gogpt.ChatCompletionMessage{
-			{
-				Role:    "user",
-				Content: prompt,
-			},
-		}}
-	req := c1
-
-	resp, err := c.client.CreateChatCompletion(c.ctx, req)
-	if err != nil {
-		return "", err
+
+	if public.Config.Model == gogpt.GPT3Dot5Turbo0301 || public.Config.Model == gogpt.GPT3Dot5Turbo {
+		req := gogpt.ChatCompletionRequest{
+			Model: public.Config.Model,
+			Messages: []gogpt.ChatCompletionMessage{
+				{
+					Role:    "user",
+					Content: prompt,
+				},
+			}}
+		resp, err := c.client.CreateChatCompletion(c.ctx, req)
+		if err != nil {
+			return "", err
+		}
+		resp.Choices[0].Message.Content = formatAnswer(resp.Choices[0].Message.Content)
+		c.ChatContext.old = append(c.ChatContext.old, conversation{
+			Role:   c.ChatContext.humanRole,
+			Prompt: question,
+		})
+		c.ChatContext.old = append(c.ChatContext.old, conversation{
+			Role:   c.ChatContext.aiRole,
+			Prompt: resp.Choices[0].Message.Content,
+		})
+		c.ChatContext.seqTimes++
+		return resp.Choices[0].Message.Content, nil
+	} else {
+		req := gogpt.CompletionRequest{
+			Model:            public.Config.Model,
+			MaxTokens:        c.maxAnswerLen,
+			Prompt:           prompt,
+			Temperature:      0.9,
+			TopP:             1,
+			N:                1,
+			FrequencyPenalty: 0,
+			PresencePenalty:  0.5,
+			User:             c.userId,
+			Stop:             []string{c.ChatContext.aiRole.Name + ":", c.ChatContext.humanRole.Name + ":"},
+		}
+		resp, err := c.client.CreateCompletion(c.ctx, req)
+		if err != nil {
+			return "", err
+		}
+		resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text)
+		c.ChatContext.old = append(c.ChatContext.old, conversation{
+			Role:   c.ChatContext.humanRole,
+			Prompt: question,
+		})
+		c.ChatContext.old = append(c.ChatContext.old, conversation{
+			Role:   c.ChatContext.aiRole,
+			Prompt: resp.Choices[0].Text,
+		})
+		c.ChatContext.seqTimes++
+		return resp.Choices[0].Text, nil
 	}
-	resp.Choices[0].Message.Content = formatAnswer(resp.Choices[0].Message.Content)
-	c.ChatContext.old = append(c.ChatContext.old, conversation{
-		Role:   c.ChatContext.humanRole,
-		Prompt: question,
-	})
-	c.ChatContext.old = append(c.ChatContext.old, conversation{
-		Role:   c.ChatContext.aiRole,
-		Prompt: resp.Choices[0].Message.Content,
-	})
-	c.ChatContext.seqTimes++
-	return resp.Choices[0].Message.Content, nil
 }
 
 func WithMaxSeqTimes(times int) ChatContextOption {

+ 3 - 4
public/gpt.go

@@ -5,15 +5,14 @@ import (
 	"fmt"
 	"time"
 
-	"github.com/eryajf/chatgpt-dingtalk/config"
 	"github.com/go-resty/resty/v2"
 )
 
 func InitAiCli() *resty.Client {
-	if config.LoadConfig().HttpProxy != "" {
-		return resty.New().SetTimeout(30*time.Second).SetHeader("Authorization", fmt.Sprintf("Bearer %s", config.LoadConfig().ApiKey)).SetProxy(config.LoadConfig().HttpProxy).SetRetryCount(3).SetRetryWaitTime(5 * time.Second)
+	if Config.HttpProxy != "" {
+		return resty.New().SetTimeout(30*time.Second).SetHeader("Authorization", fmt.Sprintf("Bearer %s", Config.ApiKey)).SetProxy(Config.HttpProxy).SetRetryCount(3).SetRetryWaitTime(5 * time.Second)
 	}
-	return resty.New().SetTimeout(30*time.Second).SetHeader("Authorization", fmt.Sprintf("Bearer %s", config.LoadConfig().ApiKey)).SetRetryCount(3).SetRetryWaitTime(5 * time.Second)
+	return resty.New().SetTimeout(30*time.Second).SetHeader("Authorization", fmt.Sprintf("Bearer %s", Config.ApiKey)).SetRetryCount(3).SetRetryWaitTime(5 * time.Second)
 }
 
 type Billing struct {

+ 3 - 2
public/public.go

@@ -8,9 +8,10 @@ import (
 )
 
 var UserService service.UserServiceInterface
+var Config *config.Configuration
 
 func InitSvc() {
-	config.LoadConfig()
+	Config = config.LoadConfig()
 	UserService = service.NewUserService()
 	_, _ = GetBalance()
 }
@@ -18,7 +19,7 @@ func InitSvc() {
 func FirstCheck(rmsg ReceiveMsg) bool {
 	lc := UserService.GetUserMode(rmsg.SenderStaffId)
 	if lc == "" {
-		if config.LoadConfig().DefaultMode == "串聊" {
+		if Config.DefaultMode == "串聊" {
 			return true
 		} else {
 			return false