Browse Source

feat: 重构交互逻辑以及与openai的交互

eryajf 2 years ago
parent
commit
1c504bab48
20 changed files with 752 additions and 224 deletions
  1. 2 0
      .gitignore
  2. 36 17
      README.md
  3. 1 5
      config.dev.json
  4. 34 34
      config/config.go
  5. 10 2
      go.mod
  6. 20 0
      go.sum
  7. 0 83
      gpt/gpt.go
  8. 92 63
      main.go
  9. 21 0
      pkg/chatgpt/LICENSE
  10. 1 0
      pkg/chatgpt/README.md
  11. 93 0
      pkg/chatgpt/chatgpt.go
  12. 53 0
      pkg/chatgpt/chatgpt_test.go
  13. 210 0
      pkg/chatgpt/context.go
  14. 85 0
      pkg/chatgpt/context_test.go
  15. 12 0
      pkg/chatgpt/errors.go
  16. 12 0
      pkg/chatgpt/format.go
  17. 4 0
      pkg/chatgpt/go.mod
  18. 26 0
      pkg/chatgpt/tools.go
  19. 27 0
      public/gpt.go
  20. 13 20
      service/user.go

+ 2 - 0
.gitignore

@@ -18,3 +18,5 @@ chatgpt-dingtalk
 # Dependency directories (remove the comment below to include it)
 # vendor/
 config.json
+tmp
+test/

File diff suppressed because it is too large
+ 36 - 17
README.md


+ 1 - 5
config.dev.json

@@ -1,8 +1,4 @@
 {
     "api_key": "xxxxxxxxx",
-    "session_timeout": 180,
-    "max_tokens": 2000,
-    "model": "text-davinci-003",
-    "temperature": 0.9,
-    "session_clear_token": "清空会话"
+    "session_timeout": 600
 }

+ 34 - 34
config/config.go

@@ -17,14 +17,14 @@ type Configuration struct {
 	ApiKey string `json:"api_key"`
 	// 会话超时时间
 	SessionTimeout time.Duration `json:"session_timeout"`
-	// GPT请求最大字符数
-	MaxTokens uint `json:"max_tokens"`
-	// GPT模型
-	Model string `json:"model"`
-	// 热度
-	Temperature float64 `json:"temperature"`
-	// 自定义清空会话口令
-	SessionClearToken string `json:"session_clear_token"`
+	// // GPT请求最大字符数
+	// MaxTokens uint `json:"max_tokens"`
+	// // GPT模型
+	// Model string `json:"model"`
+	// // 热度
+	// Temperature float64 `json:"temperature"`
+	// // 自定义清空会话口令
+	// SessionClearToken string `json:"session_clear_token"`
 }
 
 var config *Configuration
