gpt.go 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. package gpt
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io/ioutil"
  7. "net/http"
  8. "github.com/eryajf/chatgpt-dingtalk/config"
  9. "github.com/eryajf/chatgpt-dingtalk/public/logger"
  10. )
  11. const BASEURL = "https://api.openai.com/v1/"
  12. // ChatGPTResponseBody 请求体
  13. type ChatGPTResponseBody struct {
  14. ID string `json:"id"`
  15. Object string `json:"object"`
  16. Created int `json:"created"`
  17. Model string `json:"model"`
  18. Choices []ChoiceItem `json:"choices"`
  19. Usage map[string]interface{} `json:"usage"`
  20. }
  21. type ChoiceItem struct {
  22. Text string `json:"text"`
  23. Index int `json:"index"`
  24. Logprobs int `json:"logprobs"`
  25. FinishReason string `json:"finish_reason"`
  26. }
  27. // ChatGPTRequestBody 响应体
  28. type ChatGPTRequestBody struct {
  29. Model string `json:"model"`
  30. Prompt string `json:"prompt"`
  31. MaxTokens uint `json:"max_tokens"`
  32. Temperature float64 `json:"temperature"`
  33. }
  34. // Completions gtp文本模型回复
  35. //curl https://api.openai.com/v1/completions
  36. //-H "Content-Type: application/json"
  37. //-H "Authorization: Bearer your chatGPT key"
  38. //-d '{"model": "text-davinci-003", "prompt": "give me good song", "temperature": 0, "max_tokens": 7}'
  39. func Completions(msg string) (string, error) {
  40. cfg := config.LoadConfig()
  41. requestBody := ChatGPTRequestBody{
  42. Model: cfg.Model,
  43. Prompt: msg,
  44. MaxTokens: cfg.MaxTokens,
  45. Temperature: cfg.Temperature,
  46. }
  47. requestData, err := json.Marshal(requestBody)
  48. if err != nil {
  49. return "", err
  50. }
  51. logger.Info(fmt.Sprintf("request gtp json string : %v", string(requestData)))
  52. req, err := http.NewRequest("POST", BASEURL+"completions", bytes.NewBuffer(requestData))
  53. if err != nil {
  54. return "", err
  55. }
  56. req.Header.Set("Content-Type", "application/json")
  57. req.Header.Set("Authorization", "Bearer "+cfg.ApiKey)
  58. client := &http.Client{Timeout: cfg.SessionTimeout}
  59. response, err := client.Do(req)
  60. if err != nil {
  61. return "", err
  62. }
  63. defer response.Body.Close()
  64. body, err := ioutil.ReadAll(response.Body)
  65. if err != nil {
  66. return "", err
  67. }
  68. if response.StatusCode != 200 {
  69. return "", fmt.Errorf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details: %v ", response.StatusCode, string(body))
  70. }
  71. logger.Info(fmt.Sprintf("response gtp json string : %v", string(body)))
  72. gptResponseBody := &ChatGPTResponseBody{}
  73. err = json.Unmarshal(body, gptResponseBody)
  74. if err != nil {
  75. return "", err
  76. }
  77. var reply string
  78. if len(gptResponseBody.Choices) > 0 {
  79. reply = gptResponseBody.Choices[0].Text
  80. }
  81. return reply, nil
  82. }