main.go 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package main
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "os"
  7. "os/signal"
  8. "strings"
  9. "time"
  10. "github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
  11. "github.com/eryajf/chatgpt-dingtalk/pkg/logger"
  12. "github.com/eryajf/chatgpt-dingtalk/pkg/process"
  13. "github.com/eryajf/chatgpt-dingtalk/public"
  14. "github.com/gin-gonic/gin"
  15. )
  16. func init() {
  17. public.InitSvc()
  18. logger.InitLogger(public.Config.LogLevel)
  19. }
  20. func main() {
  21. Start()
  22. }
  23. func Start() {
  24. app := gin.Default()
  25. app.POST("/", func(c *gin.Context) {
  26. var msgObj dingbot.ReceiveMsg
  27. err := c.Bind(&msgObj)
  28. if err != nil {
  29. return
  30. }
  31. // 先校验回调是否合法
  32. clientId, checkOk := public.CheckRequestWithCredentials(c.GetHeader("timestamp"), c.GetHeader("sign"))
  33. if !checkOk {
  34. logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
  35. return
  36. }
  37. // 通过 context 传递 OAuth ClientID,用于后续流程中调用钉钉OpenAPI
  38. c.Set(public.DingTalkClientIdKeyName, clientId)
  39. // 为了兼容存量老用户,暂时保留 public.CheckRequest 方法,将来升级到 Stream 模式后,建议去除该方法,采用上面的 CheckRequestWithCredentials
  40. if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" {
  41. logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
  42. return
  43. } else if !public.JudgeOutgoingGroup(msgObj.ConversationID) && msgObj.SenderStaffId == "" {
  44. logger.Warning("该请求不合法,可能是未经允许的普通群outgoing机器人调用所致,请知悉!")
  45. return
  46. }
  47. // 再校验回调参数是否有价值
  48. if msgObj.Text.Content == "" || msgObj.ChatbotUserID == "" {
  49. logger.Warning("从钉钉回调过来的内容为空,根据过往的经验,或许重新创建一下机器人,能解决这个问题")
  50. return
  51. }
  52. // 去除问题的前后空格
  53. msgObj.Text.Content = strings.TrimSpace(msgObj.Text.Content)
  54. if public.JudgeSensitiveWord(msgObj.Text.Content) {
  55. logger.Info(fmt.Sprintf("🙋 %s提问的问题中包含敏感词汇,userid:%#v,消息: %#v", msgObj.SenderNick, msgObj.SenderStaffId, msgObj.Text.Content))
  56. _, err = msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,您提问的问题中包含敏感词汇,请审核自己的对话内容之后再进行!**")
  57. if err != nil {
  58. logger.Warning(fmt.Errorf("send message error: %v", err))
  59. return
  60. }
  61. return
  62. }
  63. // 打印钉钉回调过来的请求明细,调试时打开
  64. logger.Debug(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
  65. if public.Config.ChatType != "0" && msgObj.ConversationType != public.Config.ChatType {
  66. logger.Info(fmt.Sprintf("🙋 %s使用了禁用的聊天方式", msgObj.SenderNick))
  67. _, err = msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,管理员禁用了这种聊天方式,请选择其他聊天方式与机器人对话!**")
  68. if err != nil {
  69. logger.Warning(fmt.Errorf("send message error: %v", err))
  70. return
  71. }
  72. return
  73. }
  74. // 查询群ID,发送指令后,可通过查看日志来获取
  75. if msgObj.ConversationType == "2" && msgObj.Text.Content == "群ID" {
  76. if msgObj.RobotCode == "normal" {
  77. logger.Info(fmt.Sprintf("🙋 outgoing机器人 在『%s』群的ConversationID为: %#v", msgObj.ConversationTitle, msgObj.ConversationID))
  78. } else {
  79. logger.Info(fmt.Sprintf("🙋 企业内部机器人 在『%s』群的ConversationID为: %#v", msgObj.ConversationTitle, msgObj.ConversationID))
  80. }
  81. //_, err = msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), msgObj.ConversationID)
  82. if err != nil {
  83. logger.Warning(fmt.Errorf("send message error: %v", err))
  84. return
  85. }
  86. return
  87. }
  88. // 不在允许群组,不在允许用户(包括在黑名单),满足任一条件,拒绝会话;管理员不受限制
  89. if msgObj.ConversationType == "2" && !public.JudgeGroup(msgObj.ConversationID) && !public.JudgeAdminUsers(msgObj.SenderStaffId) && msgObj.SenderStaffId != "" {
  90. logger.Info(fmt.Sprintf("🙋『%s』群组未被验证通过,群ID: %#v,userid:%#v, 昵称: %#v,消息: %#v", msgObj.ConversationTitle, msgObj.ConversationID, msgObj.SenderStaffId, msgObj.SenderNick, msgObj.Text.Content))
  91. _, err = msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,该群组未被认证通过,无法使用机器人对话功能。**\n>如需继续使用,请联系管理员申请访问权限。")
  92. if err != nil {
  93. logger.Warning(fmt.Errorf("send message error: %v", err))
  94. return
  95. }
  96. return
  97. } else if !public.JudgeUsers(msgObj.SenderStaffId) && !public.JudgeAdminUsers(msgObj.SenderStaffId) && msgObj.SenderStaffId != "" {
  98. logger.Info(fmt.Sprintf("🙋 %s身份信息未被验证通过,userid:%#v,消息: %#v", msgObj.SenderNick, msgObj.SenderStaffId, msgObj.Text.Content))
  99. _, err = msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), "**🤷 抱歉,您的身份信息未被认证通过,无法使用机器人对话功能。**\n>如需继续使用,请联系管理员申请访问权限。")
  100. if err != nil {
  101. logger.Warning(fmt.Errorf("send message error: %v", err))
  102. return
  103. }
  104. return
  105. }
  106. if len(msgObj.Text.Content) == 0 || msgObj.Text.Content == "帮助" {
  107. // 欢迎信息
  108. _, err := msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), public.Config.Help)
  109. if err != nil {
  110. logger.Warning(fmt.Errorf("send message error: %v", err))
  111. return
  112. }
  113. } else {
  114. logger.Info(fmt.Sprintf("🙋 %s发起的问题: %#v", msgObj.SenderNick, msgObj.Text.Content))
  115. // 除去帮助之外的逻辑分流在这里处理
  116. switch {
  117. case strings.HasPrefix(msgObj.Text.Content, "#图片"):
  118. err := process.ImageGenerate(c, &msgObj)
  119. if err != nil {
  120. logger.Warning(fmt.Errorf("process request: %v", err))
  121. return
  122. }
  123. return
  124. case strings.HasPrefix(msgObj.Text.Content, "#查对话"):
  125. err := process.SelectHistory(&msgObj)
  126. if err != nil {
  127. logger.Warning(fmt.Errorf("process request: %v", err))
  128. return
  129. }
  130. return
  131. case strings.HasPrefix(msgObj.Text.Content, "#域名"):
  132. err := process.DomainMsg(&msgObj)
  133. if err != nil {
  134. logger.Warning(fmt.Errorf("process request: %v", err))
  135. return
  136. }
  137. return
  138. case strings.HasPrefix(msgObj.Text.Content, "#证书"):
  139. err := process.DomainCertMsg(&msgObj)
  140. if err != nil {
  141. logger.Warning(fmt.Errorf("process request: %v", err))
  142. return
  143. }
  144. return
  145. default:
  146. msgObj.Text.Content, err = process.GeneratePrompt(msgObj.Text.Content)
  147. // err不为空:提示词之后没有文本 -> 直接返回提示词所代表的内容
  148. if err != nil {
  149. _, err = msgObj.ReplyToDingtalk(string(dingbot.TEXT), msgObj.Text.Content)
  150. if err != nil {
  151. logger.Warning(fmt.Errorf("send message error: %v", err))
  152. return
  153. }
  154. return
  155. }
  156. err := process.ProcessRequest(&msgObj)
  157. if err != nil {
  158. logger.Warning(fmt.Errorf("process request: %v", err))
  159. return
  160. }
  161. return
  162. }
  163. }
  164. })
  165. // 解析生成后的图片
  166. app.GET("/images/:filename", func(c *gin.Context) {
  167. filename := c.Param("filename")
  168. c.File("./data/images/" + filename)
  169. })
  170. // 解析生成后的历史聊天
  171. app.GET("/history/:filename", func(c *gin.Context) {
  172. filename := c.Param("filename")
  173. c.File("./data/chatHistory/" + filename)
  174. })
  175. // 直接下载文件
  176. app.GET("/download/:filename", func(c *gin.Context) {
  177. filename := c.Param("filename")
  178. c.Header("Content-Disposition", "attachment; filename="+filename)
  179. c.Header("Content-Type", "application/octet-stream")
  180. c.File("./data/chatHistory/" + filename)
  181. })
  182. // 服务器健康检测
  183. app.GET("/", func(c *gin.Context) {
  184. c.JSON(200, gin.H{
  185. "status": "ok",
  186. "message": "🚀 欢迎使用钉钉机器人 🤖",
  187. })
  188. })
  189. port := ":" + public.Config.Port
  190. srv := &http.Server{
  191. Addr: port,
  192. Handler: app,
  193. }
  194. // Initializing the server in a goroutine so that
  195. // it won't block the graceful shutdown handling below
  196. go func() {
  197. logger.Info("🚀 The HTTP Server is running on", port)
  198. if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  199. logger.Fatal("listen: %s\n", err)
  200. }
  201. }()
  202. // Wait for interrupt signal to gracefully shutdown the server with
  203. // a timeout of 5 seconds.
  204. quit := make(chan os.Signal, 1)
  205. // kill (no param) default send syscall.SIGTERM
  206. // kill -2 is syscall.SIGINT
  207. // kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
  208. // signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
  209. signal.Notify(quit, os.Interrupt)
  210. <-quit
  211. logger.Info("Shutting down server...")
  212. // 5秒后强制退出
  213. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  214. defer cancel()
  215. if err := srv.Shutdown(ctx); err != nil {
  216. logger.Fatal("Server forced to shutdown:", err)
  217. }
  218. logger.Info("Server exiting!")
  219. }