context.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. package chatgpt
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/base64"
  6. "encoding/gob"
  7. "errors"
  8. "fmt"
  9. "golang.org/x/image/webp"
  10. "image"
  11. _ "image/gif"
  12. _ "image/jpeg"
  13. "image/png"
  14. "os"
  15. "strings"
  16. "time"
  17. "github.com/pandodao/tokenizer-go"
  18. openai "github.com/sashabaranov/go-openai"
  19. "github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
  20. "github.com/eryajf/chatgpt-dingtalk/public"
  21. )
  22. var (
  23. DefaultAiRole = "AI"
  24. DefaultHumanRole = "Human"
  25. DefaultCharacter = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
  26. DefaultBackground = "The following is a conversation with AI assistant. The assistant is %s"
  27. DefaultPreset = "\n%s: 你好,让我们开始愉快的谈话!\n%s: 我是 AI assistant ,请问你有什么问题?"
  28. )
  29. type (
  30. ChatContext struct {
  31. background string // 对话背景
  32. preset string // 预设对话
  33. maxSeqTimes int // 最大对话次数
  34. aiRole *role // AI角色
  35. humanRole *role // 人类角色
  36. old []conversation // 旧对话
  37. restartSeq string // 重新开始对话的标识
  38. startSeq string // 开始对话的标识
  39. seqTimes int // 对话次数
  40. maintainSeqTimes bool // 是否维护对话次数 (自动移除旧对话)
  41. }
  42. ChatContextOption func(*ChatContext)
  43. conversation struct {
  44. Role *role
  45. Prompt string
  46. }
  47. role struct {
  48. Name string
  49. }
  50. )
  51. func NewContext(options ...ChatContextOption) *ChatContext {
  52. ctx := &ChatContext{
  53. aiRole: &role{Name: DefaultAiRole},
  54. humanRole: &role{Name: DefaultHumanRole},
  55. background: "",
  56. maxSeqTimes: 1000,
  57. preset: "",
  58. old: []conversation{},
  59. seqTimes: 0,
  60. restartSeq: "\n" + DefaultHumanRole + ": ",
  61. startSeq: "\n" + DefaultAiRole + ": ",
  62. maintainSeqTimes: false,
  63. }
  64. for _, option := range options {
  65. option(ctx)
  66. }
  67. return ctx
  68. }
  69. // PollConversation 移除最旧的一则对话
  70. func (c *ChatContext) PollConversation() {
  71. c.old = c.old[1:]
  72. c.seqTimes--
  73. }
  74. // ResetConversation 重置对话
  75. func (c *ChatContext) ResetConversation(userid string) {
  76. public.UserService.ClearUserSessionContext(userid)
  77. }
  78. // SaveConversation 保存对话
  79. func (c *ChatContext) SaveConversation(userid string) error {
  80. var buffer bytes.Buffer
  81. enc := gob.NewEncoder(&buffer)
  82. err := enc.Encode(c.old)
  83. if err != nil {
  84. return err
  85. }
  86. public.UserService.SetUserSessionContext(userid, buffer.String())
  87. return nil
  88. }
  89. // LoadConversation 加载对话
  90. func (c *ChatContext) LoadConversation(userid string) error {
  91. dec := gob.NewDecoder(strings.NewReader(public.UserService.GetUserSessionContext(userid)))
  92. err := dec.Decode(&c.old)
  93. if err != nil {
  94. return err
  95. }
  96. c.seqTimes = len(c.old)
  97. return nil
  98. }
  99. func (c *ChatContext) SetHumanRole(role string) {
  100. c.humanRole.Name = role
  101. c.restartSeq = "\n" + c.humanRole.Name + ": "
  102. }
  103. func (c *ChatContext) SetAiRole(role string) {
  104. c.aiRole.Name = role
  105. c.startSeq = "\n" + c.aiRole.Name + ": "
  106. }
  107. func (c *ChatContext) SetMaxSeqTimes(times int) {
  108. c.maxSeqTimes = times
  109. }
  110. func (c *ChatContext) GetMaxSeqTimes() int {
  111. return c.maxSeqTimes
  112. }
  113. func (c *ChatContext) SetBackground(background string) {
  114. c.background = background
  115. }
  116. func (c *ChatContext) SetPreset(preset string) {
  117. c.preset = preset
  118. }
  119. // 通过 base64 编码字符串开头字符判断图像类型
  120. func getImageTypeFromBase64(base64Str string) string {
  121. switch {
  122. case strings.HasPrefix(base64Str, "/9j/"):
  123. return "JPEG"
  124. case strings.HasPrefix(base64Str, "iVBOR"):
  125. return "PNG"
  126. case strings.HasPrefix(base64Str, "R0lG"):
  127. return "GIF"
  128. case strings.HasPrefix(base64Str, "UklG"):
  129. return "WebP"
  130. default:
  131. return "Unknown"
  132. }
  133. }
  134. func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
  135. question = question + "."
  136. if tokenizer.MustCalToken(question) > c.maxQuestionLen {
  137. return "", OverMaxQuestionLength
  138. }
  139. if c.ChatContext.seqTimes >= c.ChatContext.maxSeqTimes {
  140. if c.ChatContext.maintainSeqTimes {
  141. c.ChatContext.PollConversation()
  142. } else {
  143. return "", OverMaxSequenceTimes
  144. }
  145. }
  146. var promptTable []string
  147. promptTable = append(promptTable, c.ChatContext.background)
  148. promptTable = append(promptTable, c.ChatContext.preset)
  149. for _, v := range c.ChatContext.old {
  150. if v.Role == c.ChatContext.humanRole {
  151. promptTable = append(promptTable, "\n"+v.Role.Name+": "+v.Prompt)
  152. } else {
  153. promptTable = append(promptTable, v.Role.Name+": "+v.Prompt)
  154. }
  155. }
  156. promptTable = append(promptTable, "\n"+c.ChatContext.restartSeq+question)
  157. prompt := strings.Join(promptTable, "\n")
  158. prompt += c.ChatContext.startSeq
  159. // 删除对话,直到prompt的长度满足条件
  160. for tokenizer.MustCalToken(prompt) > c.maxText {
  161. if len(c.ChatContext.old) > 1 { // 至少保留一条记录
  162. c.ChatContext.PollConversation() // 删除最旧的一条对话
  163. // 重新构建 prompt,计算长度
  164. promptTable = promptTable[1:] // 删除promptTable中对应的对话
  165. prompt = strings.Join(promptTable, "\n") + c.ChatContext.startSeq
  166. } else {
  167. break // 如果已经只剩一条记录,那么跳出循环
  168. }
  169. }
  170. // if tokenizer.MustCalToken(prompt) > c.maxText-c.maxAnswerLen {
  171. // return "", OverMaxTextLength
  172. // }
  173. model := public.Config.Model
  174. userId := c.userId
  175. if public.Config.AzureOn {
  176. userId = ""
  177. }
  178. if isModelSupportedChatCompletions(model) {
  179. req := openai.ChatCompletionRequest{
  180. Model: model,
  181. Messages: []openai.ChatCompletionMessage{
  182. {
  183. Role: "user",
  184. Content: prompt,
  185. },
  186. },
  187. MaxTokens: c.maxAnswerLen,
  188. Temperature: 0.6,
  189. User: userId,
  190. }
  191. resp, err := c.client.CreateChatCompletion(c.ctx, req)
  192. if err != nil {
  193. return "", err
  194. }
  195. resp.Choices[0].Message.Content = formatAnswer(resp.Choices[0].Message.Content)
  196. c.ChatContext.old = append(c.ChatContext.old, conversation{
  197. Role: c.ChatContext.humanRole,
  198. Prompt: question,
  199. })
  200. c.ChatContext.old = append(c.ChatContext.old, conversation{
  201. Role: c.ChatContext.aiRole,
  202. Prompt: resp.Choices[0].Message.Content,
  203. })
  204. c.ChatContext.seqTimes++
  205. return resp.Choices[0].Message.Content, nil
  206. } else {
  207. req := openai.CompletionRequest{
  208. Model: model,
  209. MaxTokens: c.maxAnswerLen,
  210. Prompt: prompt,
  211. Temperature: 0.6,
  212. User: c.userId,
  213. Stop: []string{c.ChatContext.aiRole.Name + ":", c.ChatContext.humanRole.Name + ":"},
  214. }
  215. resp, err := c.client.CreateCompletion(c.ctx, req)
  216. if err != nil {
  217. return "", err
  218. }
  219. resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text)
  220. c.ChatContext.old = append(c.ChatContext.old, conversation{
  221. Role: c.ChatContext.humanRole,
  222. Prompt: question,
  223. })
  224. c.ChatContext.old = append(c.ChatContext.old, conversation{
  225. Role: c.ChatContext.aiRole,
  226. Prompt: resp.Choices[0].Text,
  227. })
  228. c.ChatContext.seqTimes++
  229. return resp.Choices[0].Text, nil
  230. }
  231. }
  232. func (c *ChatGPT) GenerateImage(ctx context.Context, prompt string) (string, error) {
  233. model := public.Config.Model
  234. imageModel := public.Config.ImageModel
  235. if isModelSupportedChatCompletions(model) {
  236. req := openai.ImageRequest{
  237. Prompt: prompt,
  238. Model: imageModel,
  239. Size: openai.CreateImageSize1024x1024,
  240. ResponseFormat: openai.CreateImageResponseFormatB64JSON,
  241. N: 1,
  242. User: c.userId,
  243. }
  244. respBase64, err := c.client.CreateImage(c.ctx, req)
  245. if err != nil {
  246. return "", err
  247. }
  248. imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON)
  249. if err != nil {
  250. return "", err
  251. }
  252. r := bytes.NewReader(imgBytes)
  253. // dall-e-3 返回的是 WebP 格式的图片,需要判断处理
  254. imgType := getImageTypeFromBase64(respBase64.Data[0].B64JSON)
  255. var imgData image.Image
  256. var imgErr error
  257. if imgType == "WebP" {
  258. imgData, imgErr = webp.Decode(r)
  259. } else {
  260. imgData, _, imgErr = image.Decode(r)
  261. }
  262. if imgErr != nil {
  263. return "", imgErr
  264. }
  265. imageName := time.Now().Format("20060102-150405") + ".png"
  266. clientId, _ := ctx.Value(public.DingTalkClientIdKeyName).(string)
  267. client := public.DingTalkClientManager.GetClientByOAuthClientID(clientId)
  268. mediaResult, uploadErr := &dingbot.MediaUploadResult{}, errors.New(fmt.Sprintf("unknown clientId: %s", clientId))
  269. if client != nil {
  270. mediaResult, uploadErr = client.UploadMedia(imgBytes, imageName, dingbot.MediaTypeImage, dingbot.MimeTypeImagePng)
  271. }
  272. err = os.MkdirAll("data/images", 0755)
  273. if err != nil {
  274. return "", err
  275. }
  276. file, err := os.Create("data/images/" + imageName)
  277. if err != nil {
  278. return "", err
  279. }
  280. defer file.Close()
  281. if err := png.Encode(file, imgData); err != nil {
  282. return "", err
  283. }
  284. if uploadErr == nil {
  285. return mediaResult.MediaID, nil
  286. } else {
  287. return public.Config.ServiceURL + "/images/" + imageName, nil
  288. }
  289. }
  290. return "", nil
  291. }
  292. func WithMaxSeqTimes(times int) ChatContextOption {
  293. return func(c *ChatContext) {
  294. c.SetMaxSeqTimes(times)
  295. }
  296. }
  297. // WithOldConversation 从文件中加载对话
  298. func WithOldConversation(userid string) ChatContextOption {
  299. return func(c *ChatContext) {
  300. _ = c.LoadConversation(userid)
  301. }
  302. }
  303. func WithMaintainSeqTimes(maintain bool) ChatContextOption {
  304. return func(c *ChatContext) {
  305. c.maintainSeqTimes = maintain
  306. }
  307. }