gtp.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package gtp
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/ioutil"
  8. "net/http"
  9. "time"
  10. "github.com/eryajf/chatgpt-dingtalk/config"
  11. "github.com/eryajf/chatgpt-dingtalk/public/logger"
  12. )
  13. const BASEURL = "https://api.openai.com/v1/"
  14. // ChatGPTResponseBody 请求体
  15. type ChatGPTResponseBody struct {
  16. ID string `json:"id"`
  17. Object string `json:"object"`
  18. Created int `json:"created"`
  19. Model string `json:"model"`
  20. Choices []ChoiceItem `json:"choices"`
  21. Usage map[string]interface{} `json:"usage"`
  22. }
  23. type ChoiceItem struct {
  24. Text string `json:"text"`
  25. Index int `json:"index"`
  26. Logprobs int `json:"logprobs"`
  27. FinishReason string `json:"finish_reason"`
  28. }
  29. // ChatGPTRequestBody 响应体
  30. type ChatGPTRequestBody struct {
  31. Model string `json:"model"`
  32. Prompt string `json:"prompt"`
  33. MaxTokens uint `json:"max_tokens"`
  34. Temperature float64 `json:"temperature"`
  35. TopP int `json:"top_p"`
  36. FrequencyPenalty int `json:"frequency_penalty"`
  37. PresencePenalty int `json:"presence_penalty"`
  38. }
  39. // Completions gtp文本模型回复
  40. //curl https://api.openai.com/v1/completions
  41. //-H "Content-Type: application/json"
  42. //-H "Authorization: Bearer your chatGPT key"
  43. //-d '{"model": "text-davinci-003", "prompt": "give me good song", "temperature": 0, "max_tokens": 7}'
  44. func Completions(msg string) (string, error) {
  45. cfg := config.LoadConfig()
  46. requestBody := ChatGPTRequestBody{
  47. Model: cfg.Model,
  48. Prompt: msg,
  49. MaxTokens: cfg.MaxTokens,
  50. Temperature: cfg.Temperature,
  51. TopP: 1,
  52. FrequencyPenalty: 0,
  53. PresencePenalty: 0,
  54. }
  55. requestData, err := json.Marshal(requestBody)
  56. if err != nil {
  57. return "", err
  58. }
  59. logger.Info(fmt.Sprintf("request gtp json string : %v", string(requestData)))
  60. req, err := http.NewRequest("POST", BASEURL+"completions", bytes.NewBuffer(requestData))
  61. if err != nil {
  62. return "", err
  63. }
  64. apiKey := config.LoadConfig().ApiKey
  65. req.Header.Set("Content-Type", "application/json")
  66. req.Header.Set("Authorization", "Bearer "+apiKey)
  67. client := &http.Client{Timeout: 30 * time.Second}
  68. response, err := client.Do(req)
  69. if err != nil {
  70. return "", err
  71. }
  72. defer response.Body.Close()
  73. if response.StatusCode != 200 {
  74. body, _ := ioutil.ReadAll(response.Body)
  75. return "", errors.New(fmt.Sprintf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details: %v ", response.StatusCode, string(body)))
  76. }
  77. body, err := ioutil.ReadAll(response.Body)
  78. if err != nil {
  79. return "", err
  80. }
  81. logger.Info(fmt.Sprintf("response gtp json string : %v", string(body)))
  82. gptResponseBody := &ChatGPTResponseBody{}
  83. err = json.Unmarshal(body, gptResponseBody)
  84. if err != nil {
  85. return "", err
  86. }
  87. var reply string
  88. if len(gptResponseBody.Choices) > 0 {
  89. reply = gptResponseBody.Choices[0].Text
  90. }
  91. return reply, nil
  92. }