|
@@ -1,19 +1,26 @@
|
|
|
package gpt
|
|
|
|
|
|
import (
|
|
|
- "bytes"
|
|
|
"encoding/json"
|
|
|
"fmt"
|
|
|
- "io/ioutil"
|
|
|
- "net/http"
|
|
|
+ "time"
|
|
|
|
|
|
"github.com/eryajf/chatgpt-dingtalk/config"
|
|
|
"github.com/eryajf/chatgpt-dingtalk/public/logger"
|
|
|
+ "github.com/go-resty/resty/v2"
|
|
|
)
|
|
|
|
|
|
const BASEURL = "https://api.openai.com/v1/"
|
|
|
|
|
|
-// ChatGPTResponseBody 请求体
|
|
|
+// ChatGPTRequestBody 请求体
|
|
|
+type ChatGPTRequestBody struct {
|
|
|
+ Model string `json:"model"`
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
+ MaxTokens uint `json:"max_tokens"`
|
|
|
+ Temperature float64 `json:"temperature"`
|
|
|
+}
|
|
|
+
|
|
|
+// ChatGPTResponseBody 响应体
|
|
|
type ChatGPTResponseBody struct {
|
|
|
ID string `json:"id"`
|
|
|
Object string `json:"object"`
|
|
@@ -30,14 +37,6 @@ type ChoiceItem struct {
|
|
|
FinishReason string `json:"finish_reason"`
|
|
|
}
|
|
|
|
|
|
-// ChatGPTRequestBody 响应体
|
|
|
-type ChatGPTRequestBody struct {
|
|
|
- Model string `json:"model"`
|
|
|
- Prompt string `json:"prompt"`
|
|
|
- MaxTokens uint `json:"max_tokens"`
|
|
|
- Temperature float64 `json:"temperature"`
|
|
|
-}
|
|
|
-
|
|
|
// Completions gtp文本模型回复
|
|
|
//curl https://api.openai.com/v1/completions
|
|
|
//-H "Content-Type: application/json"
|
|
@@ -51,41 +50,29 @@ func Completions(msg string) (string, error) {
|
|
|
MaxTokens: cfg.MaxTokens,
|
|
|
Temperature: cfg.Temperature,
|
|
|
}
|
|
|
- requestData, err := json.Marshal(requestBody)
|
|
|
- if err != nil {
|
|
|
- return "", err
|
|
|
- }
|
|
|
- logger.Info(fmt.Sprintf("request gtp json string : %v", string(requestData)))
|
|
|
- req, err := http.NewRequest("POST", BASEURL+"completions", bytes.NewBuffer(requestData))
|
|
|
- if err != nil {
|
|
|
- return "", err
|
|
|
- }
|
|
|
|
|
|
- req.Header.Set("Content-Type", "application/json")
|
|
|
- req.Header.Set("Authorization", "Bearer "+cfg.ApiKey)
|
|
|
- client := &http.Client{Timeout: cfg.SessionTimeout}
|
|
|
- response, err := client.Do(req)
|
|
|
- if err != nil {
|
|
|
- return "", err
|
|
|
- }
|
|
|
- defer response.Body.Close()
|
|
|
+ client := resty.New().
|
|
|
+ SetRetryCount(2).
|
|
|
+ SetRetryWaitTime(1*time.Second).
|
|
|
+ SetTimeout(cfg.SessionTimeout).
|
|
|
+ SetHeader("Content-Type", "application/json").
|
|
|
+ SetHeader("Authorization", "Bearer "+cfg.ApiKey)
|
|
|
|
|
|
- body, err := ioutil.ReadAll(response.Body)
|
|
|
+ rsp, err := client.R().SetBody(requestBody).Post(BASEURL + "completions")
|
|
|
if err != nil {
|
|
|
- return "", err
|
|
|
+ return "", fmt.Errorf("request openai failed, err : %v", err)
|
|
|
}
|
|
|
-
|
|
|
- if response.StatusCode != 200 {
|
|
|
- return "", fmt.Errorf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details: %v ", response.StatusCode, string(body))
|
|
|
+ if rsp.StatusCode() != 200 {
|
|
|
+ return "", fmt.Errorf("gtp api status code not equals 200, code is %d ,details: %v ", rsp.StatusCode(), string(rsp.Body()))
|
|
|
+ } else {
|
|
|
+ logger.Info(fmt.Sprintf("response gtp json string : %v", string(rsp.Body())))
|
|
|
}
|
|
|
- logger.Info(fmt.Sprintf("response gtp json string : %v", string(body)))
|
|
|
|
|
|
gptResponseBody := &ChatGPTResponseBody{}
|
|
|
- err = json.Unmarshal(body, gptResponseBody)
|
|
|
+ err = json.Unmarshal(rsp.Body(), gptResponseBody)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
-
|
|
|
var reply string
|
|
|
if len(gptResponseBody.Choices) > 0 {
|
|
|
reply = gptResponseBody.Choices[0].Text
|