gtp.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package gtp
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "io/ioutil"
  8. "log"
  9. "net/http"
  10. "time"
  11. "github.com/eryajf/chatgpt-dingtalk/config"
  12. "github.com/eryajf/chatgpt-dingtalk/public/logger"
  13. )
  14. const BASEURL = "https://api.openai.com/v1/"
  15. // ChatGPTResponseBody 请求体
  16. type ChatGPTResponseBody struct {
  17. ID string `json:"id"`
  18. Object string `json:"object"`
  19. Created int `json:"created"`
  20. Model string `json:"model"`
  21. Choices []ChoiceItem `json:"choices"`
  22. Usage map[string]interface{} `json:"usage"`
  23. }
  24. type ChoiceItem struct {
  25. Text string `json:"text"`
  26. Index int `json:"index"`
  27. Logprobs int `json:"logprobs"`
  28. FinishReason string `json:"finish_reason"`
  29. }
  30. // ChatGPTRequestBody 响应体
  31. type ChatGPTRequestBody struct {
  32. Model string `json:"model"`
  33. Prompt string `json:"prompt"`
  34. MaxTokens uint `json:"max_tokens"`
  35. Temperature float64 `json:"temperature"`
  36. TopP int `json:"top_p"`
  37. FrequencyPenalty int `json:"frequency_penalty"`
  38. PresencePenalty int `json:"presence_penalty"`
  39. }
  40. // Completions gtp文本模型回复
  41. //curl https://api.openai.com/v1/completions
  42. //-H "Content-Type: application/json"
  43. //-H "Authorization: Bearer your chatGPT key"
  44. //-d '{"model": "text-davinci-003", "prompt": "give me good song", "temperature": 0, "max_tokens": 7}'
  45. func Completions(msg string) (string, error) {
  46. cfg := config.LoadConfig()
  47. requestBody := ChatGPTRequestBody{
  48. Model: cfg.Model,
  49. Prompt: msg,
  50. MaxTokens: cfg.MaxTokens,
  51. Temperature: cfg.Temperature,
  52. TopP: 1,
  53. FrequencyPenalty: 0,
  54. PresencePenalty: 0,
  55. }
  56. requestData, err := json.Marshal(requestBody)
  57. if err != nil {
  58. return "", err
  59. }
  60. logger.Info(fmt.Sprintf("request gtp json string : %v", string(requestData)))
  61. req, err := http.NewRequest("POST", BASEURL+"completions", bytes.NewBuffer(requestData))
  62. if err != nil {
  63. return "", err
  64. }
  65. apiKey := config.LoadConfig().ApiKey
  66. req.Header.Set("Content-Type", "application/json")
  67. req.Header.Set("Authorization", "Bearer "+apiKey)
  68. client := &http.Client{Timeout: 30 * time.Second}
  69. response, err := client.Do(req)
  70. if err != nil {
  71. return "", err
  72. }
  73. defer response.Body.Close()
  74. if response.StatusCode != 200 {
  75. body, _ := ioutil.ReadAll(response.Body)
  76. return "", errors.New(fmt.Sprintf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details: %v ", response.StatusCode, string(body)))
  77. }
  78. body, err := ioutil.ReadAll(response.Body)
  79. if err != nil {
  80. return "", err
  81. }
  82. logger.Info(fmt.Sprintf("response gtp json string : %v", string(body)))
  83. gptResponseBody := &ChatGPTResponseBody{}
  84. log.Println(string(body))
  85. err = json.Unmarshal(body, gptResponseBody)
  86. if err != nil {
  87. return "", err
  88. }
  89. var reply string
  90. if len(gptResponseBody.Choices) > 0 {
  91. reply = gptResponseBody.Choices[0].Text
  92. }
  93. logger.Info(fmt.Sprintf("gpt response text: %s ", reply))
  94. return reply, nil
  95. }