Browse Source

Merge pull request #17 from eryajf/fix_timeout

二丫讲梵 2 years ago
parent
commit
25872fafd5
4 changed files with 26 additions and 40 deletions
  1. 2 2
      config.dev.json
  2. 5 10
      config/config.go
  3. 16 24
      gtp/gtp.go
  4. 3 4
      main.go

+ 2 - 2
config.dev.json

@@ -1,7 +1,7 @@
 {
     "api_key": "xxxxxxxxx",
-    "session_timeout": 60,
-    "max_tokens": 1024,
+    "session_timeout": 180,
+    "max_tokens": 2000,
     "model": "text-davinci-003",
     "temperature": 0.9,
     "session_clear_token": "清空会话"

+ 5 - 10
config/config.go

@@ -34,13 +34,7 @@ var once sync.Once
 func LoadConfig() *Configuration {
 	once.Do(func() {
 		// 从文件中读取
-		config = &Configuration{
-			SessionTimeout:    60,
-			MaxTokens:         512,
-			Model:             "text-davinci-003",
-			Temperature:       0.9,
-			SessionClearToken: "下一个问题",
-		}
+		config = &Configuration{}
 		f, err := os.Open("config.json")
 		if err != nil {
 			logger.Danger("open config err: %v", err)
@@ -55,7 +49,6 @@ func LoadConfig() *Configuration {
 		}
 
 		// 如果环境变量有配置,读取环境变量
-		// 有环境变量使用环境变量
 		ApiKey := os.Getenv("APIKEY")
 		SessionTimeout := os.Getenv("SESSION_TIMEOUT")
 		Model := os.Getenv("MODEL")
@@ -66,12 +59,14 @@ func LoadConfig() *Configuration {
 			config.ApiKey = ApiKey
 		}
 		if SessionTimeout != "" {
-			duration, err := time.ParseDuration(SessionTimeout)
+			duration, err := strconv.ParseInt(SessionTimeout, 10, 64)
 			if err != nil {
 				logger.Danger(fmt.Sprintf("config session timeout err: %v ,get is %v", err, SessionTimeout))
 				return
 			}
-			config.SessionTimeout = duration
+			config.SessionTimeout = time.Duration(duration) * time.Second
+		} else {
+			config.SessionTimeout = time.Duration(config.SessionTimeout) * time.Second
 		}
 		if Model != "" {
 			config.Model = Model

+ 16 - 24
gtp/gtp.go

@@ -1,13 +1,11 @@
-package gtp
+package gpt
 
 import (
 	"bytes"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"io/ioutil"
 	"net/http"
-	"time"
 
 	"github.com/eryajf/chatgpt-dingtalk/config"
 	"github.com/eryajf/chatgpt-dingtalk/public/logger"
@@ -34,13 +32,10 @@ type ChoiceItem struct {
 
 // ChatGPTRequestBody 响应体
 type ChatGPTRequestBody struct {
-	Model            string  `json:"model"`
-	Prompt           string  `json:"prompt"`
-	MaxTokens        uint    `json:"max_tokens"`
-	Temperature      float64 `json:"temperature"`
-	TopP             int     `json:"top_p"`
-	FrequencyPenalty int     `json:"frequency_penalty"`
-	PresencePenalty  int     `json:"presence_penalty"`
+	Model       string  `json:"model"`
+	Prompt      string  `json:"prompt"`
+	MaxTokens   uint    `json:"max_tokens"`
+	Temperature float64 `json:"temperature"`
 }
 
 // Completions gtp文本模型回复
@@ -51,13 +46,10 @@ type ChatGPTRequestBody struct {
 func Completions(msg string) (string, error) {
 	cfg := config.LoadConfig()
 	requestBody := ChatGPTRequestBody{
-		Model:            cfg.Model,
-		Prompt:           msg,
-		MaxTokens:        cfg.MaxTokens,
-		Temperature:      cfg.Temperature,
-		TopP:             1,
-		FrequencyPenalty: 0,
-		PresencePenalty:  0,
+		Model:       cfg.Model,
+		Prompt:      msg,
+		MaxTokens:   cfg.MaxTokens,
+		Temperature: cfg.Temperature,
 	}
 	requestData, err := json.Marshal(requestBody)
 	if err != nil {
@@ -69,23 +61,23 @@ func Completions(msg string) (string, error) {
 		return "", err
 	}
 
-	apiKey := config.LoadConfig().ApiKey
 	req.Header.Set("Content-Type", "application/json")
-	req.Header.Set("Authorization", "Bearer "+apiKey)
-	client := &http.Client{Timeout: 30 * time.Second}
+	req.Header.Set("Authorization", "Bearer "+cfg.ApiKey)
+	client := &http.Client{Timeout: cfg.SessionTimeout}
 	response, err := client.Do(req)
 	if err != nil {
 		return "", err
 	}
 	defer response.Body.Close()
-	if response.StatusCode != 200 {
-		body, _ := ioutil.ReadAll(response.Body)
-		return "", errors.New(fmt.Sprintf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details:  %v ", response.StatusCode, string(body)))
-	}
+
 	body, err := ioutil.ReadAll(response.Body)
 	if err != nil {
 		return "", err
 	}
+
+	if response.StatusCode != 200 {
+		return "", fmt.Errorf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details:  %v ", response.StatusCode, string(body))
+	}
 	logger.Info(fmt.Sprintf("response gtp json string : %v", string(body)))
 
 	gptResponseBody := &ChatGPTResponseBody{}

+ 3 - 4
main.go

@@ -7,7 +7,7 @@ import (
 	"net/http"
 	"strings"
 
-	"github.com/eryajf/chatgpt-dingtalk/gtp"
+	"github.com/eryajf/chatgpt-dingtalk/gpt"
 	"github.com/eryajf/chatgpt-dingtalk/public"
 	"github.com/eryajf/chatgpt-dingtalk/public/logger"
 	"github.com/eryajf/chatgpt-dingtalk/service"
@@ -33,7 +33,6 @@ func Start() {
 			return
 		}
 		// TODO: 校验请求
-		// fmt.Println(r.Header)
 		if len(data) == 0 {
 			logger.Warning("回调参数为空,以至于无法正常解析,请检查原因")
 			return
@@ -76,9 +75,9 @@ func ProcessRequest(rmsg public.ReceiveMsg) error {
 	} else {
 		requestText := getRequestText(rmsg)
 		// 获取问题的答案
-		reply, err := gtp.Completions(requestText)
+		reply, err := gpt.Completions(requestText)
 		if err != nil {
-			logger.Info("gtp request error: %v \n", err)
+			logger.Info("gpt request error: %v \n", err)
 			_, err = rmsg.ReplyText("机器人太累了,让她休息会儿,过一会儿再来请求。")
 			if err != nil {
 				logger.Warning("send message error: %v \n", err)