Procházet zdrojové kódy

feat: 支持自定义分组以及用户白名单的能力支持定义系统管理员的能力 (#159)

二丫讲梵 před 2 roky
rodič
revize
fa8a008386
8 změnil soubory, kde provedl 123 přidání a 18 odebrání
  1. 13 2
      README.md
  2. 10 1
      config.example.yml
  3. 31 12
      config/config.go
  4. 3 0
      docker-compose.yml
  5. 8 0
      main.go
  6. 10 2
      pkg/db/chat.go
  7. 9 1
      pkg/process/process_request.go
  8. 39 0
      public/tools.go

+ 13 - 2
README.md

@@ -80,7 +80,9 @@
 - 🔗 自定义api域名:通过配置指定,解决国内服务器无法直接访问openai的问题
 - 🪜 添加代理:通过配置指定,通过给应用注入代理解决国内服务器无法访问的问题
 - 👐 默认模式:支持自定义默认的聊天模式,通过配置化指定
-- 📝 查询对话:通过发送`#查对话 username:xxx`查询xxx的对话历史,可在线预览,可下载到本地。
+- 📝 查询对话:通过发送`#查对话 username:xxx`查询xxx的对话历史,可在线预览,可下载到本地
+- 👹 白名单机制:通过配置指定,支持指定群组名称和用户名称作为白名单,从而实现可控范围与机器人对话
+- 💂‍♀️ 管理员机制:通过配置指定管理员,部分敏感操作,以及一些应用配置,管理员有权限进行操作
 
 ## 使用前提
 
@@ -144,10 +146,11 @@
 ```
 第一种:基于环境变量运行
 # 运行项目
-$ docker run -itd --name chatgpt -p 8090:8090 -v ./data:/app/data --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" -e MAX_REQUEST=0 -e PORT=8090 -e SERVICE_URL="你当前服务外网可访问的URL" -e CHAT_TYPE="0" --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest
+$ docker run -itd --name chatgpt -p 8090:8090 -v ./data:/app/data --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" -e MAX_REQUEST=0 -e PORT=8090 -e SERVICE_URL="你当前服务外网可访问的URL" -e CHAT_TYPE="0" -e ALLOW_GROUPS=a,b -e ALLOW_USERS=a,b ADMIN_USERS=a,b --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest
 ```
 
 `📢 注意:`如果使用docker部署,那么PORT参数不需要进行任何调整。
+`📢 注意:`ALLOW_GROUPS,ALLOW_USERS,ADMIN_USERS三个参数为数组,如果需要指定多个,可用英文逗号分割。
 `📢 注意:`如果服务器节点本身就在国外或者自定义了`BASE_URL`,那么就把`HTTP_PROXY`参数留空即可。
 `📢 注意:`如果使用docker部署,那么proxy地址可以直接使用如上方式部署,`host.docker.internal`会指向容器所在宿主机的IP,只需要更改端口为你的代理端口即可。参见:[Docker容器如何优雅地访问宿主机网络](https://wiki.eryajf.net/pages/674f53/)
 
@@ -352,6 +355,14 @@ port: "8090"
 service_url: "http://chat.eryajf.net"
 # 限定对话类型 0:不限 1:只能单聊 2:只能群聊
 chat_type: "0"
+# 哪些群组可以进行对话,如果留空,则表示允许所有群组,对话聊天是,如下三个满足其一即可通过校验
+allow_groups:
+  - "学无止境"
+# 哪些用户可以进行对话,如果留空,则表示允许所有用户
+allow_users:
+  - "xxx"
+# 指定哪些人为此系统的管理员,如果留空,则表示没有人是管理员
+admin_users:
 ```
 
 ## 常见问题

+ 10 - 1
config.example.yml

@@ -17,4 +17,13 @@ port: "8090"
 # 指定服务的地址,就是当前服务可供外网访问的地址(或者直接理解为你配置在钉钉回调那里的地址),用于生成图片时给钉钉做渲染
 service_url: "http://chat.eryajf.net"
 # 限定对话类型 0:不限 1:只能单聊 2:只能群聊
-chat_type: "0"
+chat_type: "0"
+# 哪些群组可以进行对话,如果留空,则表示允许所有群组
+allow_groups:
+  - "学无止境"
+  - "技术群"
+# 哪些用户可以进行对话,如果留空,则表示允许所有用户
+allow_users:
+  - "xxx"
+# 指定哪些人为此系统的管理员,如果留空,则表示没有人是管理员
+admin_users:

+ 31 - 12
config/config.go

@@ -6,6 +6,7 @@ import (
 	"log"
 	"os"
 	"strconv"
+	"strings"
 	"sync"
 	"time"
 
@@ -35,6 +36,12 @@ type Configuration struct {
 	ServiceURL string `yaml:"service_url"`
 	// 限定对话类型 0:不限 1:单聊 2:群聊
 	ChatType string `yaml:"chat_type"`
+	// 哪些群组可以进行对话
+	AllowGroups []string `yaml:"allow_groups"`
+	// 哪些用户可以进行对话
+	AllowUsers []string `yaml:"allow_users"`
+	// 指定哪些人为此系统的管理员,必须指定,否则所有人都是
+	AdminUsers []string `yaml:"admin_users"`
 }
 
 var config *Configuration
@@ -56,21 +63,18 @@ func LoadConfig() *Configuration {
 
 		// 如果环境变量有配置,读取环境变量
 		apiKey := os.Getenv("APIKEY")
-		baseURL := os.Getenv("BASE_URL")
-		model := os.Getenv("MODEL")
-		sessionTimeout := os.Getenv("SESSION_TIMEOUT")
-		defaultMode := os.Getenv("DEFAULT_MODE")
-		httpProxy := os.Getenv("HTTP_PROXY")
-		maxRequest := os.Getenv("MAX_REQUEST")
-		port := os.Getenv("PORT")
-		serviceURL := os.Getenv("SERVICE_URL")
-		chatType := os.Getenv("CHAT_TYPE")
 		if apiKey != "" {
 			config.ApiKey = apiKey
 		}
+		baseURL := os.Getenv("BASE_URL")
 		if baseURL != "" {
 			config.BaseURL = baseURL
 		}
+		model := os.Getenv("MODEL")
+		if model != "" {
+			config.Model = model
+		}
+		sessionTimeout := os.Getenv("SESSION_TIMEOUT")
 		if sessionTimeout != "" {
 			duration, err := strconv.ParseInt(sessionTimeout, 10, 64)
 			if err != nil {
@@ -81,28 +85,43 @@ func LoadConfig() *Configuration {
 		} else {
 			config.SessionTimeout = time.Duration(config.SessionTimeout) * time.Second
 		}
+		defaultMode := os.Getenv("DEFAULT_MODE")
 		if defaultMode != "" {
 			config.DefaultMode = defaultMode
 		}
+		httpProxy := os.Getenv("HTTP_PROXY")
 		if httpProxy != "" {
 			config.HttpProxy = httpProxy
 		}
-		if model != "" {
-			config.Model = model
-		}
+		maxRequest := os.Getenv("MAX_REQUEST")
 		if maxRequest != "" {
 			newMR, _ := strconv.Atoi(maxRequest)
 			config.MaxRequest = newMR
 		}
+		port := os.Getenv("PORT")
 		if port != "" {
 			config.Port = port
 		}
+		serviceURL := os.Getenv("SERVICE_URL")
 		if serviceURL != "" {
 			config.ServiceURL = serviceURL
 		}
+		chatType := os.Getenv("CHAT_TYPE")
 		if chatType != "" {
 			config.ChatType = chatType
 		}
+		allowGroup := os.Getenv("ALLOW_GROUPS")
+		if allowGroup != "" {
+			config.AllowGroups = strings.Split(allowGroup, ",")
+		}
+		allowUsers := os.Getenv("ALLOW_USERS")
+		if allowUsers != "" {
+			config.AllowUsers = strings.Split(allowUsers, ",")
+		}
+		adminUsers := os.Getenv("ADMIN_USERS")
+		if adminUsers != "" {
+			config.AdminUsers = strings.Split(adminUsers, ",")
+		}
 	})
 	if config.Model == "" {
 		config.DefaultMode = "gpt-3.5-turbo"

+ 3 - 0
docker-compose.yml

@@ -16,6 +16,9 @@ services:
       PORT: 8090 # 指定服务启动端口,默认为 8090,容器化部署时,不需要调整,一般在二进制宿主机部署时,遇到端口冲突时使用
       SERVICE_URL: ""  # 指定服务的地址,就是当前服务可供外网访问的地址(或者直接理解为你配置在钉钉回调那里的地址),用于生成图片时给钉钉做渲染
       CHAT_TYPE: "0" # 限定对话类型 0:不限 1:只能单聊 2:只能群聊
+      ALLOW_GROUPS: "学无止境,技术群" # 哪些群组可以进行对话,如果留空,则表示允许所有群组,如果有多个,则用英文逗号分割,docker-compose的语法不支持变量的值为数组
+      ALLOW_USERS: "xxx" # 哪些用户可以进行对话,如果留空,则表示允许所有用户,如果有多个,则用英文逗号分割
+      ADMIN_USERS: "" # 指定哪些人为此系统的管理员,如果留空,则表示没有人是管理员
     volumes:
       - ./data:/app/data
     ports:

+ 8 - 0
main.go

@@ -49,6 +49,14 @@ func Start() {
 			}
 			return nil
 		}
+		if !public.JudgeGroup(msgObj.GetChatTitle()) && !public.JudgeUsers(msgObj.SenderNick) && !public.JudgeAdminUsers(msgObj.SenderNick) {
+			_, err = msgObj.ReplyToDingtalk(string(dingbot.TEXT), "抱歉,您不在该机器人对话功能的白名单当中!")
+			if err != nil {
+				logger.Warning(fmt.Errorf("send message error: %v", err))
+				return err
+			}
+			return nil
+		}
 		if len(msgObj.Text.Content) == 1 || msgObj.Text.Content == "帮助" {
 			// 欢迎信息
 			_, err := msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), public.Welcome)

+ 10 - 2
pkg/db/chat.go

@@ -1,6 +1,7 @@
 package db
 
 import (
+	"errors"
 	"fmt"
 	"strings"
 
@@ -48,13 +49,20 @@ func (c Chat) List(req ChatListReq) ([]*Chat, error) {
 
 	userName := strings.TrimSpace(req.Username)
 	if userName != "" {
-		db = db.Where("username LIKE ?", fmt.Sprintf("%%%s%%", userName))
+		db = db.Where("username = ?", fmt.Sprintf("%%%s%%", userName))
 	}
 	source := strings.TrimSpace(req.Source)
 	if source != "" {
-		db = db.Where("source LIKE ?", fmt.Sprintf("%%%s%%", source))
+		db = db.Where("source = ?", fmt.Sprintf("%%%s%%", source))
 	}
 
 	err := db.Find(&list).Error
 	return list, err
 }
+
+// Exist 判断资源是否存在
+func (c Chat) Exist(filter map[string]interface{}) bool {
+	var dataObj Chat
+	err := DB.Where(filter).First(&dataObj).Error
+	return !errors.Is(err, gorm.ErrRecordNotFound)
+}

+ 9 - 1
pkg/process/process_request.go

@@ -258,7 +258,7 @@ func ImageGenerate(rmsg *dingbot.ReceiveMsg) error {
 }
 func SelectHistory(rmsg *dingbot.ReceiveMsg) error {
 	name := strings.TrimSpace(strings.Split(rmsg.Text.Content, ":")[1])
-	if !rmsg.IsAdmin || name != rmsg.SenderNick {
+	if !rmsg.IsAdmin && name != rmsg.SenderNick && !public.JudgeAdminUsers(rmsg.SenderNick) {
 		_, err := rmsg.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,您没有权限查询其他人的对话记录!**")
 		if err != nil {
 			logger.Error(fmt.Errorf("send message error: %v", err))
@@ -268,6 +268,14 @@ func SelectHistory(rmsg *dingbot.ReceiveMsg) error {
 	}
 	// 获取数据列表
 	var chat db.Chat
+	if !chat.Exist(map[string]interface{}{"username": name}) {
+		_, err := rmsg.ReplyToDingtalk(string(dingbot.TEXT), "用户名错误,这个用户不存在,请核实之后再进行查询")
+		if err != nil {
+			logger.Error(fmt.Errorf("send message error: %v", err))
+			return err
+		}
+		return fmt.Errorf("用户名错误,这个用户不存在,请核实之后重新查询")
+	}
 	chats, err := chat.List(db.ChatListReq{
 		Username: name,
 	})

+ 39 - 0
public/tools.go

@@ -23,3 +23,42 @@ func WriteToFile(path string, data []byte) error {
 	}
 	return nil
 }
+
+// JudgeGroup 判断群聊名称是否在白名单
+func JudgeGroup(s string) bool {
+	if len(Config.AllowGroups) == 0 {
+		return true
+	}
+	for _, v := range Config.AllowGroups {
+		if v == s {
+			return true
+		}
+	}
+	return false
+}
+
+// JudgeUsers 判断用户名称是否在白名单
+func JudgeUsers(s string) bool {
+	if len(Config.AllowUsers) == 0 {
+		return true
+	}
+	for _, v := range Config.AllowUsers {
+		if v == s {
+			return true
+		}
+	}
+	return false
+}
+
+// JudgeAdminUsers 判断用户是否为系统管理员
+func JudgeAdminUsers(s string) bool {
+	if len(Config.AllowGroups) == 0 {
+		return false
+	}
+	for _, v := range Config.AdminUsers {
+		if v == s {
+			return true
+		}
+	}
+	return false
+}