main.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. package main
  2. import (
  3. "context"
  4. "fmt"
  5. "net/http"
  6. "os"
  7. "os/signal"
  8. "path/filepath"
  9. "strings"
  10. "time"
  11. "github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
  12. "github.com/eryajf/chatgpt-dingtalk/pkg/logger"
  13. "github.com/eryajf/chatgpt-dingtalk/pkg/process"
  14. "github.com/eryajf/chatgpt-dingtalk/public"
  15. "github.com/xgfone/ship/v5"
  16. )
  17. func init() {
  18. public.InitSvc()
  19. logger.InitLogger(public.Config.LogLevel)
  20. }
  21. func main() {
  22. Start()
  23. }
  24. func Start() {
  25. app := ship.Default()
  26. app.Route("/").POST(func(c *ship.Context) error {
  27. var msgObj dingbot.ReceiveMsg
  28. err := c.Bind(&msgObj)
  29. if err != nil {
  30. return ship.ErrBadRequest.New(fmt.Errorf("bind to receivemsg failed : %v", err))
  31. }
  32. // 先校验回调是否合法
  33. if !public.CheckRequest(c.GetReqHeader("timestamp"), c.GetReqHeader("sign")) {
  34. logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
  35. return nil
  36. }
  37. // 再校验回调参数是否有价值
  38. if msgObj.Text.Content == "" || msgObj.ChatbotUserID == "" {
  39. logger.Warning("从钉钉回调过来的内容为空,根据过往的经验,或许重新创建一下机器人,能解决这个问题")
  40. return ship.ErrBadRequest.New(fmt.Errorf("从钉钉回调过来的内容为空,根据过往的经验,或许重新创建一下机器人,能解决这个问题"))
  41. }
  42. // 去除问题的前后空格
  43. msgObj.Text.Content = strings.TrimSpace(msgObj.Text.Content)
  44. // 打印钉钉回调过来的请求明细,调试时打开
  45. fmt.Println("=======", logger.Logger.GetLevel().String())
  46. logger.Debug(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
  47. if public.Config.ChatType != "0" && msgObj.ConversationType != public.Config.ChatType {
  48. _, err = msgObj.ReplyToDingtalk(string(dingbot.TEXT), "抱歉,管理员禁用了这种聊天方式,请选择其他聊天方式与机器人对话!")
  49. if err != nil {
  50. logger.Warning(fmt.Errorf("send message error: %v", err))
  51. return err
  52. }
  53. return nil
  54. }
  55. if !public.JudgeGroup(msgObj.GetChatTitle()) && !public.JudgeUsers(msgObj.SenderNick) && !public.JudgeAdminUsers(msgObj.SenderNick) {
  56. _, err = msgObj.ReplyToDingtalk(string(dingbot.TEXT), "抱歉,您不在该机器人对话功能的白名单当中!")
  57. if err != nil {
  58. logger.Warning(fmt.Errorf("send message error: %v", err))
  59. return err
  60. }
  61. return nil
  62. }
  63. if len(msgObj.Text.Content) == 1 || msgObj.Text.Content == "帮助" {
  64. // 欢迎信息
  65. _, err := msgObj.ReplyToDingtalk(string(dingbot.MARKDOWN), public.Welcome)
  66. if err != nil {
  67. logger.Warning(fmt.Errorf("send message error: %v", err))
  68. return ship.ErrBadRequest.New(fmt.Errorf("send message error: %v", err))
  69. }
  70. } else {
  71. logger.Info(fmt.Sprintf("🙋 %s发起的问题: %#v", msgObj.SenderNick, msgObj.Text.Content))
  72. // 除去帮助之外的逻辑分流在这里处理
  73. switch {
  74. case strings.HasPrefix(msgObj.Text.Content, "#图片"):
  75. return process.ImageGenerate(&msgObj)
  76. case strings.HasPrefix(msgObj.Text.Content, "#查对话"):
  77. return process.SelectHistory(&msgObj)
  78. case strings.HasPrefix(msgObj.Text.Content, "#域名"):
  79. return process.DomainMsg(&msgObj)
  80. case strings.HasPrefix(msgObj.Text.Content, "#证书"):
  81. return process.DomainCertMsg(&msgObj)
  82. default:
  83. msgObj.Text.Content, err = process.GeneratePrompt(msgObj.Text.Content)
  84. // err不为空:提示词之后没有文本 -> 直接返回提示词所代表的内容
  85. if err != nil {
  86. _, err = msgObj.ReplyToDingtalk(string(dingbot.TEXT), msgObj.Text.Content)
  87. if err != nil {
  88. logger.Warning(fmt.Errorf("send message error: %v", err))
  89. return err
  90. }
  91. return nil
  92. }
  93. return process.ProcessRequest(&msgObj)
  94. }
  95. }
  96. return nil
  97. })
  98. // 解析生成后的图片
  99. app.Route("/images/:filename").GET(func(c *ship.Context) error {
  100. filename := c.Param("filename")
  101. root := "./data/images/"
  102. return c.File(filepath.Join(root, filename))
  103. })
  104. // 解析生成后的历史聊天
  105. app.Route("/history/:filename").GET(func(c *ship.Context) error {
  106. filename := c.Param("filename")
  107. root := "./data/chatHistory/"
  108. return c.File(filepath.Join(root, filename))
  109. })
  110. // 直接下载文件
  111. app.Route("/download/:filename").GET(func(c *ship.Context) error {
  112. filename := c.Param("filename")
  113. root := "./data/chatHistory/"
  114. return c.Attachment(filepath.Join(root, filename), "")
  115. })
  116. port := ":" + public.Config.Port
  117. srv := &http.Server{
  118. Addr: port,
  119. Handler: app,
  120. }
  121. // Initializing the server in a goroutine so that
  122. // it won't block the graceful shutdown handling below
  123. go func() {
  124. logger.Info("🚀 The HTTP Server is running on", port)
  125. if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
  126. logger.Fatal("listen: %s\n", err)
  127. }
  128. }()
  129. // Wait for interrupt signal to gracefully shutdown the server with
  130. // a timeout of 5 seconds.
  131. quit := make(chan os.Signal, 1)
  132. // kill (no param) default send syscall.SIGTERM
  133. // kill -2 is syscall.SIGINT
  134. // kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
  135. // signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
  136. signal.Notify(quit, os.Interrupt)
  137. <-quit
  138. logger.Info("Shutting down server...")
  139. // 5秒后强制退出
  140. ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
  141. defer cancel()
  142. if err := srv.Shutdown(ctx); err != nil {
  143. logger.Fatal("Server forced to shutdown:", err)
  144. }
  145. logger.Info("Server exiting!")
  146. }