|
@@ -1,13 +1,11 @@
|
|
|
-package gtp
|
|
|
+package gpt
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
"encoding/json"
|
|
|
- "errors"
|
|
|
"fmt"
|
|
|
"io/ioutil"
|
|
|
"net/http"
|
|
|
- "time"
|
|
|
|
|
|
"github.com/eryajf/chatgpt-dingtalk/config"
|
|
|
"github.com/eryajf/chatgpt-dingtalk/public/logger"
|
|
@@ -34,13 +32,10 @@ type ChoiceItem struct {
|
|
|
|
|
|
// ChatGPTRequestBody 响应体
|
|
|
type ChatGPTRequestBody struct {
|
|
|
- Model string `json:"model"`
|
|
|
- Prompt string `json:"prompt"`
|
|
|
- MaxTokens uint `json:"max_tokens"`
|
|
|
- Temperature float64 `json:"temperature"`
|
|
|
- TopP int `json:"top_p"`
|
|
|
- FrequencyPenalty int `json:"frequency_penalty"`
|
|
|
- PresencePenalty int `json:"presence_penalty"`
|
|
|
+ Model string `json:"model"`
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
+ MaxTokens uint `json:"max_tokens"`
|
|
|
+ Temperature float64 `json:"temperature"`
|
|
|
}
|
|
|
|
|
|
// Completions gtp文本模型回复
|
|
@@ -51,13 +46,10 @@ type ChatGPTRequestBody struct {
|
|
|
func Completions(msg string) (string, error) {
|
|
|
cfg := config.LoadConfig()
|
|
|
requestBody := ChatGPTRequestBody{
|
|
|
- Model: cfg.Model,
|
|
|
- Prompt: msg,
|
|
|
- MaxTokens: cfg.MaxTokens,
|
|
|
- Temperature: cfg.Temperature,
|
|
|
- TopP: 1,
|
|
|
- FrequencyPenalty: 0,
|
|
|
- PresencePenalty: 0,
|
|
|
+ Model: cfg.Model,
|
|
|
+ Prompt: msg,
|
|
|
+ MaxTokens: cfg.MaxTokens,
|
|
|
+ Temperature: cfg.Temperature,
|
|
|
}
|
|
|
requestData, err := json.Marshal(requestBody)
|
|
|
if err != nil {
|
|
@@ -69,23 +61,23 @@ func Completions(msg string) (string, error) {
|
|
|
return "", err
|
|
|
}
|
|
|
|
|
|
- apiKey := config.LoadConfig().ApiKey
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
- req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
- client := &http.Client{Timeout: 30 * time.Second}
|
|
|
+ req.Header.Set("Authorization", "Bearer "+cfg.ApiKey)
|
|
|
+ client := &http.Client{Timeout: cfg.SessionTimeout}
|
|
|
response, err := client.Do(req)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
defer response.Body.Close()
|
|
|
- if response.StatusCode != 200 {
|
|
|
- body, _ := ioutil.ReadAll(response.Body)
|
|
|
- return "", errors.New(fmt.Sprintf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details: %v ", response.StatusCode, string(body)))
|
|
|
- }
|
|
|
+
|
|
|
body, err := ioutil.ReadAll(response.Body)
|
|
|
if err != nil {
|
|
|
return "", err
|
|
|
}
|
|
|
+
|
|
|
+ if response.StatusCode != 200 {
|
|
|
+ return "", fmt.Errorf("请求GTP出错了,gtp api status code not equals 200,code is %d ,details: %v ", response.StatusCode, string(body))
|
|
|
+ }
|
|
|
logger.Info(fmt.Sprintf("response gtp json string : %v", string(body)))
|
|
|
|
|
|
gptResponseBody := &ChatGPTResponseBody{}
|