|
@@ -6,11 +6,12 @@ import (
|
|
|
"net/url"
|
|
|
"time"
|
|
|
|
|
|
- gogpt "github.com/sashabaranov/go-gpt3"
|
|
|
+ "github.com/eryajf/chatgpt-dingtalk/public"
|
|
|
+ openai "github.com/sashabaranov/go-openai"
|
|
|
)
|
|
|
|
|
|
type ChatGPT struct {
|
|
|
- client *gogpt.Client
|
|
|
+ client *openai.Client
|
|
|
ctx context.Context
|
|
|
userId string
|
|
|
maxQuestionLen int
|
|
@@ -23,13 +24,15 @@ type ChatGPT struct {
|
|
|
ChatContext *ChatContext
|
|
|
}
|
|
|
|
|
|
-func New(apiKey, proxyUrl, userId string, timeOut time.Duration) *ChatGPT {
|
|
|
+func New(userId string) *ChatGPT {
|
|
|
var ctx context.Context
|
|
|
var cancel func()
|
|
|
- if timeOut == 0 {
|
|
|
+
|
|
|
+ // public.Config.BaseURL, public.Config.ApiKey, public.Config.HttpProxy
|
|
|
+ if public.Config.SessionTimeout == 0 {
|
|
|
ctx, cancel = context.WithCancel(context.Background())
|
|
|
} else {
|
|
|
- ctx, cancel = context.WithTimeout(context.Background(), timeOut)
|
|
|
+ ctx, cancel = context.WithTimeout(context.Background(), public.Config.SessionTimeout)
|
|
|
}
|
|
|
timeOutChan := make(chan struct{}, 1)
|
|
|
go func() {
|
|
@@ -37,22 +40,26 @@ func New(apiKey, proxyUrl, userId string, timeOut time.Duration) *ChatGPT {
|
|
|
timeOutChan <- struct{}{} // 发送超时信号,或是提示结束,用于聊天机器人场景,配合GetTimeOutChan() 使用
|
|
|
}()
|
|
|
|
|
|
- config := gogpt.DefaultConfig(apiKey)
|
|
|
- if proxyUrl != "" {
|
|
|
+ config := openai.DefaultConfig(public.Config.ApiKey)
|
|
|
+ if public.Config.HttpProxy != "" {
|
|
|
config.HTTPClient.Transport = &http.Transport{
|
|
|
// 设置代理
|
|
|
Proxy: func(req *http.Request) (*url.URL, error) {
|
|
|
- return url.Parse(proxyUrl)
|
|
|
+ return url.Parse(public.Config.HttpProxy)
|
|
|
}}
|
|
|
}
|
|
|
+ if public.Config.BaseURL != "" {
|
|
|
+ config.BaseURL = public.Config.BaseURL + "/v1"
|
|
|
+ }
|
|
|
+
|
|
|
return &ChatGPT{
|
|
|
- client: gogpt.NewClientWithConfig(config),
|
|
|
+ client: openai.NewClientWithConfig(config),
|
|
|
ctx: ctx,
|
|
|
userId: userId,
|
|
|
maxQuestionLen: 2048, // 最大问题长度
|
|
|
maxAnswerLen: 2048, // 最大答案长度
|
|
|
maxText: 4096, // 最大文本 = 问题 + 回答, 接口限制
|
|
|
- timeOut: timeOut,
|
|
|
+ timeOut: public.Config.SessionTimeout,
|
|
|
doneChan: timeOutChan,
|
|
|
cancel: func() {
|
|
|
cancel()
|
|
@@ -75,30 +82,3 @@ func (c *ChatGPT) SetMaxQuestionLen(maxQuestionLen int) int {
|
|
|
c.maxQuestionLen = maxQuestionLen
|
|
|
return c.maxQuestionLen
|
|
|
}
|
|
|
-
|
|
|
-// func (c *ChatGPT) Chat(question string) (answer string, err error) {
|
|
|
-// question = question + "."
|
|
|
-// if len(question) > c.maxQuestionLen {
|
|
|
-// return "", OverMaxQuestionLength
|
|
|
-// }
|
|
|
-// if len(question)+c.maxAnswerLen > c.maxText {
|
|
|
-// question = question[:c.maxText-c.maxAnswerLen]
|
|
|
-// }
|
|
|
-// req := gogpt.CompletionRequest{
|
|
|
-// Model: gogpt.GPT3TextDavinci003,
|
|
|
-// MaxTokens: c.maxAnswerLen,
|
|
|
-// Prompt: question,
|
|
|
-// Temperature: 0.9,
|
|
|
-// TopP: 1,
|
|
|
-// N: 1,
|
|
|
-// FrequencyPenalty: 0,
|
|
|
-// PresencePenalty: 0.5,
|
|
|
-// User: c.userId,
|
|
|
-// Stop: []string{},
|
|
|
-// }
|
|
|
-// resp, err := c.client.CreateCompletion(c.ctx, req)
|
|
|
-// if err != nil {
|
|
|
-// return "", err
|
|
|
-// }
|
|
|
-// return formatAnswer(resp.Choices[0].Text), err
|
|
|
-// }
|