Prechádzať zdrojové kódy

fix: 修复上下文以及清空会话等bug,添加更多参数

eryajf 2 rokov pred
rodič
commit
e1351271d9
8 zmenil súbory, kde vykonal 232 pridanie a 63 odobranie
  1. 42 2
      README.md
  2. 5 1
      config.dev.json
  3. 52 7
      config/config.go
  4. 14 10
      gtp/gtp.go
  5. 76 40
      main.go
  6. 0 0
      public/dingtalk.go
  7. 40 0
      public/logger/logger.go
  8. 3 3
      service/user.go

+ 42 - 2
README.md

@@ -65,7 +65,7 @@
 
 ```sh
 # 运行项目
-$ docker run -itd --name chatgpt -p 8090:8090 -e ApiKey=xxxx -e SessionTimeout=60s --restart=always docker.mirrors.sjtug.sjtu.edu.cn/eryajf/chatgpt-dingtalk:latest
+$ docker run -itd --name chatgpt -p 8090:8090 -e APIKEY=换成你的key -e SESSIONTIMEOUT=60s -e MODEL=text-davinci-003 -e MAX_TOKENS=512 -e TEMPREATURE=0.9 -e SESSION_CLEAR_TOKEN=清空会话 --restart=always docker.mirrors.sjtug.sjtu.edu.cn/eryajf/chatgpt-dingtalk:latest
 ```
 
 运行命令中映射的配置文件参考下边的配置文件说明。
