main.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. package main
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io/ioutil"
  6. "net/http"
  7. "strings"
  8. "github.com/eryajf/chatgpt-dingtalk/config"
  9. "github.com/eryajf/chatgpt-dingtalk/public"
  10. "github.com/eryajf/chatgpt-dingtalk/public/logger"
  11. "github.com/solywsh/chatgpt"
  12. )
  13. func init() {
  14. public.InitSvc()
  15. }
  16. func main() {
  17. Start()
  18. }
  19. var Welcome string = `Commands:
  20. =================================
  21. 🙋 单聊 👉 单独聊天,缺省
  22. 📣 串聊 👉 带上下文聊天
  23. 🔃 重置 👉 重置带上下文聊天
  24. 🚀 帮助 👉 显示帮助信息
  25. =================================
  26. 🚜 例:@我发送 空 或 帮助 将返回此帮助信息
  27. 💪 Power By https://github.com/eryajf/chatgpt-dingtalk
  28. `
  29. // 💵 余额 👉 查看接口可调用额度
  30. func Start() {
  31. // 定义一个处理器函数
  32. handler := func(w http.ResponseWriter, r *http.Request) {
  33. data, err := ioutil.ReadAll(r.Body)
  34. if err != nil {
  35. http.Error(w, err.Error(), http.StatusBadRequest)
  36. logger.Warning(fmt.Sprintf("read request body failed: %v\n", err.Error()))
  37. return
  38. }
  39. if len(data) == 0 {
  40. logger.Warning("回调参数为空,以至于无法正常解析,请检查原因")
  41. return
  42. }
  43. var msgObj = new(public.ReceiveMsg)
  44. err = json.Unmarshal(data, &msgObj)
  45. if err != nil {
  46. logger.Warning(fmt.Errorf("unmarshal request body failed: %v", err))
  47. }
  48. if msgObj.Text.Content == "" || msgObj.ChatbotUserID == "" {
  49. logger.Warning("从钉钉回调过来的内容为空,根据过往的经验,或许重新创建一下机器人,能解决这个问题")
  50. return
  51. }
  52. // TODO: 校验请求
  53. if len(msgObj.Text.Content) == 1 || strings.TrimSpace(msgObj.Text.Content) == "帮助" {
  54. // 欢迎信息
  55. _, err := msgObj.ReplyText(Welcome, msgObj.SenderStaffId)
  56. if err != nil {
  57. logger.Warning(fmt.Errorf("send message error: %v", err))
  58. }
  59. } else {
  60. logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
  61. err = ProcessRequest(*msgObj)
  62. if err != nil {
  63. logger.Warning(fmt.Errorf("process request failed: %v", err))
  64. }
  65. }
  66. }
  67. // 创建一个新的 HTTP 服务器
  68. server := &http.Server{
  69. Addr: ":8090",
  70. Handler: http.HandlerFunc(handler),
  71. }
  72. // 启动服务器
  73. logger.Info("Start Listen On ", server.Addr)
  74. err := server.ListenAndServe()
  75. if err != nil {
  76. logger.Danger(err)
  77. }
  78. }
  79. func ProcessRequest(rmsg public.ReceiveMsg) error {
  80. content := strings.TrimSpace(rmsg.Text.Content)
  81. switch content {
  82. case "单聊":
  83. public.UserService.SetUserMode(rmsg.SenderStaffId, content)
  84. _, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈单聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
  85. if err != nil {
  86. logger.Warning(fmt.Errorf("send message error: %v", err))
  87. }
  88. case "串聊":
  89. public.UserService.SetUserMode(rmsg.SenderStaffId, content)
  90. _, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈串聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
  91. if err != nil {
  92. logger.Warning(fmt.Errorf("send message error: %v", err))
  93. }
  94. case "重置":
  95. public.UserService.ClearUserMode(rmsg.SenderStaffId)
  96. public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
  97. _, err := rmsg.ReplyText(fmt.Sprintf("=====已重置与👉%s👈的对话模式,可以开始新的对话=====", rmsg.SenderNick), rmsg.SenderStaffId)
  98. if err != nil {
  99. logger.Warning(fmt.Errorf("send message error: %v", err))
  100. }
  101. default:
  102. if public.FirstCheck(rmsg) {
  103. return Do("串聊", rmsg)
  104. } else {
  105. return Do("单聊", rmsg)
  106. }
  107. }
  108. return nil
  109. }
  110. func Do(mode string, rmsg public.ReceiveMsg) error {
  111. // 先把模式注入
  112. public.UserService.SetUserMode(rmsg.SenderStaffId, mode)
  113. switch mode {
  114. case "单聊":
  115. reply, err := SingleQa(rmsg.Text.Content, rmsg.SenderNick)
  116. if err != nil {
  117. logger.Info(fmt.Errorf("gpt request error: %v", err))
  118. if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
  119. public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
  120. _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
  121. if err != nil {
  122. logger.Warning(fmt.Errorf("send message error: %v", err))
  123. return err
  124. }
  125. } else {
  126. _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
  127. if err != nil {
  128. logger.Warning(fmt.Errorf("send message error: %v", err))
  129. return err
  130. }
  131. }
  132. }
  133. if reply == "" {
  134. logger.Warning(fmt.Errorf("get gpt result falied: %v", err))
  135. return nil
  136. } else {
  137. reply = strings.TrimSpace(reply)
  138. reply = strings.Trim(reply, "\n")
  139. // 回复@我的用户
  140. // fmt.Println("单聊结果是:", reply)
  141. _, err = rmsg.ReplyText(reply, rmsg.SenderStaffId)
  142. if err != nil {
  143. logger.Warning(fmt.Errorf("send message error: %v", err))
  144. return err
  145. }
  146. }
  147. case "串聊":
  148. cli, reply, err := ContextQa(rmsg.Text.Content, rmsg.SenderStaffId)
  149. if err != nil {
  150. logger.Info(fmt.Sprintf("gpt request error: %v", err))
  151. if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
  152. public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
  153. _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
  154. if err != nil {
  155. logger.Warning(fmt.Errorf("send message error: %v", err))
  156. return err
  157. }
  158. } else {
  159. _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
  160. if err != nil {
  161. logger.Warning(fmt.Errorf("send message error: %v", err))
  162. return err
  163. }
  164. }
  165. }
  166. if reply == "" {
  167. logger.Warning(fmt.Errorf("get gpt result falied: %v", err))
  168. return nil
  169. } else {
  170. reply = strings.TrimSpace(reply)
  171. reply = strings.Trim(reply, "\n")
  172. // 回复@我的用户
  173. _, err = rmsg.ReplyText(reply, rmsg.SenderStaffId)
  174. if err != nil {
  175. logger.Warning(fmt.Errorf("send message error: %v", err))
  176. return err
  177. }
  178. _ = cli.ChatContext.SaveConversation(rmsg.SenderStaffId)
  179. }
  180. default:
  181. }
  182. return nil
  183. }
  184. func SingleQa(question, userId string) (answer string, err error) {
  185. cfg := config.LoadConfig()
  186. chat := chatgpt.New(cfg.ApiKey, userId, cfg.SessionTimeout)
  187. defer chat.Close()
  188. return chat.ChatWithContext(question)
  189. }
  190. func ContextQa(question, userId string) (chat *chatgpt.ChatGPT, answer string, err error) {
  191. cfg := config.LoadConfig()
  192. chat = chatgpt.New(cfg.ApiKey, userId, cfg.SessionTimeout)
  193. if public.UserService.GetUserSessionContext(userId) != "" {
  194. err = chat.ChatContext.LoadConversation(userId)
  195. if err != nil {
  196. fmt.Printf("load station failed: %v\n", err)
  197. }
  198. }
  199. answer, err = chat.ChatWithContext(question)
  200. return
  201. }