浏览代码

Merge pull request #19 from eryajf/feat_retryreq

二丫讲梵 2 年之前
父节点
当前提交
564e257bc2
共有 4 个文件被更改,包括 36 次插入41 次删除
  1. 3 1
      go.mod
  2. 9 2
      go.sum
  3. 24 37
      gpt/gpt.go
  4. 0 1
      main.go

+ 3 - 1
go.mod

@@ -3,6 +3,8 @@ module github.com/eryajf/chatgpt-dingtalk
 go 1.17
 
 require (
-	github.com/eatmoreapple/openwechat v1.2.3
+	github.com/go-resty/resty/v2 v2.7.0
 	github.com/patrickmn/go-cache v2.1.0+incompatible
 )
+
+require golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect

+ 9 - 2
go.sum

@@ -1,4 +1,11 @@
-github.com/eatmoreapple/openwechat v1.2.3 h1:8AO+nvXwHVTM/7Gk7y6IZ2/hjnILTLQztWmJnPhPB+k=
-github.com/eatmoreapple/openwechat v1.2.3/go.mod h1:61HOzTyvLobGdgWhL68jfGNwTJEv0mhQ1miCXQrvWU8=
+github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY=
+github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I=
 github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
 github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
+golang.org/x/net v0.0.0-20211029224645-99673261e6eb h1:pirldcYWx7rx7kE5r+9WsOXPXK0+WH5+uZ7uPmJ44uM=
+golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

+ 24 - 37
gpt/gpt.go

@@ -1,19 +1,26 @@
 package gpt
 
 import (
-	"bytes"
 	"encoding/json"
 	"fmt"
-	"io/ioutil"
-	"net/http"
+	"time"
 
 	"github.com/eryajf/chatgpt-dingtalk/config"
 	"github.com/eryajf/chatgpt-dingtalk/public/logger"
+	"github.com/go-resty/resty/v2"
 )
 
 const BASEURL = "https://api.openai.com/v1/"
 
-// ChatGPTResponseBody 请求体
+// ChatGPTRequestBody 请求体
+type ChatGPTRequestBody struct {
+	Model       string  `json:"model"`
+	Prompt      string  `json:"prompt"`
+	MaxTokens   uint    `json:"max_tokens"`
+	Temperature float64 `json:"temperature"`
+}
+
+// ChatGPTResponseBody 响应体
 type ChatGPTResponseBody struct {
 	ID      string                 `json:"id"`
 	Object  string                 `json:"object"`
@@ -30,14 +37,6 @@ type ChoiceItem struct {
 	FinishReason string `json:"finish_reason"`
 }
 
-// ChatGPTRequestBody 响应体
-type ChatGPTRequestBody struct {
-	Model       string  `json:"model"`
-	Prompt      string  `json:"prompt"`
-	MaxTokens   uint    `json:"max_tokens"`
-	Temperature float64 `json:"temperature"`
-}
-
 // Completions gtp文本模型回复
 //curl https://api.openai.com/v1/completions
 //-H "Content-Type: application/json"
@@ -51,41 +50,29 @@ func Completions(msg string) (string, error) {
 		MaxTokens:   cfg.MaxTokens,
 		Temperature: cfg.Temperature,
 	}
-	requestData, err := json.Marshal(requestBody)
-	if err != nil {
-		return "", err
-	}
-	logger.Info(fmt.Sprintf("request gtp json string : %v", string(requestData)))
-	req, err := http.NewRequest("POST", BASEURL+"completions", bytes.NewBuffer(requestData))
-	if err != nil {
-		return "", err
-	}
 
-	req.Header.Set("Content-Type", "application/json")
-	req.Header.Set("Authorization", "Bearer "+cfg.ApiKey)
-	client := &http.Client{Timeout: cfg.SessionTimeout}
-	response, err := client.Do(req)
-	if err != nil {
-		return "", err
-	}
-	defer response.Body.Close()
+	client := resty.New().
+		SetRetryCount(2).
+		SetRetryWaitTime(1*time.Second).
+		SetTimeout(cfg.SessionTimeout).
+		SetHeader("Content-Type", "application/json").
+		SetHeader("Authorization", "Bearer "+cfg.ApiKey)
 
-	body, err := ioutil.ReadAll(response.Body)
+	rsp, err := client.R().SetBody(requestBody).Post(BASEURL + "completions")
 	if err != nil {
-		return "", err
+		return "", fmt.Errorf("request openai failed, err : %v", err)
 	}
-
-	if response.StatusCode != 200 {
-		return "", fmt.Errorf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details:  %v ", response.StatusCode, string(body))
+	if rsp.StatusCode() != 200 {
+		return "", fmt.Errorf("gtp api status code not equals 200, code is %d ,details:  %v ", rsp.StatusCode(), string(rsp.Body()))
+	} else {
+		logger.Info(fmt.Sprintf("response gtp json string : %v", string(rsp.Body())))
 	}
-	logger.Info(fmt.Sprintf("response gtp json string : %v", string(body)))
 
 	gptResponseBody := &ChatGPTResponseBody{}
-	err = json.Unmarshal(body, gptResponseBody)
+	err = json.Unmarshal(rsp.Body(), gptResponseBody)
 	if err != nil {
 		return "", err
 	}
-
 	var reply string
 	if len(gptResponseBody.Choices) > 0 {
 		reply = gptResponseBody.Choices[0].Text

+ 0 - 1
main.go

@@ -110,7 +110,6 @@ func getRequestText(rmsg public.ReceiveMsg) string {
 	// 1.去除空格以及换行
 	requestText := strings.TrimSpace(rmsg.Text.Content)
 	requestText = strings.Trim(rmsg.Text.Content, "\n")
-
 	// 2.替换掉当前用户名称
 	replaceText := "@" + rmsg.SenderNick
 	requestText = strings.TrimSpace(strings.ReplaceAll(rmsg.Text.Content, replaceText, ""))