@@ -51,10 +51,10 @@ func LoadConfig() *Configuration {
 		// 如果环境变量有配置,读取环境变量
 		ApiKey := os.Getenv("APIKEY")
 		SessionTimeout := os.Getenv("SESSION_TIMEOUT")
-		Model := os.Getenv("MODEL")
-		MaxTokens := os.Getenv("MAX_TOKENS")
-		Temperature := os.Getenv("TEMPREATURE")
-		SessionClearToken := os.Getenv("SESSION_CLEAR_TOKEN")
+		// Model := os.Getenv("MODEL")
+		// MaxTokens := os.Getenv("MAX_TOKENS")
+		// Temperature := os.Getenv("TEMPREATURE")
+		// SessionClearToken := os.Getenv("SESSION_CLEAR_TOKEN")
 		if ApiKey != "" {
 			config.ApiKey = ApiKey
 		}
@@ -68,28 +68,28 @@ func LoadConfig() *Configuration {
 		} else {
 			config.SessionTimeout = time.Duration(config.SessionTimeout) * time.Second
 		}
-		if Model != "" {
-			config.Model = Model
-		}
-		if MaxTokens != "" {
-			max, err := strconv.Atoi(MaxTokens)
-			if err != nil {
-				logger.Danger(fmt.Sprintf("config MaxTokens err: %v ,get is %v", err, MaxTokens))
-				return
-			}
-			config.MaxTokens = uint(max)
-		}
-		if Temperature != "" {
-			temp, err := strconv.ParseFloat(Temperature, 64)
-			if err != nil {
-				logger.Danger(fmt.Sprintf("config Temperature err: %v ,get is %v", err, Temperature))
-				return
-			}
-			config.Temperature = temp
-		}
-		if SessionClearToken != "" {
-			config.SessionClearToken = SessionClearToken
-		}
+		// if Model != "" {
+		// 	config.Model = Model
+		// }
+		// if MaxTokens != "" {
+		// 	max, err := strconv.Atoi(MaxTokens)
+		// 	if err != nil {
+		// 		logger.Danger(fmt.Sprintf("config MaxTokens err: %v ,get is %v", err, MaxTokens))
+		// 		return
+		// 	}
+		// 	config.MaxTokens = uint(max)
+		// }
+		// if Temperature != "" {
+		// 	temp, err := strconv.ParseFloat(Temperature, 64)
+		// 	if err != nil {
+		// 		logger.Danger(fmt.Sprintf("config Temperature err: %v ,get is %v", err, Temperature))
+		// 		return
+		// 	}
+		// 	config.Temperature = temp
+		// }
+		// if SessionClearToken != "" {
+		// 	config.SessionClearToken = SessionClearToken
+		// }
 	})
 	if config.ApiKey == "" {
 		logger.Danger("config err: api key required")

+ 10 - 2
go.mod

@@ -1,10 +1,18 @@
 module github.com/eryajf/chatgpt-dingtalk
 
-go 1.17
+go 1.18
 
 require (
 	github.com/go-resty/resty/v2 v2.7.0
 	github.com/patrickmn/go-cache v2.1.0+incompatible
+	github.com/solywsh/chatgpt v0.0.14
 )
 
-require golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect
+require (
+	github.com/joho/godotenv v1.5.1 // indirect
+	github.com/sashabaranov/go-gpt3 v1.0.1 // indirect
+	github.com/stretchr/testify v1.8.1 // indirect
+	golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect
+)
+
+replace github.com/solywsh/chatgpt => ./pkg/chatgpt

+ 20 - 0
go.sum

@@ -1,7 +1,23 @@
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
+github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY=
 github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I=
+github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
+github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
 github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
 github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/sashabaranov/go-gpt3 v1.0.1 h1:KHwY4uroFlX1qI1Hui7d31ZI6uzbNGL9zAkh1FkfhuM=
+github.com/sashabaranov/go-gpt3 v1.0.1/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
+github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
+github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
+github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 golang.org/x/net v0.0.0-20211029224645-99673261e6eb h1:pirldcYWx7rx7kE5r+9WsOXPXK0+WH5+uZ7uPmJ44uM=
 golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -9,3 +25,7 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
+gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 0 - 83
gpt/gpt.go

@@ -1,83 +0,0 @@
-package gpt
-
-import (
-	"crypto/tls"
-	"encoding/json"
-	"fmt"
-	"time"
-
-	"github.com/eryajf/chatgpt-dingtalk/config"
-	"github.com/eryajf/chatgpt-dingtalk/public/logger"
-	"github.com/go-resty/resty/v2"
-)
-
-const BASEURL = "https://api.openai.com/v1/"
-
-// ChatGPTRequestBody 请求体
-type ChatGPTRequestBody struct {
-	Model       string  `json:"model"`
-	Prompt      string  `json:"prompt"`
-	MaxTokens   uint    `json:"max_tokens"`
-	Temperature float64 `json:"temperature"`
-}
-
-// ChatGPTResponseBody 响应体
-type ChatGPTResponseBody struct {
-	ID      string                 `json:"id"`
-	Object  string                 `json:"object"`
-	Created int                    `json:"created"`
-	Model   string                 `json:"model"`
-	Choices []ChoiceItem           `json:"choices"`
-	Usage   map[string]interface{} `json:"usage"`
-}
-
-type ChoiceItem struct {
-	Text         string `json:"text"`
-	Index        int    `json:"index"`
-	Logprobs     int    `json:"logprobs"`
-	FinishReason string `json:"finish_reason"`
-}
-
-// Completions gtp文本模型回复
-//curl https://api.openai.com/v1/completions
-//-H "Content-Type: application/json"
-//-H "Authorization: Bearer your chatGPT key"
-//-d '{"model": "text-davinci-003", "prompt": "give me good song", "temperature": 0, "max_tokens": 7}'
-func Completions(msg string) (string, error) {
-	cfg := config.LoadConfig()
-	requestBody := ChatGPTRequestBody{
-		Model:       cfg.Model,
-		Prompt:      msg,
-		MaxTokens:   cfg.MaxTokens,
-		Temperature: cfg.Temperature,
-	}
-
-	client := resty.New().
-		SetRetryCount(2).
-		SetRetryWaitTime(1*time.Second).
-		SetTimeout(cfg.SessionTimeout).
-		SetHeader("Content-Type", "application/json").
-		SetHeader("Authorization", "Bearer "+cfg.ApiKey).
-		SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true})
-
-	rsp, err := client.R().SetBody(requestBody).Post(BASEURL + "completions")
-	if err != nil {
-		return "", fmt.Errorf("request openai failed, err : %v", err)
-	}
-	if rsp.StatusCode() != 200 {
-		return "", fmt.Errorf("gtp api status code not equals 200, code is %d ,details:  %v ", rsp.StatusCode(), string(rsp.Body()))
-	} else {
-		logger.Info(fmt.Sprintf("response gtp json string : %v", string(rsp.Body())))
-	}
-
-	gptResponseBody := &ChatGPTResponseBody{}
-	err = json.Unmarshal(rsp.Body(), gptResponseBody)
-	if err != nil {
-		return "", err
-	}
-	var reply string
-	if len(gptResponseBody.Choices) > 0 {
-		reply = gptResponseBody.Choices[0].Text
-	}
-	return reply, nil
-}

