chatgpt.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package chatgpt
  2. import (
  3. "context"
  4. "time"
  5. gogpt "github.com/sashabaranov/go-gpt3"
  6. )
  7. type ChatGPT struct {
  8. client *gogpt.Client
  9. ctx context.Context
  10. userId string
  11. maxQuestionLen int
  12. maxText int
  13. maxAnswerLen int
  14. timeOut time.Duration // 超时时间, 0表示不超时
  15. doneChan chan struct{}
  16. cancel func()
  17. ChatContext *ChatContext
  18. }
  19. func New(ApiKey, UserId string, timeOut time.Duration) *ChatGPT {
  20. var ctx context.Context
  21. var cancel func()
  22. if timeOut == 0 {
  23. ctx, cancel = context.WithCancel(context.Background())
  24. } else {
  25. ctx, cancel = context.WithTimeout(context.Background(), timeOut)
  26. }
  27. timeOutChan := make(chan struct{}, 1)
  28. go func() {
  29. <-ctx.Done()
  30. timeOutChan <- struct{}{} // 发送超时信号,或是提示结束,用于聊天机器人场景,配合GetTimeOutChan() 使用
  31. }()
  32. return &ChatGPT{
  33. client: gogpt.NewClient(ApiKey),
  34. ctx: ctx,
  35. userId: UserId,
  36. maxQuestionLen: 2048, // 最大问题长度
  37. maxAnswerLen: 2048, // 最大答案长度
  38. maxText: 4096, // 最大文本 = 问题 + 回答, 接口限制
  39. timeOut: timeOut,
  40. doneChan: timeOutChan,
  41. cancel: func() {
  42. cancel()
  43. },
  44. ChatContext: NewContext(),
  45. }
  46. }
  47. func (c *ChatGPT) Close() {
  48. c.cancel()
  49. }
  50. func (c *ChatGPT) GetDoneChan() chan struct{} {
  51. return c.doneChan
  52. }
  53. func (c *ChatGPT) SetMaxQuestionLen(maxQuestionLen int) int {
  54. if maxQuestionLen > c.maxText-c.maxAnswerLen {
  55. maxQuestionLen = c.maxText - c.maxAnswerLen
  56. }
  57. c.maxQuestionLen = maxQuestionLen
  58. return c.maxQuestionLen
  59. }
  60. func (c *ChatGPT) Chat(question string) (answer string, err error) {
  61. question = question + "."
  62. if len(question) > c.maxQuestionLen {
  63. return "", OverMaxQuestionLength
  64. }
  65. if len(question)+c.maxAnswerLen > c.maxText {
  66. question = question[:c.maxText-c.maxAnswerLen]
  67. }
  68. req := gogpt.CompletionRequest{
  69. Model: gogpt.GPT3TextDavinci003,
  70. MaxTokens: c.maxAnswerLen,
  71. Prompt: question,
  72. Temperature: 0.9,
  73. TopP: 1,
  74. N: 1,
  75. FrequencyPenalty: 0,
  76. PresencePenalty: 0.5,
  77. User: c.userId,
  78. Stop: []string{},
  79. }
  80. resp, err := c.client.CreateCompletion(c.ctx, req)
  81. if err != nil {
  82. return "", err
  83. }
  84. return formatAnswer(resp.Choices[0].Text), err
  85. }