context.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. package chatgpt
  2. import (
  3. "bytes"
  4. "encoding/gob"
  5. "fmt"
  6. "strings"
  7. "github.com/eryajf/chatgpt-dingtalk/public"
  8. gogpt "github.com/sashabaranov/go-gpt3"
  9. )
  10. var (
  11. DefaultAiRole = "AI"
  12. DefaultHumanRole = "Human"
  13. DefaultCharacter = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
  14. DefaultBackground = "The following is a conversation with AI assistant. The assistant is %s"
  15. DefaultPreset = "\n%s: 你好,让我们开始愉快的谈话!\n%s: 我是 AI assistant ,请问你有什么问题?"
  16. )
  17. type (
  18. ChatContext struct {
  19. background string // 对话背景
  20. preset string // 预设对话
  21. maxSeqTimes int // 最大对话次数
  22. aiRole *role // AI角色
  23. humanRole *role // 人类角色
  24. old []conversation // 旧对话
  25. restartSeq string // 重新开始对话的标识
  26. startSeq string // 开始对话的标识
  27. seqTimes int // 对话次数
  28. maintainSeqTimes bool // 是否维护对话次数 (自动移除旧对话)
  29. }
  30. ChatContextOption func(*ChatContext)
  31. conversation struct {
  32. Role *role
  33. Prompt string
  34. }
  35. role struct {
  36. Name string
  37. }
  38. )
  39. func NewContext(options ...ChatContextOption) *ChatContext {
  40. ctx := &ChatContext{
  41. aiRole: &role{Name: DefaultAiRole},
  42. humanRole: &role{Name: DefaultHumanRole},
  43. background: fmt.Sprintf(DefaultBackground, strings.Join(DefaultCharacter, ", ")+"."),
  44. maxSeqTimes: 1000,
  45. preset: fmt.Sprintf(DefaultPreset, DefaultHumanRole, DefaultAiRole),
  46. old: []conversation{},
  47. seqTimes: 0,
  48. restartSeq: "\n" + DefaultHumanRole + ": ",
  49. startSeq: "\n" + DefaultAiRole + ": ",
  50. maintainSeqTimes: false,
  51. }
  52. for _, option := range options {
  53. option(ctx)
  54. }
  55. return ctx
  56. }
  57. // PollConversation 移除最旧的一则对话
  58. func (c *ChatContext) PollConversation() {
  59. c.old = c.old[1:]
  60. c.seqTimes--
  61. }
  62. // ResetConversation 重置对话
  63. func (c *ChatContext) ResetConversation(userid string) {
  64. public.UserService.ClearUserSessionContext(userid)
  65. }
  66. // SaveConversation 保存对话
  67. func (c *ChatContext) SaveConversation(userid string) error {
  68. var buffer bytes.Buffer
  69. enc := gob.NewEncoder(&buffer)
  70. err := enc.Encode(c.old)
  71. if err != nil {
  72. return err
  73. }
  74. public.UserService.SetUserSessionContext(userid, buffer.String())
  75. return nil
  76. }
  77. // LoadConversation 加载对话
  78. func (c *ChatContext) LoadConversation(userid string) error {
  79. dec := gob.NewDecoder(strings.NewReader(public.UserService.GetUserSessionContext(userid)))
  80. err := dec.Decode(&c.old)
  81. if err != nil {
  82. return err
  83. }
  84. c.seqTimes = len(c.old)
  85. return nil
  86. }
  87. func (c *ChatContext) SetHumanRole(role string) {
  88. c.humanRole.Name = role
  89. c.restartSeq = "\n" + c.humanRole.Name + ": "
  90. }
  91. func (c *ChatContext) SetAiRole(role string) {
  92. c.aiRole.Name = role
  93. c.startSeq = "\n" + c.aiRole.Name + ": "
  94. }
  95. func (c *ChatContext) SetMaxSeqTimes(times int) {
  96. c.maxSeqTimes = times
  97. }
  98. func (c *ChatContext) GetMaxSeqTimes() int {
  99. return c.maxSeqTimes
  100. }
  101. func (c *ChatContext) SetBackground(background string) {
  102. c.background = background
  103. }
  104. func (c *ChatContext) SetPreset(preset string) {
  105. c.preset = preset
  106. }
  107. func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
  108. question = question + "."
  109. if len(question) > c.maxQuestionLen {
  110. return "", OverMaxQuestionLength
  111. }
  112. if c.ChatContext.seqTimes >= c.ChatContext.maxSeqTimes {
  113. if c.ChatContext.maintainSeqTimes {
  114. c.ChatContext.PollConversation()
  115. } else {
  116. return "", OverMaxSequenceTimes
  117. }
  118. }
  119. var promptTable []string
  120. promptTable = append(promptTable, c.ChatContext.background)
  121. promptTable = append(promptTable, c.ChatContext.preset)
  122. for _, v := range c.ChatContext.old {
  123. if v.Role == c.ChatContext.humanRole {
  124. promptTable = append(promptTable, "\n"+v.Role.Name+": "+v.Prompt)
  125. } else {
  126. promptTable = append(promptTable, v.Role.Name+": "+v.Prompt)
  127. }
  128. }
  129. promptTable = append(promptTable, "\n"+c.ChatContext.restartSeq+question)
  130. prompt := strings.Join(promptTable, "\n")
  131. prompt += c.ChatContext.startSeq
  132. if len(prompt) > c.maxText-c.maxAnswerLen {
  133. return "", OverMaxTextLength
  134. }
  135. if public.Config.Model == gogpt.GPT3Dot5Turbo0301 || public.Config.Model == gogpt.GPT3Dot5Turbo {
  136. req := gogpt.ChatCompletionRequest{
  137. Model: public.Config.Model,
  138. Messages: []gogpt.ChatCompletionMessage{
  139. {
  140. Role: "user",
  141. Content: prompt,
  142. },
  143. }}
  144. resp, err := c.client.CreateChatCompletion(c.ctx, req)
  145. if err != nil {
  146. return "", err
  147. }
  148. resp.Choices[0].Message.Content = formatAnswer(resp.Choices[0].Message.Content)
  149. c.ChatContext.old = append(c.ChatContext.old, conversation{
  150. Role: c.ChatContext.humanRole,
  151. Prompt: question,
  152. })
  153. c.ChatContext.old = append(c.ChatContext.old, conversation{
  154. Role: c.ChatContext.aiRole,
  155. Prompt: resp.Choices[0].Message.Content,
  156. })
  157. c.ChatContext.seqTimes++
  158. return resp.Choices[0].Message.Content, nil
  159. } else {
  160. req := gogpt.CompletionRequest{
  161. Model: public.Config.Model,
  162. MaxTokens: c.maxAnswerLen,
  163. Prompt: prompt,
  164. Temperature: 0.9,
  165. TopP: 1,
  166. N: 1,
  167. FrequencyPenalty: 0,
  168. PresencePenalty: 0.5,
  169. User: c.userId,
  170. Stop: []string{c.ChatContext.aiRole.Name + ":", c.ChatContext.humanRole.Name + ":"},
  171. }
  172. resp, err := c.client.CreateCompletion(c.ctx, req)
  173. if err != nil {
  174. return "", err
  175. }
  176. resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text)
  177. c.ChatContext.old = append(c.ChatContext.old, conversation{
  178. Role: c.ChatContext.humanRole,
  179. Prompt: question,
  180. })
  181. c.ChatContext.old = append(c.ChatContext.old, conversation{
  182. Role: c.ChatContext.aiRole,
  183. Prompt: resp.Choices[0].Text,
  184. })
  185. c.ChatContext.seqTimes++
  186. return resp.Choices[0].Text, nil
  187. }
  188. }
  189. func WithMaxSeqTimes(times int) ChatContextOption {
  190. return func(c *ChatContext) {
  191. c.SetMaxSeqTimes(times)
  192. }
  193. }
  194. // WithOldConversation 从文件中加载对话
  195. func WithOldConversation(userid string) ChatContextOption {
  196. return func(c *ChatContext) {
  197. _ = c.LoadConversation(userid)
  198. }
  199. }
  200. func WithMaintainSeqTimes(maintain bool) ChatContextOption {
  201. return func(c *ChatContext) {
  202. c.maintainSeqTimes = maintain
  203. }
  204. }