context.go 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. package chatgpt
  2. import (
  3. "bytes"
  4. "encoding/gob"
  5. "fmt"
  6. "os"
  7. "strings"
  8. gogpt "github.com/sashabaranov/go-gpt3"
  9. )
  10. var (
  11. DefaultAiRole = "AI"
  12. DefaultHumanRole = "Human"
  13. DefaultCharacter = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
  14. DefaultBackground = "The following is a conversation with AI assistant. The assistant is %s"
  15. DefaultPreset = "\n%s: 你好,让我们开始愉快的谈话!\n%s: 我是 AI assistant ,请问你有什么问题?"
  16. )
  17. type (
  18. ChatContext struct {
  19. background string // 对话背景
  20. preset string // 预设对话
  21. maxSeqTimes int // 最大对话次数
  22. aiRole *role // AI角色
  23. humanRole *role // 人类角色
  24. old []conversation // 旧对话
  25. restartSeq string // 重新开始对话的标识
  26. startSeq string // 开始对话的标识
  27. seqTimes int // 对话次数
  28. maintainSeqTimes bool // 是否维护对话次数 (自动移除旧对话)
  29. }
  30. ChatContextOption func(*ChatContext)
  31. conversation struct {
  32. Role *role
  33. Prompt string
  34. }
  35. role struct {
  36. Name string
  37. }
  38. )
  39. func NewContext(options ...ChatContextOption) *ChatContext {
  40. ctx := &ChatContext{
  41. aiRole: &role{Name: DefaultAiRole},
  42. humanRole: &role{Name: DefaultHumanRole},
  43. background: fmt.Sprintf(DefaultBackground, strings.Join(DefaultCharacter, ", ")+"."),
  44. maxSeqTimes: 1000,
  45. preset: fmt.Sprintf(DefaultPreset, DefaultHumanRole, DefaultAiRole),
  46. old: []conversation{},
  47. seqTimes: 0,
  48. restartSeq: "\n" + DefaultHumanRole + ": ",
  49. startSeq: "\n" + DefaultAiRole + ": ",
  50. maintainSeqTimes: false,
  51. }
  52. for _, option := range options {
  53. option(ctx)
  54. }
  55. return ctx
  56. }
  57. // PollConversation 移除最旧的一则对话
  58. func (c *ChatContext) PollConversation() {
  59. c.old = c.old[1:]
  60. c.seqTimes--
  61. }
  62. // ResetConversation 重置对话
  63. func (c *ChatContext) ResetConversation() {
  64. c.old = []conversation{}
  65. c.seqTimes = 0
  66. }
  67. // SaveConversation 保存对话
  68. func (c *ChatContext) SaveConversation(path string) error {
  69. var buffer bytes.Buffer
  70. enc := gob.NewEncoder(&buffer)
  71. err := enc.Encode(c.old)
  72. if err != nil {
  73. return err
  74. }
  75. return WriteToFile(path, buffer.Bytes())
  76. }
  77. // LoadConversation 加载对话
  78. func (c *ChatContext) LoadConversation(path string) error {
  79. data, err := os.ReadFile(path)
  80. if err != nil {
  81. return err
  82. }
  83. buffer := bytes.NewBuffer(data)
  84. dec := gob.NewDecoder(buffer)
  85. err = dec.Decode(&c.old)
  86. if err != nil {
  87. return err
  88. }
  89. c.seqTimes = len(c.old)
  90. return nil
  91. }
  92. func (c *ChatContext) SetHumanRole(role string) {
  93. c.humanRole.Name = role
  94. c.restartSeq = "\n" + c.humanRole.Name + ": "
  95. }
  96. func (c *ChatContext) SetAiRole(role string) {
  97. c.aiRole.Name = role
  98. c.startSeq = "\n" + c.aiRole.Name + ": "
  99. }
  100. func (c *ChatContext) SetMaxSeqTimes(times int) {
  101. c.maxSeqTimes = times
  102. }
  103. func (c *ChatContext) GetMaxSeqTimes() int {
  104. return c.maxSeqTimes
  105. }
  106. func (c *ChatContext) SetBackground(background string) {
  107. c.background = background
  108. }
  109. func (c *ChatContext) SetPreset(preset string) {
  110. c.preset = preset
  111. }
  112. func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
  113. question = question + "."
  114. if len(question) > c.maxQuestionLen {
  115. return "", OverMaxQuestionLength
  116. }
  117. if c.ChatContext.seqTimes >= c.ChatContext.maxSeqTimes {
  118. if c.ChatContext.maintainSeqTimes {
  119. c.ChatContext.PollConversation()
  120. } else {
  121. return "", OverMaxSequenceTimes
  122. }
  123. }
  124. var promptTable []string
  125. promptTable = append(promptTable, c.ChatContext.background)
  126. promptTable = append(promptTable, c.ChatContext.preset)
  127. for _, v := range c.ChatContext.old {
  128. if v.Role == c.ChatContext.humanRole {
  129. promptTable = append(promptTable, "\n"+v.Role.Name+": "+v.Prompt)
  130. } else {
  131. promptTable = append(promptTable, v.Role.Name+": "+v.Prompt)
  132. }
  133. }
  134. promptTable = append(promptTable, "\n"+c.ChatContext.restartSeq+question)
  135. prompt := strings.Join(promptTable, "\n")
  136. prompt += c.ChatContext.startSeq
  137. if len(prompt) > c.maxText-c.maxAnswerLen {
  138. return "", OverMaxTextLength
  139. }
  140. req := gogpt.CompletionRequest{
  141. Model: gogpt.GPT3TextDavinci003,
  142. MaxTokens: c.maxAnswerLen,
  143. Prompt: prompt,
  144. Temperature: 0.9,
  145. TopP: 1,
  146. N: 1,
  147. FrequencyPenalty: 0,
  148. PresencePenalty: 0.5,
  149. User: c.userId,
  150. Stop: []string{c.ChatContext.aiRole.Name + ":", c.ChatContext.humanRole.Name + ":"},
  151. }
  152. resp, err := c.client.CreateCompletion(c.ctx, req)
  153. if err != nil {
  154. return "", err
  155. }
  156. resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text)
  157. c.ChatContext.old = append(c.ChatContext.old, conversation{
  158. Role: c.ChatContext.humanRole,
  159. Prompt: question,
  160. })
  161. c.ChatContext.old = append(c.ChatContext.old, conversation{
  162. Role: c.ChatContext.aiRole,
  163. Prompt: resp.Choices[0].Text,
  164. })
  165. c.ChatContext.seqTimes++
  166. return resp.Choices[0].Text, nil
  167. }
  168. func WithMaxSeqTimes(times int) ChatContextOption {
  169. return func(c *ChatContext) {
  170. c.SetMaxSeqTimes(times)
  171. }
  172. }
  173. // WithOldConversation 从文件中加载对话
  174. func WithOldConversation(path string) ChatContextOption {
  175. return func(c *ChatContext) {
  176. _ = c.LoadConversation(path)
  177. }
  178. }
  179. func WithMaintainSeqTimes(maintain bool) ChatContextOption {
  180. return func(c *ChatContext) {
  181. c.maintainSeqTimes = maintain
  182. }
  183. }