context.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. package chatgpt
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/base64"
  6. "encoding/gob"
  7. "errors"
  8. "fmt"
  9. "image/png"
  10. "os"
  11. "strings"
  12. "time"
  13. "github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
  14. "github.com/pandodao/tokenizer-go"
  15. "github.com/eryajf/chatgpt-dingtalk/public"
  16. openai "github.com/sashabaranov/go-openai"
  17. )
  18. var (
  19. DefaultAiRole = "AI"
  20. DefaultHumanRole = "Human"
  21. DefaultCharacter = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
  22. DefaultBackground = "The following is a conversation with AI assistant. The assistant is %s"
  23. DefaultPreset = "\n%s: 你好,让我们开始愉快的谈话!\n%s: 我是 AI assistant ,请问你有什么问题?"
  24. )
  25. type (
  26. ChatContext struct {
  27. background string // 对话背景
  28. preset string // 预设对话
  29. maxSeqTimes int // 最大对话次数
  30. aiRole *role // AI角色
  31. humanRole *role // 人类角色
  32. old []conversation // 旧对话
  33. restartSeq string // 重新开始对话的标识
  34. startSeq string // 开始对话的标识
  35. seqTimes int // 对话次数
  36. maintainSeqTimes bool // 是否维护对话次数 (自动移除旧对话)
  37. }
  38. ChatContextOption func(*ChatContext)
  39. conversation struct {
  40. Role *role
  41. Prompt string
  42. }
  43. role struct {
  44. Name string
  45. }
  46. )
  47. func NewContext(options ...ChatContextOption) *ChatContext {
  48. ctx := &ChatContext{
  49. aiRole: &role{Name: DefaultAiRole},
  50. humanRole: &role{Name: DefaultHumanRole},
  51. background: "",
  52. maxSeqTimes: 1000,
  53. preset: "",
  54. old: []conversation{},
  55. seqTimes: 0,
  56. restartSeq: "\n" + DefaultHumanRole + ": ",
  57. startSeq: "\n" + DefaultAiRole + ": ",
  58. maintainSeqTimes: false,
  59. }
  60. for _, option := range options {
  61. option(ctx)
  62. }
  63. return ctx
  64. }
  65. // PollConversation 移除最旧的一则对话
  66. func (c *ChatContext) PollConversation() {
  67. c.old = c.old[1:]
  68. c.seqTimes--
  69. }
  70. // ResetConversation 重置对话
  71. func (c *ChatContext) ResetConversation(userid string) {
  72. public.UserService.ClearUserSessionContext(userid)
  73. }
  74. // SaveConversation 保存对话
  75. func (c *ChatContext) SaveConversation(userid string) error {
  76. var buffer bytes.Buffer
  77. enc := gob.NewEncoder(&buffer)
  78. err := enc.Encode(c.old)
  79. if err != nil {
  80. return err
  81. }
  82. public.UserService.SetUserSessionContext(userid, buffer.String())
  83. return nil
  84. }
  85. // LoadConversation 加载对话
  86. func (c *ChatContext) LoadConversation(userid string) error {
  87. dec := gob.NewDecoder(strings.NewReader(public.UserService.GetUserSessionContext(userid)))
  88. err := dec.Decode(&c.old)
  89. if err != nil {
  90. return err
  91. }
  92. c.seqTimes = len(c.old)
  93. return nil
  94. }
  95. func (c *ChatContext) SetHumanRole(role string) {
  96. c.humanRole.Name = role
  97. c.restartSeq = "\n" + c.humanRole.Name + ": "
  98. }
  99. func (c *ChatContext) SetAiRole(role string) {
  100. c.aiRole.Name = role
  101. c.startSeq = "\n" + c.aiRole.Name + ": "
  102. }
  103. func (c *ChatContext) SetMaxSeqTimes(times int) {
  104. c.maxSeqTimes = times
  105. }
  106. func (c *ChatContext) GetMaxSeqTimes() int {
  107. return c.maxSeqTimes
  108. }
  109. func (c *ChatContext) SetBackground(background string) {
  110. c.background = background
  111. }
  112. func (c *ChatContext) SetPreset(preset string) {
  113. c.preset = preset
  114. }
  115. func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
  116. question = question + "."
  117. if tokenizer.MustCalToken(question) > c.maxQuestionLen {
  118. return "", OverMaxQuestionLength
  119. }
  120. if c.ChatContext.seqTimes >= c.ChatContext.maxSeqTimes {
  121. if c.ChatContext.maintainSeqTimes {
  122. c.ChatContext.PollConversation()
  123. } else {
  124. return "", OverMaxSequenceTimes
  125. }
  126. }
  127. var promptTable []string
  128. promptTable = append(promptTable, c.ChatContext.background)
  129. promptTable = append(promptTable, c.ChatContext.preset)
  130. for _, v := range c.ChatContext.old {
  131. if v.Role == c.ChatContext.humanRole {
  132. promptTable = append(promptTable, "\n"+v.Role.Name+": "+v.Prompt)
  133. } else {
  134. promptTable = append(promptTable, v.Role.Name+": "+v.Prompt)
  135. }
  136. }
  137. promptTable = append(promptTable, "\n"+c.ChatContext.restartSeq+question)
  138. prompt := strings.Join(promptTable, "\n")
  139. prompt += c.ChatContext.startSeq
  140. // 删除对话,直到prompt的长度满足条件
  141. for tokenizer.MustCalToken(prompt) > c.maxText {
  142. if len(c.ChatContext.old) > 1 { // 至少保留一条记录
  143. c.ChatContext.PollConversation() // 删除最旧的一条对话
  144. // 重新构建 prompt,计算长度
  145. promptTable = promptTable[1:] // 删除promptTable中对应的对话
  146. prompt = strings.Join(promptTable, "\n") + c.ChatContext.startSeq
  147. } else {
  148. break // 如果已经只剩一条记录,那么跳出循环
  149. }
  150. }
  151. // if tokenizer.MustCalToken(prompt) > c.maxText-c.maxAnswerLen {
  152. // return "", OverMaxTextLength
  153. // }
  154. model := public.Config.Model
  155. userId := c.userId
  156. if public.Config.AzureOn {
  157. userId = ""
  158. }
  159. if model == openai.GPT3Dot5Turbo || model == openai.GPT3Dot5Turbo0301 || model == openai.GPT3Dot5Turbo0613 ||
  160. model == openai.GPT3Dot5Turbo16K || model == openai.GPT3Dot5Turbo16K0613 ||
  161. model == openai.GPT4 || model == openai.GPT40314 || model == openai.GPT40613 ||
  162. model == openai.GPT432K || model == openai.GPT432K0314 || model == openai.GPT432K0613 {
  163. req := openai.ChatCompletionRequest{
  164. Model: model,
  165. Messages: []openai.ChatCompletionMessage{
  166. {
  167. Role: "user",
  168. Content: prompt,
  169. },
  170. },
  171. MaxTokens: c.maxAnswerLen,
  172. Temperature: 0.6,
  173. User: userId,
  174. }
  175. resp, err := c.client.CreateChatCompletion(c.ctx, req)
  176. if err != nil {
  177. return "", err
  178. }
  179. resp.Choices[0].Message.Content = formatAnswer(resp.Choices[0].Message.Content)
  180. c.ChatContext.old = append(c.ChatContext.old, conversation{
  181. Role: c.ChatContext.humanRole,
  182. Prompt: question,
  183. })
  184. c.ChatContext.old = append(c.ChatContext.old, conversation{
  185. Role: c.ChatContext.aiRole,
  186. Prompt: resp.Choices[0].Message.Content,
  187. })
  188. c.ChatContext.seqTimes++
  189. return resp.Choices[0].Message.Content, nil
  190. } else {
  191. req := openai.CompletionRequest{
  192. Model: model,
  193. MaxTokens: c.maxAnswerLen,
  194. Prompt: prompt,
  195. Temperature: 0.6,
  196. User: c.userId,
  197. Stop: []string{c.ChatContext.aiRole.Name + ":", c.ChatContext.humanRole.Name + ":"},
  198. }
  199. resp, err := c.client.CreateCompletion(c.ctx, req)
  200. if err != nil {
  201. return "", err
  202. }
  203. resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text)
  204. c.ChatContext.old = append(c.ChatContext.old, conversation{
  205. Role: c.ChatContext.humanRole,
  206. Prompt: question,
  207. })
  208. c.ChatContext.old = append(c.ChatContext.old, conversation{
  209. Role: c.ChatContext.aiRole,
  210. Prompt: resp.Choices[0].Text,
  211. })
  212. c.ChatContext.seqTimes++
  213. return resp.Choices[0].Text, nil
  214. }
  215. }
  216. func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) {
  217. model := public.Config.Model
  218. if model == openai.GPT3Dot5Turbo || model == openai.GPT3Dot5Turbo0301 || model == openai.GPT3Dot5Turbo0613 ||
  219. model == openai.GPT3Dot5Turbo16K || model == openai.GPT3Dot5Turbo16K0613 ||
  220. model == openai.GPT4 || model == openai.GPT40314 || model == openai.GPT40613 ||
  221. model == openai.GPT432K || model == openai.GPT432K0314 || model == openai.GPT432K0613 {
  222. req := openai.ImageRequest{
  223. Prompt: prompt,
  224. Size: openai.CreateImageSize1024x1024,
  225. ResponseFormat: openai.CreateImageResponseFormatB64JSON,
  226. N: 1,
  227. User: c.userId,
  228. }
  229. respBase64, err := c.client.CreateImage(c.ctx, req)
  230. if err != nil {
  231. return "", err
  232. }
  233. imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON)
  234. if err != nil {
  235. return "", err
  236. }
  237. r := bytes.NewReader(imgBytes)
  238. imgData, err := png.Decode(r)
  239. if err != nil {
  240. return "", err
  241. }
  242. imageName := time.Now().Format("20060102-150405") + ".png"
  243. clientId, _ := ctx.Value(public.DingTalkClientIdKeyName).(string)
  244. client := public.DingTalkClientManager.GetClientByOAuthClientID(clientId)
  245. mediaResult, uploadErr := &dingbot.MediaUploadResult{}, errors.New(fmt.Sprintf("unknown clientId: %s", clientId))
  246. if client != nil {
  247. mediaResult, uploadErr = client.UploadMedia(imgBytes, imageName, dingbot.MediaTypeImage, dingbot.MimeTypeImagePng)
  248. }
  249. err = os.MkdirAll("data/images", 0755)
  250. if err != nil {
  251. return "", err
  252. }
  253. file, err := os.Create("data/images/" + imageName)
  254. if err != nil {
  255. return "", err
  256. }
  257. defer file.Close()
  258. if err := png.Encode(file, imgData); err != nil {
  259. return "", err
  260. }
  261. if uploadErr == nil {
  262. return mediaResult.MediaID, nil
  263. } else {
  264. return public.Config.ServiceURL + "/images/" + imageName, nil
  265. }
  266. }
  267. return "", nil
  268. }
  269. func WithMaxSeqTimes(times int) ChatContextOption {
  270. return func(c *ChatContext) {
  271. c.SetMaxSeqTimes(times)
  272. }
  273. }
  274. // WithOldConversation 从文件中加载对话
  275. func WithOldConversation(userid string) ChatContextOption {
  276. return func(c *ChatContext) {
  277. _ = c.LoadConversation(userid)
  278. }
  279. }
  280. func WithMaintainSeqTimes(maintain bool) ChatContextOption {
  281. return func(c *ChatContext) {
  282. c.maintainSeqTimes = maintain
  283. }
  284. }