main.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. package main
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "os"
  7. "os/signal"
  8. "strings"
  9. "time"
  10. "github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
  11. "github.com/eryajf/chatgpt-dingtalk/pkg/logger"
  12. "github.com/eryajf/chatgpt-dingtalk/pkg/process"
  13. "github.com/eryajf/chatgpt-dingtalk/public"
  14. "github.com/gin-gonic/gin"
  15. "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
  16. "github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
  17. )
  18. func init() {
  19. // 初始化加载配置,数据库,模板等
  20. public.InitSvc()
  21. // 指定日志等级
  22. logger.InitLogger(public.Config.LogLevel)
  23. }
  24. func main() {
  25. if public.Config.RunMode == "http" {
  26. StartHttp()
  27. } else {
  28. for _, credential := range public.Config.Credentials {
  29. StartStream(credential.ClientID, credential.ClientSecret)
  30. }
  31. logger.Info("🚀 The Server Is Running On Stream Mode")
  32. select {}
  33. }
  34. }
  35. type ChatReceiver struct {
  36. clientId string
  37. clientSecret string
  38. }
  39. func NewChatReceiver(clientId, clientSecret string) *ChatReceiver {
  40. return &ChatReceiver{
  41. clientId: clientId,
  42. clientSecret: clientSecret,
  43. }
  44. }
  45. // 启动为 stream 模式
  46. func StartStream(clientId, clientSecret string) {
  47. receiver := NewChatReceiver(clientId, clientSecret)
  48. cli := client.NewStreamClient(client.WithAppCredential(client.NewAppCredentialConfig(clientId, clientSecret)))
  49. //注册callback类型的处理函数
  50. cli.RegisterChatBotCallbackRouter(receiver.OnChatBotMessageReceived)
  51. err := cli.Start(context.Background())
  52. if err != nil {
  53. logger.Fatal("strar stream failed: %v\n", err)
  54. }
  55. defer cli.Close()
  56. }
  57. // OnChatBotMessageReceived 简单的应答机器人实现
  58. func (r *ChatReceiver) OnChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) {
  59. msgObj := dingbot.ReceiveMsg{
  60. ConversationID: data.ConversationId,
  61. AtUsers: []struct {
  62. DingtalkID string "json:\"dingtalkId\""
  63. }{},
  64. ChatbotUserID: data.ChatbotUserId,
  65. MsgID: data.MsgId,
  66. SenderNick: data.SenderNick,
  67. IsAdmin: data.IsAdmin,
  68. SenderStaffId: data.SenderStaffId,
  69. SessionWebhookExpiredTime: data.SessionWebhookExpiredTime,
  70. CreateAt: data.CreateAt,
  71. ConversationType: data.ConversationType,
  72. SenderID: data.SenderId,
  73. ConversationTitle: data.ConversationTitle,
  74. IsInAtList: data.IsInAtList,
  75. SessionWebhook: data.SessionWebhook,
  76. Text: dingbot.Text(data.Text),
  77. RobotCode: "",
  78. Msgtype: dingbot.MsgType(data.Msgtype),
  79. }
  80. clientId := r.clientId
  81. var c gin.Context
  82. c.Set(public.DingTalkClientIdKeyName, clientId)
  83. DoRequest(msgObj, &c)
  84. return []byte(""), nil
  85. }
  86. func StartHttp() {
  87. app := gin.Default()
  88. app.POST("/", func(c *gin.Context) {
  89. var msgObj dingbot.ReceiveMsg
  90. err := c.Bind(&msgObj)
  91. if err != nil {
  92. return
  93. }
  94. DoRequest(msgObj, c)
  95. })
  96. // 解析生成后的图片
  97. app.GET("/images/:filename", func(c *gin.Context) {
  98. filename := c.Param("filename")
  99. c.File("./data/images/" + filename)
  100. })
  101. // 解析生成后的历史聊天
  102. app.GET("/history/:filename", func(c *gin.Context) {
  103. filename := c.Param("filename")
  104. c.File("./data/chatHistory/" + filename)
  105. })
  106. // 直接下载文件
  107. app.GET("/download/:filename", func(c *gin.Context) {
  108. filename := c.Param("filename")
  109. c.Header("Content-Disposition", "attachment; filename="+filename)
  110. c.Header("Content-Type", "application/octet-stream")
  111. c.File("./data/chatHistory/" + filename)
  112. })
  113. // 服务器健康检测
  114. app.GET("/", func(c *gin.Context) {
  115. c.JSON(200, gin.H{
  116. "status": "ok",
  117. "message": "🚀 欢迎使用钉钉机器人 🤖",
  118. })
  119. })
  120. port := ":" + public.Config.Port
  121. srv := &http.Server{
  122. Addr: port,
  123. Handler: app,
  124. }
  125. // Initializing the server in a goroutine so that
  126. // it won't block the graceful shutdown handling below
  127. go func() {
  128. logger.Info("🚀 The HTTP Server is running on", port)
  129. if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  130. logger.Fatal("listen: %s\n", err)
  131. }
  132. }()
  133. // Wait for interrupt signal to gracefully shutdown the server with
  134. // a timeout of 5 seconds.
  135. quit := make(chan os.Signal, 1)
  136. // kill (no param) default send syscall.SIGTERM
  137. // kill -2 is syscall.SIGINT
  138. // kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
  139. // signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
  140. signal.Notify(quit, os.Interrupt)
  141. <-quit
  142. logger.Info("Shutting down server...")
  143. // 5秒后强制退出
  144. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  145. defer cancel()
  146. if err := srv.Shutdown(ctx); err != nil {
  147. logger.Fatal("Server forced to shutdown:", err)
  148. }
  149. logger.Info("Server exiting!")
  150. }
  151. func DoRequest(msgObj dingbot.ReceiveMsg, c *gin.Context) {
  152. // 先校验回调是否合法
  153. // 如果是Outgoing机器人,判断是否在allow_outgoing_groups白名单内,如是(JudgeOutgoingGroup返回True)则跳过下面的逻辑,如不是则执行下面的逻辑(会返回失败)
  154. if public.Config.RunMode == "http" && (msgObj.RobotCode != "normal" || msgObj.RobotCode == "normal" && !public.JudgeOutgoingGroup(msgObj.ConversationID)) {
  155. clientId, checkOk := public.CheckRequestWithCredentials(c.GetHeader("timestamp"), c.GetHeader("sign"))
  156. if !checkOk {
  157. logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
  158. return
  159. }
  160. // 通过 context 传递 OAuth ClientID,用于后续流程中调用钉钉OpenAPI
  161. c.Set(public.DingTalkClientIdKeyName, clientId)
  162. }
  163. // 再校验回调参数是否有价值
  164. if msgObj.Text.Content == "" || msgObj.ChatbotUserID == "" {
  165. logger.Warning("从钉钉回调过来的内容为空,根据过往的经验,或许重新创建一下机器人,能解决这个问题")
  166. return
  167. }
  168. // 去除问题的前后空格
  169. msgObj.Text.Content = strings.TrimSpace(msgObj.Text.Content)
  170. if public.JudgeSensitiveWord(msgObj.Text.Content) {
  171. logger.Info(fmt.Sprintf("🙋 %s提问的问题中包含敏感词汇,userid:%#v,消息: %#v", msgObj.SenderNick, msgObj.SenderStaffId, msgObj.Text.Content))
  172. _, err := msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,您提问的问题中包含敏感词汇,请审核自己的对话内容之后再进行!**")
  173. if err != nil {
  174. logger.Warning(fmt.Errorf("send message error: %v", err))
  175. return
  176. }
  177. return
  178. }
  179. // 打印钉钉回调过来的请求明细,调试时打开
  180. logger.Debug(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
  181. if public.Config.ChatType != "0" && msgObj.ConversationType != public.Config.ChatType {
  182. logger.Info(fmt.Sprintf("🙋 %s使用了禁用的聊天方式", msgObj.SenderNick))
  183. _, err := msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,管理员禁用了这种聊天方式,请选择其他聊天方式与机器人对话!**")
  184. if err != nil {
  185. logger.Warning(fmt.Errorf("send message error: %v", err))
  186. return
  187. }
  188. return
  189. }
  190. // 查询群ID,发送指令后,可通过查看日志来获取
  191. if msgObj.ConversationType == "2" && msgObj.Text.Content == "群ID" {
  192. if msgObj.RobotCode == "normal" {
  193. logger.Info(fmt.Sprintf("🙋 outgoing机器人 在『%s』群的ConversationID为: %#v", msgObj.ConversationTitle, msgObj.ConversationID))
  194. } else {
  195. logger.Info(fmt.Sprintf("🙋 企业内部机器人 在『%s』群的ConversationID为: %#v", msgObj.ConversationTitle, msgObj.ConversationID))
  196. }
  197. return
  198. }
  199. // 不在允许群组,不在允许用户(包括在黑名单),满足任一条件,拒绝会话;管理员不受限制
  200. if msgObj.ConversationType == "2" && !public.JudgeGroup(msgObj.ConversationID) && !public.JudgeAdminUsers(msgObj.SenderStaffId) && msgObj.SenderStaffId != "" {
  201. logger.Info(fmt.Sprintf("🙋『%s』群组未被验证通过,群ID: %#v,userid:%#v, 昵称: %#v,消息: %#v", msgObj.ConversationTitle, msgObj.ConversationID, msgObj.SenderStaffId, msgObj.SenderNick, msgObj.Text.Content))
  202. _, err := msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,该群组未被认证通过,无法使用机器人对话功能。**\n>如需继续使用,请联系管理员申请访问权限。")
  203. if err != nil {
  204. logger.Warning(fmt.Errorf("send message error: %v", err))
  205. return
  206. }
  207. return
  208. } else if !public.JudgeUsers(msgObj.SenderStaffId) && !public.JudgeAdminUsers(msgObj.SenderStaffId) && msgObj.SenderStaffId != "" {
  209. logger.Info(fmt.Sprintf("🙋 %s身份信息未被验证通过,userid:%#v,消息: %#v", msgObj.SenderNick, msgObj.SenderStaffId, msgObj.Text.Content))
  210. _, err := msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,您的身份信息未被认证通过,无法使用机器人对话功能。**\n>如需继续使用,请联系管理员申请访问权限。")
  211. if err != nil {
  212. logger.Warning(fmt.Errorf("send message error: %v", err))
  213. return
  214. }
  215. return
  216. }
  217. if len(msgObj.Text.Content) == 0 || msgObj.Text.Content == "帮助" {
  218. // 欢迎信息
  219. _, err := msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), public.Config.Help)
  220. if err != nil {
  221. logger.Warning(fmt.Errorf("send message error: %v", err))
  222. return
  223. }
  224. } else {
  225. logger.Info(fmt.Sprintf("🙋 %s发起的问题: %#v", msgObj.SenderNick, msgObj.Text.Content))
  226. // 除去帮助之外的逻辑分流在这里处理
  227. switch {
  228. case strings.HasPrefix(msgObj.Text.Content, "#图片"):
  229. err := process.ImageGenerate(c, &msgObj)
  230. if err != nil {
  231. logger.Warning(fmt.Errorf("process request: %v", err))
  232. return
  233. }
  234. return
  235. case strings.HasPrefix(msgObj.Text.Content, "#查对话"):
  236. err := process.SelectHistory(&msgObj)
  237. if err != nil {
  238. logger.Warning(fmt.Errorf("process request: %v", err))
  239. return
  240. }
  241. return
  242. case strings.HasPrefix(msgObj.Text.Content, "#域名"):
  243. err := process.DomainMsg(&msgObj)
  244. if err != nil {
  245. logger.Warning(fmt.Errorf("process request: %v", err))
  246. return
  247. }
  248. return
  249. case strings.HasPrefix(msgObj.Text.Content, "#证书"):
  250. err := process.DomainCertMsg(&msgObj)
  251. if err != nil {
  252. logger.Warning(fmt.Errorf("process request: %v", err))
  253. return
  254. }
  255. return
  256. default:
  257. var err error
  258. msgObj.Text.Content, err = process.GeneratePrompt(msgObj.Text.Content)
  259. // err不为空:提示词之后没有文本 -> 直接返回提示词所代表的内容
  260. if err != nil {
  261. _, err = msgObj.ReplyToDingtalk(string(dingbot.TEXT), msgObj.Text.Content)
  262. if err != nil {
  263. logger.Warning(fmt.Errorf("send message error: %v", err))
  264. return
  265. }
  266. return
  267. }
  268. err = process.ProcessRequest(&msgObj)
  269. if err != nil {
  270. logger.Warning(fmt.Errorf("process request: %v", err))
  271. return
  272. }
  273. return
  274. }
  275. }
  276. }