context.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. package chatgpt
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/gob"
  6. "github.com/pandodao/tokenizer-go"
  7. "image/png"
  8. "os"
  9. "strings"
  10. "time"
  11. "github.com/eryajf/chatgpt-dingtalk/public"
  12. openai "github.com/sashabaranov/go-openai"
  13. )
  14. var (
  15. DefaultAiRole = "AI"
  16. DefaultHumanRole = "Human"
  17. DefaultCharacter = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
  18. DefaultBackground = "The following is a conversation with AI assistant. The assistant is %s"
  19. DefaultPreset = "\n%s: 你好,让我们开始愉快的谈话!\n%s: 我是 AI assistant ,请问你有什么问题?"
  20. )
  21. type (
  22. ChatContext struct {
  23. background string // 对话背景
  24. preset string // 预设对话
  25. maxSeqTimes int // 最大对话次数
  26. aiRole *role // AI角色
  27. humanRole *role // 人类角色
  28. old []conversation // 旧对话
  29. restartSeq string // 重新开始对话的标识
  30. startSeq string // 开始对话的标识
  31. seqTimes int // 对话次数
  32. maintainSeqTimes bool // 是否维护对话次数 (自动移除旧对话)
  33. }
  34. ChatContextOption func(*ChatContext)
  35. conversation struct {
  36. Role *role
  37. Prompt string
  38. }
  39. role struct {
  40. Name string
  41. }
  42. )
  43. func NewContext(options ...ChatContextOption) *ChatContext {
  44. ctx := &ChatContext{
  45. aiRole: &role{Name: DefaultAiRole},
  46. humanRole: &role{Name: DefaultHumanRole},
  47. background: "",
  48. maxSeqTimes: 1000,
  49. preset: "",
  50. old: []conversation{},
  51. seqTimes: 0,
  52. restartSeq: "\n" + DefaultHumanRole + ": ",
  53. startSeq: "\n" + DefaultAiRole + ": ",
  54. maintainSeqTimes: false,
  55. }
  56. for _, option := range options {
  57. option(ctx)
  58. }
  59. return ctx
  60. }
  61. // PollConversation 移除最旧的一则对话
  62. func (c *ChatContext) PollConversation() {
  63. c.old = c.old[1:]
  64. c.seqTimes--
  65. }
  66. // ResetConversation 重置对话
  67. func (c *ChatContext) ResetConversation(userid string) {
  68. public.UserService.ClearUserSessionContext(userid)
  69. }
  70. // SaveConversation 保存对话
  71. func (c *ChatContext) SaveConversation(userid string) error {
  72. var buffer bytes.Buffer
  73. enc := gob.NewEncoder(&buffer)
  74. err := enc.Encode(c.old)
  75. if err != nil {
  76. return err
  77. }
  78. public.UserService.SetUserSessionContext(userid, buffer.String())
  79. return nil
  80. }
  81. // LoadConversation 加载对话
  82. func (c *ChatContext) LoadConversation(userid string) error {
  83. dec := gob.NewDecoder(strings.NewReader(public.UserService.GetUserSessionContext(userid)))
  84. err := dec.Decode(&c.old)
  85. if err != nil {
  86. return err
  87. }
  88. c.seqTimes = len(c.old)
  89. return nil
  90. }
  91. func (c *ChatContext) SetHumanRole(role string) {
  92. c.humanRole.Name = role
  93. c.restartSeq = "\n" + c.humanRole.Name + ": "
  94. }
  95. func (c *ChatContext) SetAiRole(role string) {
  96. c.aiRole.Name = role
  97. c.startSeq = "\n" + c.aiRole.Name + ": "
  98. }
  99. func (c *ChatContext) SetMaxSeqTimes(times int) {
  100. c.maxSeqTimes = times
  101. }
  102. func (c *ChatContext) GetMaxSeqTimes() int {
  103. return c.maxSeqTimes
  104. }
  105. func (c *ChatContext) SetBackground(background string) {
  106. c.background = background
  107. }
  108. func (c *ChatContext) SetPreset(preset string) {
  109. c.preset = preset
  110. }
  111. func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
  112. question = question + "."
  113. if tokenizer.MustCalToken(question) > c.maxQuestionLen {
  114. return "", OverMaxQuestionLength
  115. }
  116. if c.ChatContext.seqTimes >= c.ChatContext.maxSeqTimes {
  117. if c.ChatContext.maintainSeqTimes {
  118. c.ChatContext.PollConversation()
  119. } else {
  120. return "", OverMaxSequenceTimes
  121. }
  122. }
  123. var promptTable []string
  124. promptTable = append(promptTable, c.ChatContext.background)
  125. promptTable = append(promptTable, c.ChatContext.preset)
  126. for _, v := range c.ChatContext.old {
  127. if v.Role == c.ChatContext.humanRole {
  128. promptTable = append(promptTable, "\n"+v.Role.Name+": "+v.Prompt)
  129. } else {
  130. promptTable = append(promptTable, v.Role.Name+": "+v.Prompt)
  131. }
  132. }
  133. promptTable = append(promptTable, "\n"+c.ChatContext.restartSeq+question)
  134. prompt := strings.Join(promptTable, "\n")
  135. prompt += c.ChatContext.startSeq
  136. if tokenizer.MustCalToken(prompt) > c.maxText-c.maxAnswerLen {
  137. return "", OverMaxTextLength
  138. }
  139. model := public.Config.Model
  140. if model == openai.GPT3Dot5Turbo0301 ||
  141. model == openai.GPT3Dot5Turbo ||
  142. model == openai.GPT4 || model == openai.GPT40314 ||
  143. model == openai.GPT432K || model == openai.GPT432K0314 {
  144. req := openai.ChatCompletionRequest{
  145. Model: model,
  146. Messages: []openai.ChatCompletionMessage{
  147. {
  148. Role: "user",
  149. Content: prompt,
  150. },
  151. },
  152. MaxTokens: c.maxAnswerLen,
  153. Temperature: 0.6,
  154. User: c.userId,
  155. }
  156. resp, err := c.client.CreateChatCompletion(c.ctx, req)
  157. if err != nil {
  158. return "", err
  159. }
  160. resp.Choices[0].Message.Content = formatAnswer(resp.Choices[0].Message.Content)
  161. c.ChatContext.old = append(c.ChatContext.old, conversation{
  162. Role: c.ChatContext.humanRole,
  163. Prompt: question,
  164. })
  165. c.ChatContext.old = append(c.ChatContext.old, conversation{
  166. Role: c.ChatContext.aiRole,
  167. Prompt: resp.Choices[0].Message.Content,
  168. })
  169. c.ChatContext.seqTimes++
  170. return resp.Choices[0].Message.Content, nil
  171. } else {
  172. req := openai.CompletionRequest{
  173. Model: model,
  174. MaxTokens: c.maxAnswerLen,
  175. Prompt: prompt,
  176. Temperature: 0.6,
  177. User: c.userId,
  178. Stop: []string{c.ChatContext.aiRole.Name + ":", c.ChatContext.humanRole.Name + ":"},
  179. }
  180. resp, err := c.client.CreateCompletion(c.ctx, req)
  181. if err != nil {
  182. return "", err
  183. }
  184. resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text)
  185. c.ChatContext.old = append(c.ChatContext.old, conversation{
  186. Role: c.ChatContext.humanRole,
  187. Prompt: question,
  188. })
  189. c.ChatContext.old = append(c.ChatContext.old, conversation{
  190. Role: c.ChatContext.aiRole,
  191. Prompt: resp.Choices[0].Text,
  192. })
  193. c.ChatContext.seqTimes++
  194. return resp.Choices[0].Text, nil
  195. }
  196. }
  197. func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
  198. model := public.Config.Model
  199. if model == openai.GPT3Dot5Turbo0301 ||
  200. model == openai.GPT3Dot5Turbo ||
  201. model == openai.GPT4 || model == openai.GPT40314 ||
  202. model == openai.GPT432K || model == openai.GPT432K0314 {
  203. req := openai.ImageRequest{
  204. Prompt: prompt,
  205. Size: openai.CreateImageSize1024x1024,
  206. ResponseFormat: openai.CreateImageResponseFormatB64JSON,
  207. N: 1,
  208. User: c.userId,
  209. }
  210. respBase64, err := c.client.CreateImage(c.ctx, req)
  211. if err != nil {
  212. return "", err
  213. }
  214. imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON)
  215. if err != nil {
  216. return "", err
  217. }
  218. r := bytes.NewReader(imgBytes)
  219. imgData, err := png.Decode(r)
  220. if err != nil {
  221. return "", err
  222. }
  223. imageName := time.Now().Format("20060102-150405") + ".png"
  224. err = os.MkdirAll("data/images", 0755)
  225. if err != nil {
  226. return "", err
  227. }
  228. file, err := os.Create("data/images/" + imageName)
  229. if err != nil {
  230. return "", err
  231. }
  232. defer file.Close()
  233. if err := png.Encode(file, imgData); err != nil {
  234. return "", err
  235. }
  236. return public.Config.ServiceURL + "/images/" + imageName, nil
  237. }
  238. return "", nil
  239. }
  240. func WithMaxSeqTimes(times int) ChatContextOption {
  241. return func(c *ChatContext) {
  242. c.SetMaxSeqTimes(times)
  243. }
  244. }
  245. // WithOldConversation 从文件中加载对话
  246. func WithOldConversation(userid string) ChatContextOption {
  247. return func(c *ChatContext) {
  248. _ = c.LoadConversation(userid)
  249. }
  250. }
  251. func WithMaintainSeqTimes(maintain bool) ChatContextOption {
  252. return func(c *ChatContext) {
  253. c.maintainSeqTimes = maintain
  254. }
  255. }