@@ -104,6 +104,42 @@ server {
 
 部署完成之后,就可以在群里艾特机器人进行体验了。
 
+Nginx配置完毕之后,可以先手动请求一下,通过服务日志输出判断服务是否正常可用:
+
+```sh
+curl --location --request POST 'http://chat.eryajf.net/' \
+  --header 'Content-type: application/json' \
+  --data-raw '{
+    "conversationId": "xxx",
+    "atUsers": [
+        {
+            "dingtalkId": "xxx",
+            "staffId":"xxx"
+        }
+    ],
+    "chatbotCorpId": "dinge8a565xxxx",
+    "chatbotUserId": "$:LWCP_v1:$Cxxxxx",
+    "msgId": "msg0xxxxx",
+    "senderNick": "eryajf",
+    "isAdmin": true,
+    "senderStaffId": "user123",
+    "sessionWebhookExpiredTime": 1613635652738,
+    "createAt": 1613630252678,
+    "senderCorpId": "dinge8a565xxxx",
+    "conversationType": "2",
+    "senderId": "$:LWCP_v1:$Ff09GIxxxxx",
+    "conversationTitle": "机器人测试-TEST",
+    "isInAtList": true,
+    "sessionWebhook": "https://oapi.dingtalk.com/robot/sendBySession?session=xxxxx",
+    "text": {
+        "content": " 你好"
+    },
+    "msgtype": "text"
+}'
+```
+
+如果手动请求没有问题,那么就可以在钉钉群里与机器人进行对话了。
+
 效果如下:
 
 ![image_20221209_163739](https://cdn.staticaly.com/gh/eryajf/tu/main/img/image_20221209_163739.png)
@@ -146,6 +182,10 @@ $ go run main.go
 ````json
 {
     "api_key": "xxxxxxxxx",  // openai api_key
-    "session_timeout": 60    // 会话超时时间,默认60秒,在会话时间内所有发送给机器人的信息会作为上下文
+    "session_timeout": 60,   // 会话超时时间,默认60秒,在会话时间内所有发送给机器人的信息会作为上下文
+    "max_tokens": 1024,      // GPT响应字符数,最大2048,默认值512。值大小会影响接口响应速度,越大响应越慢。
+    "model": "text-davinci-003", // GPT选用模型,默认text-davinci-003,具体选项参考官网训练场
+    "temperature": 0.9, // GPT热度,0到1,默认0.9。数字越大创造力越强,但更偏离训练事实,越低越接近训练事实
+    "session_clear_token": "清空会话" // 会话清空口令,默认`清空会话`
 }
 ````

+ 5 - 1
config.dev.json

@@ -1,4 +1,8 @@
 {
     "api_key": "xxxxxxxxx",
-    "session_timeout": 60
+    "session_timeout": 60,
+    "max_tokens": 1024,
+    "model": "text-davinci-003",
+    "temperature": 0.9,
+    "session_clear_token": "清空会话"
 }

+ 52 - 7
config/config.go

@@ -2,10 +2,13 @@ package config
 
 import (
 	"encoding/json"
-	"log"
+	"fmt"
 	"os"
+	"strconv"
 	"sync"
 	"time"
+
+	"github.com/eryajf/chatgpt-dingtalk/public/logger"
 )
 
 // Configuration 项目配置
@@ -14,6 +17,14 @@ type Configuration struct {
 	ApiKey string `json:"api_key"`
 	// 会话超时时间
 	SessionTimeout time.Duration `json:"session_timeout"`
+	// GPT请求最大字符数
+	MaxTokens uint `json:"max_tokens"`
+	// GPT模型
+	Model string `json:"model"`
+	// 热度
+	Temperature float64 `json:"temperature"`
+	// 自定义清空会话口令
+	SessionClearToken string `json:"session_clear_token"`
 }
 
 var config *Configuration
@@ -24,35 +35,69 @@ func LoadConfig() *Configuration {
 	once.Do(func() {
 		// 从文件中读取
 		config = &Configuration{
-			SessionTimeout: 1,
+			SessionTimeout:    60,
+			MaxTokens:         512,
+			Model:             "text-davinci-003",
+			Temperature:       0.9,
+			SessionClearToken: "下一个问题",
 		}
 		f, err := os.Open("config.json")
 		if err != nil {
-			log.Fatalf("open config err: %v", err)
+			logger.Danger("open config err: %v", err)
 			return
 		}
 		defer f.Close()
 		encoder := json.NewDecoder(f)
 		err = encoder.Decode(config)
 		if err != nil {
-			log.Fatalf("decode config err: %v", err)
+			logger.Warning("decode config err: %v", err)
 			return
 		}
 
 		// 如果环境变量有配置,读取环境变量
-		ApiKey := os.Getenv("ApiKey")
-		SessionTimeout := os.Getenv("SessionTimeout")
+		// 有环境变量使用环境变量
+		ApiKey := os.Getenv("APIKEY")
+		SessionTimeout := os.Getenv("SESSION_TIMEOUT")
+		Model := os.Getenv("MODEL")
+		MaxTokens := os.Getenv("MAX_TOKENS")
+		Temperature := os.Getenv("TEMPREATURE")
+		SessionClearToken := os.Getenv("SESSION_CLEAR_TOKEN")
 		if ApiKey != "" {
 			config.ApiKey = ApiKey
 		}
 		if SessionTimeout != "" {
 			duration, err := time.ParseDuration(SessionTimeout)
 			if err != nil {
-				log.Fatalf("config decode session timeout err: %v ,get is %v", err, SessionTimeout)
+				logger.Danger(fmt.Sprintf("config session timeout err: %v ,get is %v", err, SessionTimeout))
 				return
 			}
 			config.SessionTimeout = duration
 		}
+		if Model != "" {
+			config.Model = Model
+		}
+		if MaxTokens != "" {
+			max, err := strconv.Atoi(MaxTokens)
+			if err != nil {
+				logger.Danger(fmt.Sprintf("config MaxTokens err: %v ,get is %v", err, MaxTokens))
+				return
+			}
+			config.MaxTokens = uint(max)
+		}
+		if Temperature != "" {
+			temp, err := strconv.ParseFloat(Temperature, 64)
+			if err != nil {
+				logger.Danger(fmt.Sprintf("config Temperature err: %v ,get is %v", err, Temperature))
+				return
+			}
+			config.Temperature = temp
+		}
+		if SessionClearToken != "" {
+			config.SessionClearToken = SessionClearToken
+		}
 	})
+	if config.ApiKey == "" {
+		logger.Danger("config err: api key required")
+	}
 	return config
 }

+ 14 - 10
gtp/gtp.go

@@ -8,8 +8,10 @@ import (
 	"io/ioutil"
 	"log"
 	"net/http"
+	"time"
 
 	"github.com/eryajf/chatgpt-dingtalk/config"
+	"github.com/eryajf/chatgpt-dingtalk/public/logger"
 )
 
 const BASEURL = "https://api.openai.com/v1/"
@@ -35,8 +37,8 @@ type ChoiceItem struct {
 type ChatGPTRequestBody struct {
 	Model            string  `json:"model"`
 	Prompt           string  `json:"prompt"`
-	MaxTokens        int     `json:"max_tokens"`
-	Temperature      float32 `json:"temperature"`
+	MaxTokens        uint    `json:"max_tokens"`
+	Temperature      float64 `json:"temperature"`
 	TopP             int     `json:"top_p"`
 	FrequencyPenalty int     `json:"frequency_penalty"`
 	PresencePenalty  int     `json:"presence_penalty"`
@@ -48,21 +50,21 @@ type ChatGPTRequestBody struct {
 //-H "Authorization: Bearer your chatGPT key"
 //-d '{"model": "text-davinci-003", "prompt": "give me good song", "temperature": 0, "max_tokens": 7}'
 func Completions(msg string) (string, error) {
+	cfg := config.LoadConfig()
 	requestBody := ChatGPTRequestBody{
-		Model:            "text-davinci-003",
+		Model:            cfg.Model,
 		Prompt:           msg,
-		MaxTokens:        1024,
-		Temperature:      0.7,
+		MaxTokens:        cfg.MaxTokens,
+		Temperature:      cfg.Temperature,
 		TopP:             1,
 		FrequencyPenalty: 0,
 		PresencePenalty:  0,
 	}
 	requestData, err := json.Marshal(requestBody)
-
 	if err != nil {
 		return "", err
 	}
-	log.Printf("request gtp json string : %v", string(requestData))
+	logger.Info(fmt.Sprintf("request gtp json string : %v", string(requestData)))
 	req, err := http.NewRequest("POST", BASEURL+"completions", bytes.NewBuffer(requestData))
 	if err != nil {
 		return "", err
@@ -71,19 +73,21 @@ func Completions(msg string) (string, error) {
 	apiKey := config.LoadConfig().ApiKey
 	req.Header.Set("Content-Type", "application/json")
 	req.Header.Set("Authorization", "Bearer "+apiKey)
-	client := &http.Client{}
+	client := &http.Client{Timeout: 30 * time.Second}
 	response, err := client.Do(req)
 	if err != nil {
 		return "", err
 	}
 	defer response.Body.Close()
 	if response.StatusCode != 200 {
-		return "", errors.New(fmt.Sprintf("gtp api status code not equals 200,code is %d", response.StatusCode))
+		body, _ := ioutil.ReadAll(response.Body)
+		return "", errors.New(fmt.Sprintf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details:  %v ", response.StatusCode, string(body)))
 	}
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
 		return "", err
 	}
+	logger.Info(fmt.Sprintf("response gtp json string : %v", string(body)))
 
 	gptResponseBody := &ChatGPTResponseBody{}
 	log.Println(string(body))
@@ -96,6 +100,6 @@ func Completions(msg string) (string, error) {
 	if len(gptResponseBody.Choices) > 0 {
 		reply = gptResponseBody.Choices[0].Text
 	}
-	log.Printf("gpt response text: %s \n", reply)
+	logger.Info(fmt.Sprintf("gpt response text: %s ", reply))
 	return reply, nil
 }

+ 76 - 40
main.go

@@ -2,13 +2,14 @@ package main
 
 import (
 	"encoding/json"
+	"fmt"
 	"io/ioutil"
-	"log"
 	"net/http"
 	"strings"
 
 	"github.com/eryajf/chatgpt-dingtalk/gtp"
 	"github.com/eryajf/chatgpt-dingtalk/public"
+	"github.com/eryajf/chatgpt-dingtalk/public/logger"
 	"github.com/eryajf/chatgpt-dingtalk/service"
 )
 
@@ -19,26 +20,35 @@ func init() {
 }
 
 func main() {
+	Start()
+}
+
+func Start() {
 	// 定义一个处理器函数
 	handler := func(w http.ResponseWriter, r *http.Request) {
 		data, err := ioutil.ReadAll(r.Body)
 		if err != nil {
 			http.Error(w, err.Error(), http.StatusBadRequest)
+			logger.Warning("read request body failed: %v\n", err.Error())
 			return
 		}
 		// TODO: 校验请求
 		// fmt.Println(r.Header)
-
-		var msgObj = new(public.ReceiveMsg)
-		err = json.Unmarshal(data, &msgObj)
-		if err != nil {
-			log.Printf("unmarshal request body failed: %v\n", err)
-		}
-		err = ProcessRequest(*msgObj)
-		if err != nil {
-			log.Printf("process request failed: %v\n", err)
+		if len(data) == 0 {
+			logger.Warning("回调参数为空,以至于无法正常解析,请检查原因")
+			return
+		} else {
+			var msgObj = new(public.ReceiveMsg)
+			err = json.Unmarshal(data, &msgObj)
+			if err != nil {
+				logger.Warning("unmarshal request body failed: %v\n", err)
+			}
+			logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
+			err = ProcessRequest(*msgObj)
+			if err != nil {
+				logger.Warning("process request failed: %v\n", err)
+			}
 		}
-
 	}
 
 	// 创建一个新的 HTTP 服务器
@@ -48,47 +58,73 @@ func main() {
 	}
 
 	// 启动服务器
-	log.Print("Start Listen On ", server.Addr)
+	logger.Info("Start Listen On ", server.Addr)
 	err := server.ListenAndServe()
 	if err != nil {
-		log.Fatal(err)
+		logger.Danger(err)
 	}
 }
 
 func ProcessRequest(rmsg public.ReceiveMsg) error {
-	// 获取问题的答案
-	reply, err := gtp.Completions(rmsg.Text.Content)
-	if err != nil {
-		log.Printf("gtp request error: %v \n", err)
-		_, err = rmsg.ReplyText("机器人太累了,让她休息会儿,过一会儿再来请求。")
+	atText := "@" + rmsg.SenderNick + "\n" + " "
+	if UserService.ClearUserSessionContext(rmsg.SenderID, rmsg.Text.Content) {
+		_, err := rmsg.ReplyText(atText + "上下文已经清空了,你可以问下一个问题啦。")
 		if err != nil {
-			log.Printf("send message error: %v \n", err)
+			logger.Warning("response user error: %v \n", err)
 			return err
 		}
-		log.Printf("request openai error: %v\n", err)
-		return err
-	}
-	if reply == "" {
-		return nil
-	}
-	// 回复@我的用户
-	reply = strings.TrimSpace(reply)
-	reply = strings.Trim(reply, "\n")
-	atText := "@" + rmsg.SenderNick + "\n" + " "
-	// 设置上下文
-	if UserService.ClearUserSessionContext(rmsg.SenderID, rmsg.Text.Content) {
-		_, err = rmsg.ReplyText(atText + "上下文已经清空了,你可以问下一个问题啦。")
+	} else {
+		requestText := getRequestText(rmsg)
+		// 获取问题的答案
+		reply, err := gtp.Completions(requestText)
 		if err != nil {
-			log.Printf("response user error: %v \n", err)
+			logger.Info("gtp request error: %v \n", err)
+			_, err = rmsg.ReplyText("机器人太累了,让她休息会儿,过一会儿再来请求。")
+			if err != nil {
+				logger.Warning("send message error: %v \n", err)
+				return err
+			}
+			logger.Info("request openai error: %v\n", err)
+			return err
+		}
+		if reply == "" {
+			logger.Warning("get gpt result falied: %v\n", err)
+			return nil
+		}
+		// 回复@我的用户
+		reply = strings.TrimSpace(reply)
+		reply = strings.Trim(reply, "\n")
+
+		UserService.SetUserSessionContext(rmsg.SenderID, requestText, reply)
+		replyText := atText + reply
+		_, err = rmsg.ReplyText(replyText)
+		if err != nil {
+			logger.Info("send message error: %v \n", err)
 			return err
 		}
-	}
-	UserService.SetUserSessionContext(rmsg.SenderID, rmsg.Text.Content, reply)
-	replyText := atText + reply
-	_, err = rmsg.ReplyText(replyText)
-	if err != nil {
-		log.Printf("send message error: %v \n", err)
-		return err
 	}
 	return nil
 }
+
+// getRequestText 获取请求接口的文本,要做一些清洗
+func getRequestText(rmsg public.ReceiveMsg) string {
+	// 1.去除空格以及换行
+	requestText := strings.TrimSpace(rmsg.Text.Content)
+	requestText = strings.Trim(rmsg.Text.Content, "\n")
+
+	// 2.替换掉当前用户名称
+	replaceText := "@" + rmsg.SenderNick
+	requestText = strings.TrimSpace(strings.ReplaceAll(rmsg.Text.Content, replaceText, ""))
+	if requestText == "" {
+		return ""
+	}
+
+	// 3.获取上下文,拼接在一起,如果字符长度超出4000,截取为4000。(GPT按字符长度算)
+	requestText = UserService.GetUserSessionContext(rmsg.SenderID) + requestText
+	if len(requestText) >= 4000 {
+		requestText = requestText[:4000]
+	}
+
+	// 4.返回请求文本
+	return requestText
+}

public/base.go → public/dingtalk.go


+ 40 - 0
public/logger/logger.go

@@ -0,0 +1,40 @@
+package logger
+
+import (
+	"log"
+	"os"
+	"sync"
+)
+
+var Logger *log.Logger
+var once sync.Once
+
+func init() {
+	once.Do(func() {
+		Logger = log.New(os.Stdout, "INFO", log.Ldate|log.Ltime|log.Lshortfile)
+	})
+}
+
+// Info 详情
+func Info(args ...interface{}) {
+	Logger.SetPrefix("[INFO]")
+	Logger.Println(args...)
+}
+
+// Danger 错误 为什么不命名为 error?避免和 error 类型重名
+func Danger(args ...interface{}) {
+	Logger.SetPrefix("[ERROR]")
+	Logger.Fatal(args...)
+}
+
+// Warning 警告
+func Warning(args ...interface{}) {
+	Logger.SetPrefix("[WARNING]")
+	Logger.Println(args...)
+}
+
+// DeBug debug
+func DeBug(args ...interface{}) {
+	Logger.SetPrefix("[DeBug]")
+	Logger.Println(args...)
+}

+ 3 - 3
service/user.go

@@ -3,7 +3,6 @@ package service
 import (
 	"strings"
 	"time"
-	"unicode/utf8"
 
 	"github.com/eryajf/chatgpt-dingtalk/config"
 	"github.com/patrickmn/go-cache"
@@ -24,9 +23,10 @@ type UserService struct {
 	cache *cache.Cache
 }
 
-// ClearUserSessionContext 清空GTP上下文,接收文本中包含`我要问下一个问题`,并且Unicode 字符数量不超过20就清空
+// ClearUserSessionContext 清空GTP上下文,接收文本中包含 SessionClearToken
 func (s *UserService) ClearUserSessionContext(userId string, msg string) bool {
-	if strings.Contains(msg, "我要问下一个问题") && utf8.RuneCountInString(msg) < 20 {
+	// 清空会话
+	if strings.Contains(msg, config.LoadConfig().SessionClearToken) {
 		s.cache.Delete(userId)
 		return true
 	}