chatgpt.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package chatgpt
  2. import (
  3. "context"
  4. "net/http"
  5. "net/url"
  6. "time"
  7. gogpt "github.com/sashabaranov/go-gpt3"
  8. )
  9. type ChatGPT struct {
  10. client *gogpt.Client
  11. ctx context.Context
  12. userId string
  13. maxQuestionLen int
  14. maxText int
  15. maxAnswerLen int
  16. timeOut time.Duration // 超时时间, 0表示不超时
  17. doneChan chan struct{}
  18. cancel func()
  19. ChatContext *ChatContext
  20. }
  21. func New(apiKey, proxyUrl, userId string, timeOut time.Duration) *ChatGPT {
  22. var ctx context.Context
  23. var cancel func()
  24. if timeOut == 0 {
  25. ctx, cancel = context.WithCancel(context.Background())
  26. } else {
  27. ctx, cancel = context.WithTimeout(context.Background(), timeOut)
  28. }
  29. timeOutChan := make(chan struct{}, 1)
  30. go func() {
  31. <-ctx.Done()
  32. timeOutChan <- struct{}{} // 发送超时信号,或是提示结束,用于聊天机器人场景,配合GetTimeOutChan() 使用
  33. }()
  34. config := gogpt.DefaultConfig(apiKey)
  35. if proxyUrl != "" {
  36. config.HTTPClient.Transport = &http.Transport{
  37. // 设置代理
  38. Proxy: func(req *http.Request) (*url.URL, error) {
  39. return url.Parse(proxyUrl)
  40. }}
  41. }
  42. return &ChatGPT{
  43. client: gogpt.NewClientWithConfig(config),
  44. ctx: ctx,
  45. userId: userId,
  46. maxQuestionLen: 2048, // 最大问题长度
  47. maxAnswerLen: 2048, // 最大答案长度
  48. maxText: 4096, // 最大文本 = 问题 + 回答, 接口限制
  49. timeOut: timeOut,
  50. doneChan: timeOutChan,
  51. cancel: func() {
  52. cancel()
  53. },
  54. ChatContext: NewContext(),
  55. }
  56. }
  57. func (c *ChatGPT) Close() {
  58. c.cancel()
  59. }
  60. func (c *ChatGPT) GetDoneChan() chan struct{} {
  61. return c.doneChan
  62. }
  63. func (c *ChatGPT) SetMaxQuestionLen(maxQuestionLen int) int {
  64. if maxQuestionLen > c.maxText-c.maxAnswerLen {
  65. maxQuestionLen = c.maxText - c.maxAnswerLen
  66. }
  67. c.maxQuestionLen = maxQuestionLen
  68. return c.maxQuestionLen
  69. }
  70. // func (c *ChatGPT) Chat(question string) (answer string, err error) {
  71. // question = question + "."
  72. // if len(question) > c.maxQuestionLen {
  73. // return "", OverMaxQuestionLength
  74. // }
  75. // if len(question)+c.maxAnswerLen > c.maxText {
  76. // question = question[:c.maxText-c.maxAnswerLen]
  77. // }
  78. // req := gogpt.CompletionRequest{
  79. // Model: gogpt.GPT3TextDavinci003,
  80. // MaxTokens: c.maxAnswerLen,
  81. // Prompt: question,
  82. // Temperature: 0.9,
  83. // TopP: 1,
  84. // N: 1,
  85. // FrequencyPenalty: 0,
  86. // PresencePenalty: 0.5,
  87. // User: c.userId,
  88. // Stop: []string{},
  89. // }
  90. // resp, err := c.client.CreateCompletion(c.ctx, req)
  91. // if err != nil {
  92. // return "", err
  93. // }
  94. // return formatAnswer(resp.Choices[0].Text), err
  95. // }