context.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. package chatgpt
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/gob"
  6. "image/png"
  7. "os"
  8. "strings"
  9. "time"
  10. "github.com/eryajf/chatgpt-dingtalk/public"
  11. openai "github.com/sashabaranov/go-openai"
  12. )
  13. var (
  14. DefaultAiRole = "AI"
  15. DefaultHumanRole = "Human"
  16. DefaultCharacter = []string{"helpful", "creative", "clever", "friendly", "lovely", "talkative"}
  17. DefaultBackground = "The following is a conversation with AI assistant. The assistant is %s"
  18. DefaultPreset = "\n%s: 你好,让我们开始愉快的谈话!\n%s: 我是 AI assistant ,请问你有什么问题?"
  19. )
  20. type (
  21. ChatContext struct {
  22. background string // 对话背景
  23. preset string // 预设对话
  24. maxSeqTimes int // 最大对话次数
  25. aiRole *role // AI角色
  26. humanRole *role // 人类角色
  27. old []conversation // 旧对话
  28. restartSeq string // 重新开始对话的标识
  29. startSeq string // 开始对话的标识
  30. seqTimes int // 对话次数
  31. maintainSeqTimes bool // 是否维护对话次数 (自动移除旧对话)
  32. }
  33. ChatContextOption func(*ChatContext)
  34. conversation struct {
  35. Role *role
  36. Prompt string
  37. }
  38. role struct {
  39. Name string
  40. }
  41. )
  42. func NewContext(options ...ChatContextOption) *ChatContext {
  43. ctx := &ChatContext{
  44. aiRole: &role{Name: DefaultAiRole},
  45. humanRole: &role{Name: DefaultHumanRole},
  46. background: "",
  47. maxSeqTimes: 1000,
  48. preset: "",
  49. old: []conversation{},
  50. seqTimes: 0,
  51. restartSeq: "\n" + DefaultHumanRole + ": ",
  52. startSeq: "\n" + DefaultAiRole + ": ",
  53. maintainSeqTimes: false,
  54. }
  55. for _, option := range options {
  56. option(ctx)
  57. }
  58. return ctx
  59. }
  60. // PollConversation 移除最旧的一则对话
  61. func (c *ChatContext) PollConversation() {
  62. c.old = c.old[1:]
  63. c.seqTimes--
  64. }
  65. // ResetConversation 重置对话
  66. func (c *ChatContext) ResetConversation(userid string) {
  67. public.UserService.ClearUserSessionContext(userid)
  68. }
  69. // SaveConversation 保存对话
  70. func (c *ChatContext) SaveConversation(userid string) error {
  71. var buffer bytes.Buffer
  72. enc := gob.NewEncoder(&buffer)
  73. err := enc.Encode(c.old)
  74. if err != nil {
  75. return err
  76. }
  77. public.UserService.SetUserSessionContext(userid, buffer.String())
  78. return nil
  79. }
  80. // LoadConversation 加载对话
  81. func (c *ChatContext) LoadConversation(userid string) error {
  82. dec := gob.NewDecoder(strings.NewReader(public.UserService.GetUserSessionContext(userid)))
  83. err := dec.Decode(&c.old)
  84. if err != nil {
  85. return err
  86. }
  87. c.seqTimes = len(c.old)
  88. return nil
  89. }
  90. func (c *ChatContext) SetHumanRole(role string) {
  91. c.humanRole.Name = role
  92. c.restartSeq = "\n" + c.humanRole.Name + ": "
  93. }
  94. func (c *ChatContext) SetAiRole(role string) {
  95. c.aiRole.Name = role
  96. c.startSeq = "\n" + c.aiRole.Name + ": "
  97. }
  98. func (c *ChatContext) SetMaxSeqTimes(times int) {
  99. c.maxSeqTimes = times
  100. }
  101. func (c *ChatContext) GetMaxSeqTimes() int {
  102. return c.maxSeqTimes
  103. }
  104. func (c *ChatContext) SetBackground(background string) {
  105. c.background = background
  106. }
  107. func (c *ChatContext) SetPreset(preset string) {
  108. c.preset = preset
  109. }
  110. func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
  111. question = question + "."
  112. if len(question) > c.maxQuestionLen {
  113. return "", OverMaxQuestionLength
  114. }
  115. if c.ChatContext.seqTimes >= c.ChatContext.maxSeqTimes {
  116. if c.ChatContext.maintainSeqTimes {
  117. c.ChatContext.PollConversation()
  118. } else {
  119. return "", OverMaxSequenceTimes
  120. }
  121. }
  122. var promptTable []string
  123. promptTable = append(promptTable, c.ChatContext.background)
  124. promptTable = append(promptTable, c.ChatContext.preset)
  125. for _, v := range c.ChatContext.old {
  126. if v.Role == c.ChatContext.humanRole {
  127. promptTable = append(promptTable, "\n"+v.Role.Name+": "+v.Prompt)
  128. } else {
  129. promptTable = append(promptTable, v.Role.Name+": "+v.Prompt)
  130. }
  131. }
  132. promptTable = append(promptTable, "\n"+c.ChatContext.restartSeq+question)
  133. prompt := strings.Join(promptTable, "\n")
  134. prompt += c.ChatContext.startSeq
  135. if len(prompt) > c.maxText-c.maxAnswerLen {
  136. return "", OverMaxTextLength
  137. }
  138. model := public.Config.Model
  139. if model == openai.GPT3Dot5Turbo0301 ||
  140. model == openai.GPT3Dot5Turbo ||
  141. model == openai.GPT4 || model == openai.GPT40314 ||
  142. model == openai.GPT432K || model == openai.GPT432K0314 {
  143. req := openai.ChatCompletionRequest{
  144. Model: model,
  145. Messages: []openai.ChatCompletionMessage{
  146. {
  147. Role: "user",
  148. Content: prompt,
  149. },
  150. },
  151. MaxTokens: 3072,
  152. Temperature: 0.6,
  153. User: c.userId,
  154. }
  155. resp, err := c.client.CreateChatCompletion(c.ctx, req)
  156. if err != nil {
  157. return "", err
  158. }
  159. resp.Choices[0].Message.Content = formatAnswer(resp.Choices[0].Message.Content)
  160. c.ChatContext.old = append(c.ChatContext.old, conversation{
  161. Role: c.ChatContext.humanRole,
  162. Prompt: question,
  163. })
  164. c.ChatContext.old = append(c.ChatContext.old, conversation{
  165. Role: c.ChatContext.aiRole,
  166. Prompt: resp.Choices[0].Message.Content,
  167. })
  168. c.ChatContext.seqTimes++
  169. return resp.Choices[0].Message.Content, nil
  170. } else {
  171. req := openai.CompletionRequest{
  172. Model: model,
  173. MaxTokens: c.maxAnswerLen,
  174. Prompt: prompt,
  175. Temperature: 0.6,
  176. User: c.userId,
  177. Stop: []string{c.ChatContext.aiRole.Name + ":", c.ChatContext.humanRole.Name + ":"},
  178. }
  179. resp, err := c.client.CreateCompletion(c.ctx, req)
  180. if err != nil {
  181. return "", err
  182. }
  183. resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text)
  184. c.ChatContext.old = append(c.ChatContext.old, conversation{
  185. Role: c.ChatContext.humanRole,
  186. Prompt: question,
  187. })
  188. c.ChatContext.old = append(c.ChatContext.old, conversation{
  189. Role: c.ChatContext.aiRole,
  190. Prompt: resp.Choices[0].Text,
  191. })
  192. c.ChatContext.seqTimes++
  193. return resp.Choices[0].Text, nil
  194. }
  195. }
  196. func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
  197. model := public.Config.Model
  198. if model == openai.GPT3Dot5Turbo0301 ||
  199. model == openai.GPT3Dot5Turbo ||
  200. model == openai.GPT4 || model == openai.GPT40314 ||
  201. model == openai.GPT432K || model == openai.GPT432K0314 {
  202. req := openai.ImageRequest{
  203. Prompt: prompt,
  204. Size: openai.CreateImageSize1024x1024,
  205. ResponseFormat: openai.CreateImageResponseFormatB64JSON,
  206. N: 1,
  207. User: c.userId,
  208. }
  209. respBase64, err := c.client.CreateImage(c.ctx, req)
  210. if err != nil {
  211. return "", err
  212. }
  213. imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON)
  214. if err != nil {
  215. return "", err
  216. }
  217. r := bytes.NewReader(imgBytes)
  218. imgData, err := png.Decode(r)
  219. if err != nil {
  220. return "", err
  221. }
  222. imageName := time.Now().Format("20060102-150405") + ".png"
  223. err = os.MkdirAll("images", 0755)
  224. if err != nil {
  225. return "", err
  226. }
  227. file, err := os.Create("data/images/" + imageName)
  228. if err != nil {
  229. return "", err
  230. }
  231. defer file.Close()
  232. if err := png.Encode(file, imgData); err != nil {
  233. return "", err
  234. }
  235. return public.Config.ServiceURL + "/images/" + imageName, nil
  236. }
  237. return "", nil
  238. }
  239. func WithMaxSeqTimes(times int) ChatContextOption {
  240. return func(c *ChatContext) {
  241. c.SetMaxSeqTimes(times)
  242. }
  243. }
  244. // WithOldConversation 从文件中加载对话
  245. func WithOldConversation(userid string) ChatContextOption {
  246. return func(c *ChatContext) {
  247. _ = c.LoadConversation(userid)
  248. }
  249. }
  250. func WithMaintainSeqTimes(maintain bool) ChatContextOption {
  251. return func(c *ChatContext) {
  252. c.maintainSeqTimes = maintain
  253. }
  254. }