Browse Source

fix: 修复 session_timeout 配置项不生效的bug,优化相关缓存逻辑,同时余额、查对话功能权限优化 (#210)

Finly 2 years ago
parent
commit
fb5500611a
3 changed files with 45 additions and 23 deletions
  1. 6 1
      pkg/cache/user_base.go
  2. 1 5
      pkg/chatgpt/chatgpt.go
  3. 38 17
      pkg/process/process_request.go

+ 6 - 1
pkg/cache/user_base.go

@@ -3,6 +3,7 @@ package cache
 import (
 	"time"
 
+	"github.com/eryajf/chatgpt-dingtalk/config"
 	"github.com/patrickmn/go-cache"
 )
 
@@ -33,7 +34,11 @@ type UserService struct {
 	cache *cache.Cache
 }
 
+var Config *config.Configuration
+
 // NewUserService 创建新的业务层
 func NewUserService() UserServiceInterface {
-	return &UserService{cache: cache.New(time.Hour*2, time.Hour*5)}
+	// 加载配置
+	Config = config.LoadConfig()
+	return &UserService{cache: cache.New(Config.SessionTimeout, time.Hour*1)}
 }

+ 1 - 5
pkg/chatgpt/chatgpt.go

@@ -28,11 +28,7 @@ func New(userId string) *ChatGPT {
 	var ctx context.Context
 	var cancel func()
 
-	if public.Config.SessionTimeout == 0 {
-		ctx, cancel = context.WithCancel(context.Background())
-	} else {
-		ctx, cancel = context.WithTimeout(context.Background(), public.Config.SessionTimeout)
-	}
+	ctx, cancel = context.WithTimeout(context.Background(), 600)
 	timeOutChan := make(chan struct{}, 1)
 	go func() {
 		<-ctx.Done()

+ 38 - 17
pkg/process/process_request.go

@@ -3,6 +3,7 @@ package process
 import (
 	"fmt"
 	"strings"
+	"time"
 
 	"github.com/eryajf/chatgpt-dingtalk/pkg/db"
 	"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
@@ -15,16 +16,20 @@ import (
 func ProcessRequest(rmsg *dingbot.ReceiveMsg) error {
 	if CheckRequestTimes(rmsg) {
 		content := strings.TrimSpace(rmsg.Text.Content)
+		timeoutStr := ""
+		if content != public.Config.DefaultMode {
+			timeoutStr = fmt.Sprintf("\n\n>%s 后将恢复默认聊天模式:%s", FormatTimeDuation(public.Config.SessionTimeout), public.Config.DefaultMode)
+		}
 		switch content {
 		case "单聊":
 			public.UserService.SetUserMode(rmsg.GetSenderIdentifier(), content)
-			_, err := rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf("**[Concentrate] 现在进入与 %s 的单聊模式**", rmsg.SenderNick))
+			_, err := rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf("**[Concentrate] 现在进入与 %s 的单聊模式**%s", rmsg.SenderNick, timeoutStr))
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 			}
 		case "串聊":
 			public.UserService.SetUserMode(rmsg.GetSenderIdentifier(), content)
-			_, err := rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf("**[Concentrate] 现在进入与 %s 的串聊模式**", rmsg.SenderNick))
+			_, err := rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), fmt.Sprintf("**[Concentrate] 现在进入与 %s 的串聊模式**%s", rmsg.SenderNick, timeoutStr))
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 			}
@@ -62,25 +67,28 @@ func ProcessRequest(rmsg *dingbot.ReceiveMsg) error {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 			}
 		case "余额":
-			cacheMsg := public.UserService.GetUserMode("system_balance")
-			if cacheMsg == "" {
-				rst, err := public.GetBalance()
+			if public.JudgeAdminUsers(rmsg.SenderStaffId) {
+				cacheMsg := public.UserService.GetUserMode("system_balance")
+				if cacheMsg == "" {
+					rst, err := public.GetBalance()
+					if err != nil {
+						logger.Warning(fmt.Errorf("get balance error: %v", err))
+						return err
+					}
+					cacheMsg = rst
+				}
+				_, err := rmsg.ReplyToDingtalk(string(dingbot.TEXT), cacheMsg)
 				if err != nil {
-					logger.Warning(fmt.Errorf("get balance error: %v", err))
-					return err
+					logger.Warning(fmt.Errorf("send message error: %v", err))
 				}
-				cacheMsg = rst
-			}
-			// cacheMsg := "官方暂时改写了余额接口,因此暂不提供查询余额功能!2023-04-03"
-			_, err := rmsg.ReplyToDingtalk(string(dingbot.TEXT), cacheMsg)
-			if err != nil {
-				logger.Warning(fmt.Errorf("send message error: %v", err))
 			}
 		case "查对话":
-			msg := "使用如下指令进行查询:\n\n---\n\n**#查对话 username:张三**\n\n---\n\n需要注意格式必须严格与上边一致,否则将会查询失败\n\n只有程序系统管理员有权限查询,即config.yml中的admin_users指定的人员。"
-			_, err := rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), msg)
-			if err != nil {
-				logger.Warning(fmt.Errorf("send message error: %v", err))
+			if public.JudgeAdminUsers(rmsg.SenderStaffId) {
+				msg := "使用如下指令进行查询:\n\n---\n\n**#查对话 username:张三**\n\n---\n\n需要注意格式必须严格与上边一致,否则将会查询失败\n\n只有程序系统管理员有权限查询,即config.yml中的admin_users指定的人员。"
+				_, err := rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), msg)
+				if err != nil {
+					logger.Warning(fmt.Errorf("send message error: %v", err))
+				}
 			}
 		default:
 			if public.FirstCheck(rmsg) {
@@ -223,6 +231,19 @@ func Do(mode string, rmsg *dingbot.ReceiveMsg) error {
 	}
 	return nil
 }
+// FormatTimeDuation 格式化时间
+// 主要提示单聊/群聊切换时多久后恢复默认聊天模式
+func FormatTimeDuation(duration time.Duration) string {
+	minutes := int64(duration.Minutes())
+	seconds := int64(duration.Seconds()) - minutes*60
+	timeoutStr := ""
+	if seconds == 0 {
+		timeoutStr = fmt.Sprintf("%d分钟", minutes)
+	} else {
+		timeoutStr = fmt.Sprintf("%d分%d秒", minutes, seconds)
+	}
+	return timeoutStr
+}
 
 // FormatMarkdown 格式化Markdown
 // 主要修复ChatGPT返回多行代码块,钉钉会将代码块中的#当作Markdown语法里的标题来处理,这里进行下转义