+ 92 - 63
main.go

@@ -5,9 +5,9 @@ import (
 	"fmt"
 	"io/ioutil"
 	"net/http"
+	"os"
 	"strings"
 
-	"github.com/eryajf/chatgpt-dingtalk/gpt"
 	"github.com/eryajf/chatgpt-dingtalk/public"
 	"github.com/eryajf/chatgpt-dingtalk/public/logger"
 	"github.com/eryajf/chatgpt-dingtalk/service"
@@ -23,6 +23,18 @@ func main() {
 	Start()
 }
 
+var Welcome string = `Commands:
+=================================
+🙋 单聊 👉 单独聊天,缺省
+🗣 串聊 👉 带上下文聊天
+🔃 重置 👉 重置带上下文聊天
+🚀 帮助 👉 显示帮助信息
+=================================
+例:@我发送 空 或 帮助 将返回此帮助信息
+`
+
+// 💵 余额 👉 查看接口可调用额度
+
 func Start() {
 	// 定义一个处理器函数
 	handler := func(w http.ResponseWriter, r *http.Request) {
@@ -32,16 +44,16 @@ func Start() {
 			logger.Warning("read request body failed: %v\n", err.Error())
 			return
 		}
+		var msgObj = new(public.ReceiveMsg)
+		err = json.Unmarshal(data, &msgObj)
+		if err != nil {
+			logger.Warning("unmarshal request body failed: %v\n", err)
+		}
 		// TODO: 校验请求
-		if len(data) == 0 {
-			logger.Warning("回调参数为空,以至于无法正常解析,请检查原因")
-			return
+		if len(msgObj.Text.Content) == 1 || msgObj.Text.Content == " 帮助" {
+			// 欢迎信息
+			msgObj.ReplyText(Welcome)
 		} else {
-			var msgObj = new(public.ReceiveMsg)
-			err = json.Unmarshal(data, &msgObj)
-			if err != nil {
-				logger.Warning("unmarshal request body failed: %v\n", err)
-			}
 			logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
 			err = ProcessRequest(*msgObj)
 			if err != nil {
@@ -64,65 +76,82 @@ func Start() {
 	}
 }
 
+func FirstCheck(rmsg public.ReceiveMsg) bool {
+	lc := UserService.GetUserMode(rmsg.SenderNick)
+	if lc != "" && strings.Contains(lc, "串聊") {
+		return true
+	}
+	return false
+}
+
 func ProcessRequest(rmsg public.ReceiveMsg) error {
-	atText := "@" + rmsg.SenderNick + "\n" + " "
-	if UserService.ClearUserSessionContext(rmsg.SenderID, rmsg.Text.Content) {
-		_, err := rmsg.ReplyText(atText + "上下文已经清空了,你可以问下一个问题啦。")
-		if err != nil {
-			logger.Warning("response user error: %v \n", err)
-			return err
+	switch rmsg.Text.Content {
+	case " 单聊":
+		UserService.SetUserMode(rmsg.SenderNick, rmsg.Text.Content)
+		rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈单聊的模式 =====", rmsg.SenderNick))
+	case " 串聊":
+		UserService.SetUserMode(rmsg.SenderNick, rmsg.Text.Content)
+		rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈串聊的模式 =====", rmsg.SenderNick))
+	case " 重置":
+		UserService.ClearUserMode(rmsg.SenderNick)
+		err := os.Remove("openaiCache/" + rmsg.SenderNick)
+		if err != nil && !strings.Contains(fmt.Sprintf("%s", err), "no such file or directory") {
+			rmsg.ReplyText(fmt.Sprintf("=====清理与👉%s👈的对话缓存失败,错误信息: %v\n请检查=====", rmsg.SenderNick, err))
+		} else {
+			rmsg.ReplyText(fmt.Sprintf("=====已重置与👉%s👈的对话模式,可以开始新的对话=====", rmsg.SenderNick))
 		}
-	} else {
-		requestText := getRequestText(rmsg)
-		// 获取问题的答案
-		reply, err := gpt.Completions(requestText)
-		if err != nil {
-			logger.Info("gpt request error: %v \n", err)
-			_, err = rmsg.ReplyText("机器人太累了,让她休息会儿,过一会儿再来请求。")
+	default:
+		if FirstCheck(rmsg) {
+			cli, reply, err := public.ContextQa(rmsg.Text.Content, rmsg.SenderNick)
 			if err != nil {
-				logger.Warning("send message error: %v \n", err)
-				return err
+				logger.Info("gpt request error: %v \n", err)
+				_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err))
+				if err != nil {
+					logger.Warning("send message error: %v \n", err)
+					return err
+				}
+			}
+			if reply == "" {
+				logger.Warning("get gpt result falied: %v\n", err)
+				return nil
+			} else {
+				reply = strings.TrimSpace(reply)
+				reply = strings.Trim(reply, "\n")
+				// 回复@我的用户
+				replyText := "@" + rmsg.SenderNick + "\n" + reply
+				_, err = rmsg.ReplyText(replyText)
+				if err != nil {
+					logger.Warning("send message error: %v \n", err)
+					return err
+				}
+				path := "openaiCache/" + rmsg.SenderNick
+				cli.ChatContext.SaveConversation(path)
+			}
+		} else {
+			reply, err := public.SingleQa(rmsg.Text.Content, rmsg.SenderNick)
+			if err != nil {
+				logger.Info("gpt request error: %v \n", err)
+				_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err))
+				if err != nil {
+					logger.Warning("send message error: %v \n", err)
+					return err
+				}
+			}
+			if reply == "" {
+				logger.Warning("get gpt result falied: %v\n", err)
+				return nil
+			} else {
+				reply = strings.TrimSpace(reply)
+				reply = strings.Trim(reply, "\n")
+				// 回复@我的用户
+				replyText := "@" + rmsg.SenderNick + "\n" + reply
+				_, err = rmsg.ReplyText(replyText)
+				if err != nil {
+					logger.Warning("send message error: %v \n", err)
+					return err
+				}
 			}
-			logger.Info("request openai error: %v\n", err)
-			return err
-		}
-		if reply == "" {
-			logger.Warning("get gpt result falied: %v\n", err)
-			return nil
-		}
-		// 回复@我的用户
-		reply = strings.TrimSpace(reply)
-		reply = strings.Trim(reply, "\n")
-
-		UserService.SetUserSessionContext(rmsg.SenderID, requestText, reply)
-		replyText := atText + reply
-		_, err = rmsg.ReplyText(replyText)
-		if err != nil {
-			logger.Info("send message error: %v \n", err)
-			return err
 		}
 	}
 	return nil
 }
-
-// getRequestText 获取请求接口的文本,要做一些清洗
-func getRequestText(rmsg public.ReceiveMsg) string {
-	// 1.去除空格以及换行
-	requestText := strings.TrimSpace(rmsg.Text.Content)
-	requestText = strings.Trim(rmsg.Text.Content, "\n")
-	// 2.替换掉当前用户名称
-	replaceText := "@" + rmsg.SenderNick
-	requestText = strings.TrimSpace(strings.ReplaceAll(rmsg.Text.Content, replaceText, ""))
-	if requestText == "" {
-		return ""
-	}
-
-	// 3.获取上下文,拼接在一起,如果字符长度超出4000,截取为4000。(GPT按字符长度算)
-	requestText = UserService.GetUserSessionContext(rmsg.SenderID) + requestText
-	if len(requestText) >= 4000 {
-		requestText = requestText[:4000]
-	}
-
-	// 4.返回请求文本
-	return requestText
-}

+ 21 - 0
pkg/chatgpt/LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Shihao
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 1 - 0
pkg/chatgpt/README.md

@@ -0,0 +1 @@
+> 因为三方包写死了很多参数,这里转到本地,便于二次改造。 感谢:https://github.com/solywsh/chatgpt

+ 93 - 0
pkg/chatgpt/chatgpt.go

@@ -0,0 +1,93 @@
+package chatgpt
+
+import (
+	"context"
+	"time"
+
+	gogpt "github.com/sashabaranov/go-gpt3"
+)
+
+type ChatGPT struct {
+	client         *gogpt.Client
+	ctx            context.Context
+	userId         string
+	maxQuestionLen int
+	maxText        int
+	maxAnswerLen   int
+	timeOut        time.Duration // 超时时间, 0表示不超时
+	doneChan       chan struct{}
+	cancel         func()
+
+	ChatContext *ChatContext
+}
+
+func New(ApiKey, UserId string, timeOut time.Duration) *ChatGPT {
+	var ctx context.Context
+	var cancel func()
+	if timeOut == 0 {
+		ctx, cancel = context.WithCancel(context.Background())
+	} else {
+		ctx, cancel = context.WithTimeout(context.Background(), timeOut)
+	}
+	timeOutChan := make(chan struct{}, 1)
+	go func() {
+		<-ctx.Done()
+		timeOutChan <- struct{}{} // 发送超时信号,或是提示结束,用于聊天机器人场景,配合GetTimeOutChan() 使用
+	}()
+	return &ChatGPT{
+		client:         gogpt.NewClient(ApiKey),
+		ctx:            ctx,
+		userId:         UserId,
+		maxQuestionLen: 2048, // 最大问题长度
+		maxAnswerLen:   2048, // 最大答案长度
+		maxText:        4096, // 最大文本 = 问题 + 回答, 接口限制
+		timeOut:        timeOut,
+		doneChan:       timeOutChan,
+		cancel: func() {
+			cancel()
+		},
+		ChatContext: NewContext(),
+	}
+}
+func (c *ChatGPT) Close() {
+	c.cancel()
+}
+
+func (c *ChatGPT) GetDoneChan() chan struct{} {
+	return c.doneChan
+}
+
+func (c *ChatGPT) SetMaxQuestionLen(maxQuestionLen int) int {
+	if maxQuestionLen > c.maxText-c.maxAnswerLen {
+		maxQuestionLen = c.maxText - c.maxAnswerLen
+	}
+	c.maxQuestionLen = maxQuestionLen
+	return c.maxQuestionLen
+}
+
+func (c *ChatGPT) Chat(question string) (answer string, err error) {
+	question = question + "."
+	if len(question) > c.maxQuestionLen {
+		return "", OverMaxQuestionLength
+	}
+	if len(question)+c.maxAnswerLen > c.maxText {
+		question = question[:c.maxText-c.maxAnswerLen]
+	}
+	req := gogpt.CompletionRequest{
+		Model:            gogpt.GPT3TextDavinci003,
+		MaxTokens:        c.maxAnswerLen,
+		Prompt:           question,
+		Temperature:      0.9,
+		TopP:             1,
+		N:                1,
+		FrequencyPenalty: 0,
+		PresencePenalty:  0.5,
+		User:             c.userId,
+		Stop:             []string{},
+	}
+	resp, err := c.client.CreateCompletion(c.ctx, req)
+	if err != nil {
+		return "", err
+	}
+	return formatAnswer(resp.Choices[0].Text), err
+}

+ 53 - 0
pkg/chatgpt/chatgpt_test.go

@@ -0,0 +1,53 @@
+package chatgpt
+
+import (
+	"fmt"
+	"testing"
+	"time"
+)
+
+func TestChatGPT(t *testing.T) {
+	chat := New("CHATGPT_API_KEY", "", 0)
+	defer chat.Close()
+
+	//select {
+	//case <-chat.GetDoneChan():
+	//	fmt.Println("time out")
+	//}
+	question := "你认为2022年世界杯的冠军是谁?\n"
+	fmt.Printf("Q: %s\n", question)
+	answer, err := chat.Chat(question)
+	if err != nil {
+		fmt.Println(err)
+	}
+	fmt.Printf("A: %s\n", answer)
+
+	//Q: 你认为2022年世界杯的冠军是谁?
+	//A: 这个问题很难回答,因为2022年世界杯还没有开始,所以没有人知道冠军是谁。
+
+}
+
+func TestChatGPT_ChatWithContext(t *testing.T) {
+	chat := New("CHATGPT_API_KEY", "", 10*time.Minute)
+	defer chat.Close()
+	//go func() {
+	//	select {
+	//	case <-chat.GetDoneChan():
+	//		fmt.Println("time out")
+	//	}
+	//}()
+	question := "现在你是一只猫,接下来你只能用\"喵喵喵\"回答."
+	fmt.Printf("Q: %s\n", question)
+	answer, err := chat.ChatWithContext(question)
+	if err != nil {
+		fmt.Println(err)
+	}
+	fmt.Printf("A: %s\n", answer)
+	question = "你是一只猫吗?"
+	fmt.Printf("Q: %s\n", question)
+	answer, err = chat.ChatWithContext(question)
+	if err != nil {
+		fmt.Println(err)
+	}
+	fmt.Printf("A: %s\n", answer)
+}

+ 210 - 0
pkg/chatgpt/context.go

@@ -0,0 +1,210 @@
+package chatgpt
+
+import (
+	"bytes"
+	"encoding/gob"
+	"fmt"
+	"os"
+	"strings"
+
+	gogpt "github.com/sashabaranov/go-gpt3"
+)
+
+var (
+	DefaultAiRole    = "AI"
+	DefaultHumanRole = "Human"
+
+	DefaultCharacter  = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
+	DefaultBackground = "The following is a conversation with AI assistant. The assistant is %s"
+	DefaultPreset     = "\n%s: 你好,让我们开始愉快的谈话!\n%s: 我是 AI assistant ,请问你有什么问题?"
+)
+
+type (
+	ChatContext struct {
+		background  string // 对话背景
+		preset      string // 预设对话
+		maxSeqTimes int    // 最大对话次数
+		aiRole      *role  // AI角色
+		humanRole   *role  // 人类角色
+
+		old        []conversation // 旧对话
+		restartSeq string         // 重新开始对话的标识
+		startSeq   string         // 开始对话的标识
+
+		seqTimes int // 对话次数
+
+		maintainSeqTimes bool // 是否维护对话次数 (自动移除旧对话)
+	}
+
+	ChatContextOption func(*ChatContext)
+
+	conversation struct {
+		Role   *role
+		Prompt string
+	}
+
+	role struct {
+		Name string
+	}
+)
+
+func NewContext(options ...ChatContextOption) *ChatContext {
+	ctx := &ChatContext{
+		aiRole:           &role{Name: DefaultAiRole},
+		humanRole:        &role{Name: DefaultHumanRole},
+		background:       fmt.Sprintf(DefaultBackground, strings.Join(DefaultCharacter, ", ")+"."),
+		maxSeqTimes:      1000,
+		preset:           fmt.Sprintf(DefaultPreset, DefaultHumanRole, DefaultAiRole),
+		old:              []conversation{},
+		seqTimes:         0,
+		restartSeq:       "\n" + DefaultHumanRole + ": ",
+		startSeq:         "\n" + DefaultAiRole + ": ",
+		maintainSeqTimes: false,
+	}
+
+	for _, option := range options {
+		option(ctx)
+	}
+	return ctx
+}
+
+// PollConversation 移除最旧的一则对话
+func (c *ChatContext) PollConversation() {
+	c.old = c.old[1:]
+	c.seqTimes--
+}
+
+// ResetConversation 重置对话
+func (c *ChatContext) ResetConversation() {
+	c.old = []conversation{}
+	c.seqTimes = 0
+}
+
+// SaveConversation 保存对话
+func (c *ChatContext) SaveConversation(path string) error {
+	var buffer bytes.Buffer
+	enc := gob.NewEncoder(&buffer)
+	err := enc.Encode(c.old)
+	if err != nil {
+		return err
+	}
+	return WriteToFile(path, buffer.Bytes())
+}
+
+// LoadConversation 加载对话
+func (c *ChatContext) LoadConversation(path string) error {
+	data, err := os.ReadFile(path)
+	if err != nil {
+		return err
+	}
+	buffer := bytes.NewBuffer(data)
+	dec := gob.NewDecoder(buffer)
+	err = dec.Decode(&c.old)
+	if err != nil {
+		return err
+	}
+	c.seqTimes = len(c.old)
+	return nil
+}
+
+func (c *ChatContext) SetHumanRole(role string) {
+	c.humanRole.Name = role
+	c.restartSeq = "\n" + c.humanRole.Name + ": "
+}
+
+func (c *ChatContext) SetAiRole(role string) {
+	c.aiRole.Name = role
+	c.startSeq = "\n" + c.aiRole.Name + ": "
+}
+
+func (c *ChatContext) SetMaxSeqTimes(times int) {
+	c.maxSeqTimes = times
+}
+
+func (c *ChatContext) GetMaxSeqTimes() int {
+	return c.maxSeqTimes
+}
+
+func (c *ChatContext) SetBackground(background string) {
+	c.background = background
+}
+
+func (c *ChatContext) SetPreset(preset string) {
+	c.preset = preset
+}
+
+func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
+	question = question + "."
+	if len(question) > c.maxQuestionLen {
+		return "", OverMaxQuestionLength
+	}
+	if c.ChatContext.seqTimes >= c.ChatContext.maxSeqTimes {
+		if c.ChatContext.maintainSeqTimes {
+			c.ChatContext.PollConversation()
+		} else {
+			return "", OverMaxSequenceTimes
+		}
+	}
+	var promptTable []string
+	promptTable = append(promptTable, c.ChatContext.background)
+	promptTable = append(promptTable, c.ChatContext.preset)
+	for _, v := range c.ChatContext.old {
+		if v.Role == c.ChatContext.humanRole {
+			promptTable = append(promptTable, "\n"+v.Role.Name+": "+v.Prompt)
+		} else {
+			promptTable = append(promptTable, v.Role.Name+": "+v.Prompt)
+		}
+	}
+	promptTable = append(promptTable, "\n"+c.ChatContext.restartSeq+question)
+	prompt := strings.Join(promptTable, "\n")
+	prompt += c.ChatContext.startSeq
+	if len(prompt) > c.maxText-c.maxAnswerLen {
+		return "", OverMaxTextLength
+	}
+	req := gogpt.CompletionRequest{
+		Model:            gogpt.GPT3TextDavinci003,
+		MaxTokens:        c.maxAnswerLen,
+		Prompt:           prompt,
+		Temperature:      0.9,
+		TopP:             1,
+		N:                1,
+		FrequencyPenalty: 0,
+		PresencePenalty:  0.5,
+		User:             c.userId,
+		Stop:             []string{c.ChatContext.aiRole.Name + ":", c.ChatContext.humanRole.Name + ":"},
+	}
+	resp, err := c.client.CreateCompletion(c.ctx, req)
+	if err != nil {
+		return "", err
+	}
+	resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text)
+	c.ChatContext.old = append(c.ChatContext.old, conversation{
+		Role:   c.ChatContext.humanRole,
+		Prompt: question,
+	})
+	c.ChatContext.old = append(c.ChatContext.old, conversation{
+		Role:   c.ChatContext.aiRole,
+		Prompt: resp.Choices[0].Text,
+	})
+	c.ChatContext.seqTimes++
+	return resp.Choices[0].Text, nil
+}
+
+func WithMaxSeqTimes(times int) ChatContextOption {
+	return func(c *ChatContext) {
+		c.SetMaxSeqTimes(times)
+	}
+}
+
+// WithOldConversation 从文件中加载对话
+func WithOldConversation(path string) ChatContextOption {
+	return func(c *ChatContext) {
+		_ = c.LoadConversation(path)
+	}
+}
+
+func WithMaintainSeqTimes(maintain bool) ChatContextOption {
+	return func(c *ChatContext) {
+		c.maintainSeqTimes = maintain
+	}
+}

+ 85 - 0
pkg/chatgpt/context_test.go

@@ -0,0 +1,85 @@
+package chatgpt
+
+import (
+	"os"
+	"testing"
+	"time"
+
+	"github.com/joho/godotenv"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestOfflineContext(t *testing.T) {
+	key := os.Getenv("CHATGPT_API_KEY")
+	if key == "" {
+		t.Skip("CHATGPT_API_KEY is not set")
+	}
+	cli := New(key, "user1", time.Second*30)
+	reply, err := cli.ChatWithContext("我叫老三,你是?")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("我叫老三,你是? => %s", reply)
+
+	err = cli.ChatContext.SaveConversation("test.conversation")
+	if err != nil {
+		t.Fatalf("储存对话记录失败: %v", err)
+	}
+	cli.ChatContext.ResetConversation()
+
+	reply, err = cli.ChatWithContext("你知道我是谁吗?")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("你知道我是谁吗? => %s", reply)
+	assert.NotContains(t, reply, "老三")
+
+	err = cli.ChatContext.LoadConversation("test.conversation")
+	if err != nil {
+		t.Fatalf("读取对话记录失败: %v", err)
+	}
+
+	reply, err = cli.ChatWithContext("你知道我是谁吗?")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	t.Logf("你知道我是谁吗? => %s", reply)
+
+	// AI 理应知道他叫老三
+	assert.Contains(t, reply, "老三")
+}
+
+func TestMaintainContext(t *testing.T) {
+	key := os.Getenv("CHATGPT_API_KEY")
+	if key == "" {
+		t.Skip("CHATGPT_API_KEY is not set")
+	}
+	cli := New(key, "user1", time.Second*30)
+	cli.ChatContext = NewContext(
+		WithMaxSeqTimes(1),
+		WithMaintainSeqTimes(true),
+	)
+
+	reply, err := cli.ChatWithContext("我叫老三,你是?")
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Logf("我叫老三,你是? => %s", reply)
+
+	reply, err = cli.ChatWithContext("你知道我是谁吗?")
+	if err != nil {
+		t.Fatal(err)
+	}
+	t.Logf("你知道我是谁吗? => %s", reply)
+
+	// 对话次数已经超过 1 次,因此最先前的对话已被移除,AI 理应不知道他叫老三
+	assert.NotContains(t, reply, "老三")
+}
+
+func init() {
+	// 本地加载适用于本地测试,如果要在github进行测试,可以透过传入 secrets 到环境参数
+	_ = godotenv.Load(".env.local")
+}

+ 12 - 0
pkg/chatgpt/errors.go

@@ -0,0 +1,12 @@
+package chatgpt
+
+import "errors"
+
+// OverMaxSequenceTimes 超过最大对话时间
+var OverMaxSequenceTimes = errors.New("maximum conversation times exceeded")
+
+// OverMaxTextLength 超过最大文本长度
+var OverMaxTextLength = errors.New("maximum text length exceeded")
+
+// OverMaxQuestionLength 超过最大问题长度
+var OverMaxQuestionLength = errors.New("maximum question length exceeded")

+ 12 - 0
pkg/chatgpt/format.go

@@ -0,0 +1,12 @@
+package chatgpt
+
+func formatAnswer(answer string) string {
+	for len(answer) > 0 {
+		if answer[:1] == "\n" || answer[0] == ' ' {
+			answer = answer[1:]
+		} else {
+			break
+		}
+	}
+	return answer
+}

+ 4 - 0
pkg/chatgpt/go.mod

@@ -0,0 +1,4 @@
+module chatgpt
+
+go 1.18
+

+ 26 - 0
pkg/chatgpt/tools.go

@@ -0,0 +1,26 @@
+package chatgpt
+
+import (
+	"fmt"
+	"io/ioutil"
+	"os"
+	"strings"
+)
+
+// 将内容写入到文件,如果文件名带路径,则会判断路径是否存在,不存在则创建
+func WriteToFile(path string, data []byte) error {
+	tmp := strings.Split(path, "/")
+	if len(tmp) > 0 {
+		tmp = tmp[:len(tmp)-1]
+	}
+	fmt.Println(tmp)
+	err := os.MkdirAll(strings.Join(tmp, "/"), os.ModePerm)
+	if err != nil {
+		return err
+	}
+	err = ioutil.WriteFile(path, data, 0755)
+	if err != nil {
+		return err
+	}
+	return nil
+}

+ 27 - 0
public/gpt.go

@@ -0,0 +1,27 @@
+package public
+
+import (
+	"fmt"
+
+	"github.com/eryajf/chatgpt-dingtalk/config"
+	"github.com/solywsh/chatgpt"
+)
+
+func SingleQa(question, userId string) (answer string, err error) {
+	cfg := config.LoadConfig()
+	chat := chatgpt.New(cfg.ApiKey, userId, cfg.SessionTimeout)
+	defer chat.Close()
+	return chat.ChatWithContext(question)
+}
+
+func ContextQa(question, userId string) (chat *chatgpt.ChatGPT, answer string, err error) {
+	cfg := config.LoadConfig()
+	chat = chatgpt.New(cfg.ApiKey, userId, cfg.SessionTimeout)
+	path := "openaiCache/" + userId
+	err = chat.ChatContext.LoadConversation(path)
+	if err != nil {
+		fmt.Printf("load station failed: %v\n", err)
+	}
+	answer, err = chat.ChatWithContext(question)
+	return
+}

+ 13 - 20
service/user.go

@@ -1,7 +1,6 @@
 package service
 
 import (
-	"strings"
 	"time"
 
 	"github.com/eryajf/chatgpt-dingtalk/config"
@@ -10,9 +9,9 @@ import (
 
 // UserServiceInterface 用户业务接口
 type UserServiceInterface interface {
-	GetUserSessionContext(userId string) string
-	SetUserSessionContext(userId string, question, reply string)
-	ClearUserSessionContext(userId string, msg string) bool
+	GetUserMode(userId string) string
+	SetUserMode(userId string, mode string)
+	ClearUserMode(userId string)
 }
 
 var _ UserServiceInterface = (*UserService)(nil)
@@ -23,23 +22,13 @@ type UserService struct {
 	cache *cache.Cache
 }
 
-// ClearUserSessionContext 清空GTP上下文,接收文本中包含 SessionClearToken
-func (s *UserService) ClearUserSessionContext(userId string, msg string) bool {
-	// 清空会话
-	if strings.Contains(msg, config.LoadConfig().SessionClearToken) {
-		s.cache.Delete(userId)
-		return true
-	}
-	return false
-}
-
 // NewUserService 创建新的业务层
 func NewUserService() UserServiceInterface {
 	return &UserService{cache: cache.New(time.Second*config.LoadConfig().SessionTimeout, time.Minute*10)}
 }
 
-// GetUserSessionContext 获取用户会话上下文文本
-func (s *UserService) GetUserSessionContext(userId string) string {
+// GetUserMode 获取当前对话模式
+func (s *UserService) GetUserMode(userId string) string {
 	sessionContext, ok := s.cache.Get(userId)
 	if !ok {
 		return ""
@@ -47,8 +36,12 @@ func (s *UserService) GetUserSessionContext(userId string) string {
 	return sessionContext.(string)
 }
 
-// SetUserSessionContext 设置用户会话上下文文本,question用户提问内容,GTP回复内容
-func (s *UserService) SetUserSessionContext(userId string, question, reply string) {
-	value := question + "\n" + reply
-	s.cache.Set(userId, value, time.Second*config.LoadConfig().SessionTimeout)
+// SetUserMode 设置用户对话模式
+func (s *UserService) SetUserMode(userId string, mode string) {
+	s.cache.Set(userId, mode, time.Second*config.LoadConfig().SessionTimeout)
+}
+
+// ClearUserMode 重置用户对话模式
+func (s *UserService) ClearUserMode(userId string) {
+	s.cache.Delete(userId)
 }