Forráskód Böngészése

feat: 增加了生成图片的能力,以及其他一系列优化 (#115)

Co-authored-by: eryajf <eryajf@users.noreply.github.com>
二丫讲梵 2 éve
szülő
commit
7219b2b324
13 módosított fájl, 325 hozzáadás és 130 törlés
  1. 12 0
      .github/workflows/TOC.yml
  2. 75 33
      README.md
  3. 2 1
      config.dev.json
  4. 9 0
      config/config.go
  5. 3 2
      docker-compose.yml
  6. 1 0
      go.mod
  7. 2 0
      go.sum
  8. 29 38
      main.go
  9. 53 4
      pkg/chatgpt/context.go
  10. 25 0
      pkg/chatgpt/export.go
  11. 38 27
      pkg/process/process_request.go
  12. 52 24
      public/dingtalk.go
  13. 24 1
      public/public.go

+ 12 - 0
.github/workflows/TOC.yml

@@ -0,0 +1,12 @@
+on: push
+name: Automatic Generation TOC
+jobs:
+  generateTOC:
+    name: TOC Generator
+    runs-on: ubuntu-latest
+    steps:
+      - uses: technote-space/toc-generator@v4
+        with:
+          TOC_TITLE: "**目录**"
+          MAX_HEADER_LEVEL: 3
+          FOLDING: true

A különbségek nem kerülnek megjelenítésre, a fájl túl nagy
+ 75 - 33
README.md


+ 2 - 1
config.dev.json

@@ -5,5 +5,6 @@
     "session_timeout": 600,
     "http_proxy": "",
     "default_mode": "单聊",
-    "max_request": 0
+    "max_request": 0,
+    "service_url": "http://chat.eryajf.net"
 }

+ 9 - 0
config/config.go

@@ -27,6 +27,8 @@ type Configuration struct {
 	HttpProxy string `json:"http_proxy"`
 	// 用户单日最大请求次数
 	MaxRequest int `json:"max_request"`
+	// 指定服务的地址,就是钉钉机器人配置的回调地址,比如: http://chat.eryajf.net
+	ServiceURL string `json:"service_url"`
 }
 
 var config *Configuration
@@ -57,6 +59,7 @@ func LoadConfig() *Configuration {
 		defaultMode := os.Getenv("DEFAULT_MODE")
 		httpProxy := os.Getenv("HTTP_PROXY")
 		maxRequest := os.Getenv("MAX_REQUEST")
+		serviceURL := os.Getenv("SERVICE_URL")
 		if apiKey != "" {
 			config.ApiKey = apiKey
 		}
@@ -86,6 +89,9 @@ func LoadConfig() *Configuration {
 			newMR, _ := strconv.Atoi(maxRequest)
 			config.MaxRequest = newMR
 		}
+		if serviceURL != "" {
+			config.ServiceURL = serviceURL
+		}
 	})
 	if config.Model == "" {
 		config.DefaultMode = "gpt-3.5-turbo"
@@ -96,5 +102,8 @@ func LoadConfig() *Configuration {
 	if config.ApiKey == "" {
 		logger.Danger("config err: api key required")
 	}
+	if config.ServiceURL == "" {
+		logger.Danger("config err: service url required")
+	}
 	return config
 }

+ 3 - 2
docker-compose.yml

@@ -7,12 +7,13 @@ services:
     restart: always
     environment:
       APIKEY: xxxxxx  # 你的 api_key
-      BASE_URL: xxxxxx  # 如果你想指定请求url的地址,可通过这个参数进行配置,默认为官方地址,不需要再添加 /v1
+      BASE_URL: xxxxxx  # 如果你想指定请求url的地址,可通过这个参数进行配置,不需要再添加 /v1,如果留空则默认为官方地址,
       MODEL: "gpt-3.5-turbo" # 指定模型
       SESSION_TIMEOUT: 600 # 超时时间
-      HTTP_PROXY: http://host.docker.internal:15777 # 配置代理,注意:host.docker.internal会解析到容器所在的宿主机IP,如果你的服务部署在宿主机,只需要更改端口即
+      HTTP_PROXY: http://host.docker.internal:15777 # 配置代理,注意:host.docker.internal会解析到容器所在的宿主机IP,因此只需要更改端口即可,另外如果服务器在国外,则这里留空即
       DEFAULT_MODE: "单聊" # 聊天模式
       MAX_REQUEST: 0 # 单人单日请求次数限制,默认为0,即不限制
+      SERVICE_URL: ""  # 指定服务的地址,就是钉钉机器人配置的回调地址,比如: http://chat.eryajf.net
     ports:
       - "8090:8090"
     extra_hosts:

+ 1 - 0
go.mod

@@ -6,6 +6,7 @@ require (
 	github.com/go-resty/resty/v2 v2.7.0
 	github.com/patrickmn/go-cache v2.1.0+incompatible
 	github.com/solywsh/chatgpt v0.0.14
+	github.com/xgfone/ship/v5 v5.3.1
 	gopkg.in/yaml.v2 v2.4.0
 )
 

+ 2 - 0
go.sum

@@ -18,6 +18,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
 github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
 github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
 github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/xgfone/ship/v5 v5.3.1 h1:e5qhMT6DSQOE6A/xDBL/Ftf28BGptNw6Etq+w+pme6E=
+github.com/xgfone/ship/v5 v5.3.1/go.mod h1:mGI+65lLL3kaOseMkWUYgy+OFl27WV2LY1NSsecu/9g=
 golang.org/x/net v0.0.0-20211029224645-99673261e6eb h1:pirldcYWx7rx7kE5r+9WsOXPXK0+WH5+uZ7uPmJ44uM=
 golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
 golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

+ 29 - 38
main.go

@@ -1,15 +1,14 @@
 package main
 
 import (
-	"encoding/json"
 	"fmt"
-	"io/ioutil"
-	"net/http"
+	"path/filepath"
 	"strings"
 
 	"github.com/eryajf/chatgpt-dingtalk/pkg/process"
 	"github.com/eryajf/chatgpt-dingtalk/public"
 	"github.com/eryajf/chatgpt-dingtalk/public/logger"
+	"github.com/xgfone/ship/v5"
 )
 
 func init() {
@@ -27,61 +26,53 @@ var Welcome string = `Commands:
 💵 余额 👉 查询剩余额度
 🚀 帮助 👉 显示帮助信息
 🌈 模板 👉 内置的prompt
+🎨 图片 👉 根据prompt生成图片
 =================================
 🚜 例:@我发送 空 或 帮助 将返回此帮助信息
 💪 Power By https://github.com/eryajf/chatgpt-dingtalk
 `
 
 func Start() {
-	// 定义一个处理器函数
-	handler := func(w http.ResponseWriter, r *http.Request) {
-		data, err := ioutil.ReadAll(r.Body)
+	app := ship.Default()
+	app.Route("/").POST(func(c *ship.Context) error {
+		var msgObj public.ReceiveMsg
+		err := c.Bind(&msgObj)
 		if err != nil {
-			http.Error(w, err.Error(), http.StatusBadRequest)
-			logger.Warning(fmt.Sprintf("read request body failed: %v\n", err.Error()))
-			return
-		}
-		if len(data) == 0 {
-			logger.Warning("回调参数为空,以至于无法正常解析,请检查原因")
-			return
-		}
-		var msgObj = new(public.ReceiveMsg)
-		err = json.Unmarshal(data, &msgObj)
-		if err != nil {
-			logger.Warning(fmt.Errorf("unmarshal request body failed: %v", err))
+			return ship.ErrBadRequest.New(fmt.Errorf("bind to receivemsg failed : %v", err))
 		}
 		if msgObj.Text.Content == "" || msgObj.ChatbotUserID == "" {
 			logger.Warning("从钉钉回调过来的内容为空,根据过往的经验,或许重新创建一下机器人,能解决这个问题")
-			return
+			return ship.ErrBadRequest.New(fmt.Errorf("从钉钉回调过来的内容为空,根据过往的经验,或许重新创建一下机器人,能解决这个问题"))
 		}
-
+		// 打印钉钉回调过来的请求明细
+		logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
 		// TODO: 校验请求
 		if len(msgObj.Text.Content) == 1 || strings.TrimSpace(msgObj.Text.Content) == "帮助" {
 			// 欢迎信息
-			_, err := msgObj.ReplyText(Welcome, msgObj.SenderStaffId)
+			_, err := msgObj.ReplyToDingtalk(string(public.TEXT), Welcome, msgObj.SenderStaffId)
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
+				return ship.ErrBadRequest.New(fmt.Errorf("send message error: %v", err))
 			}
 		} else {
-			msgObj.Text.Content = process.GeneratePrompt(strings.TrimSpace(msgObj.Text.Content))
-			logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj))
-			err = process.ProcessRequest(*msgObj)
-			if err != nil {
-				logger.Warning(fmt.Errorf("process request failed: %v", err))
+			// 除去帮助之外的逻辑分流在这里处理
+			switch {
+			case strings.HasPrefix(strings.TrimSpace(msgObj.Text.Content), "#图片"):
+				return process.ImageGenerate(&msgObj)
+			default:
+				msgObj.Text.Content = process.GeneratePrompt(strings.TrimSpace(msgObj.Text.Content))
+				return process.ProcessRequest(&msgObj)
 			}
 		}
-	}
-
-	// 创建一个新的 HTTP 服务器
-	server := &http.Server{
-		Addr:    ":8090",
-		Handler: http.HandlerFunc(handler),
-	}
+		return nil
+	})
+	// 解析生成后的图片
+	app.Route("/images/:filename").GET(func(c *ship.Context) error {
+		filename := c.Param("filename")
+		root := "./images/"
+		return c.File(filepath.Join(root, filename))
+	})
 
 	// 启动服务器
-	logger.Info("Start Listen On ", server.Addr)
-	err := server.ListenAndServe()
-	if err != nil {
-		logger.Danger(err)
-	}
+	ship.StartServer(":8090", app)
 }

+ 53 - 4
pkg/chatgpt/context.go

@@ -2,9 +2,12 @@ package chatgpt
 
 import (
 	"bytes"
+	"encoding/base64"
 	"encoding/gob"
-	"fmt"
+	"image/png"
+	"os"
 	"strings"
+	"time"
 
 	"github.com/eryajf/chatgpt-dingtalk/public"
 	openai "github.com/sashabaranov/go-openai"
@@ -52,9 +55,9 @@ func NewContext(options ...ChatContextOption) *ChatContext {
 	ctx := &ChatContext{
 		aiRole:           &role{Name: DefaultAiRole},
 		humanRole:        &role{Name: DefaultHumanRole},
-		background:       fmt.Sprintf(DefaultBackground, strings.Join(DefaultCharacter, ", ")+"."),
+		background:       "",
 		maxSeqTimes:      1000,
-		preset:           fmt.Sprintf(DefaultPreset, DefaultHumanRole, DefaultAiRole),
+		preset:           "",
 		old:              []conversation{},
 		seqTimes:         0,
 		restartSeq:       "\n" + DefaultHumanRole + ": ",
@@ -156,7 +159,6 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
 	if len(prompt) > c.maxText-c.maxAnswerLen {
 		return "", OverMaxTextLength
 	}
-
 	model := public.Config.Model
 	if model == openai.GPT3Dot5Turbo0301 ||
 		model == openai.GPT3Dot5Turbo ||
@@ -215,6 +217,53 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
 		return resp.Choices[0].Text, nil
 	}
 }
+func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
+	model := public.Config.Model
+	if model == openai.GPT3Dot5Turbo0301 ||
+		model == openai.GPT3Dot5Turbo ||
+		model == openai.GPT4 || model == openai.GPT40314 ||
+		model == openai.GPT432K || model == openai.GPT432K0314 {
+		req := openai.ImageRequest{
+			Prompt:         prompt,
+			Size:           openai.CreateImageSize1024x1024,
+			ResponseFormat: openai.CreateImageResponseFormatB64JSON,
+			N:              1,
+			User:           c.userId,
+		}
+		respBase64, err := c.client.CreateImage(c.ctx, req)
+		if err != nil {
+			return "", err
+		}
+		imgBytes, err := base64.StdEncoding.DecodeString(respBase64.Data[0].B64JSON)
+		if err != nil {
+			return "", err
+		}
+
+		r := bytes.NewReader(imgBytes)
+		imgData, err := png.Decode(r)
+		if err != nil {
+			return "", err
+		}
+
+		imageName := time.Now().Format("20060102-150405") + ".png"
+		err = os.MkdirAll("images", 0755)
+		if err != nil {
+			return "", err
+		}
+		file, err := os.Create("images/" + imageName)
+		if err != nil {
+			return "", err
+		}
+		defer file.Close()
+
+		if err := png.Encode(file, imgData); err != nil {
+			return "", err
+		}
+
+		return public.Config.ServiceURL + "/images/" + imageName, nil
+	}
+	return "", nil
+}
 
 func WithMaxSeqTimes(times int) ChatContextOption {
 	return func(c *ChatContext) {

+ 25 - 0
pkg/chatgpt/export.go

@@ -8,6 +8,7 @@ import (
 	"github.com/eryajf/chatgpt-dingtalk/public/logger"
 )
 
+// SingleQa 单聊
 func SingleQa(question, userId string) (answer string, err error) {
 	chat := New(userId)
 	defer chat.Close()
@@ -30,6 +31,7 @@ func SingleQa(question, userId string) (answer string, err error) {
 	return
 }
 
+// ContextQa 串聊
 func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error) {
 	chat = New(userId)
 	if public.UserService.GetUserSessionContext(userId) != "" {
@@ -54,3 +56,26 @@ func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error
 		retryStrategy...)
 	return
 }
+
+// ImageQa 生成图片
+func ImageQa(question, userId string) (answer string, err error) {
+	chat := New(userId)
+	defer chat.Close()
+	// 定义一个重试策略
+	retryStrategy := []retry.Option{
+		retry.Delay(100 * time.Millisecond),
+		retry.Attempts(3),
+		retry.LastErrorOnly(true),
+	}
+	// 使用重试策略进行重试
+	err = retry.Do(
+		func() error {
+			answer, err = chat.GenreateImage(question)
+			if err != nil {
+				return err
+			}
+			return nil
+		},
+		retryStrategy...)
+	return
+}

+ 38 - 27
pkg/process/process_request.go

@@ -11,26 +11,26 @@ import (
 )
 
 // ProcessRequest 分析处理请求逻辑
-func ProcessRequest(rmsg public.ReceiveMsg) error {
-	if CheckRequest(rmsg) {
+func ProcessRequest(rmsg *public.ReceiveMsg) error {
+	if public.CheckRequest(rmsg) {
 		content := strings.TrimSpace(rmsg.Text.Content)
 		switch content {
 		case "单聊":
 			public.UserService.SetUserMode(rmsg.SenderStaffId, content)
-			_, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈单聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
+			_, err := rmsg.ReplyToDingtalk(string(public.TEXT), fmt.Sprintf("=====现在进入与👉%s👈单聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 			}
 		case "串聊":
 			public.UserService.SetUserMode(rmsg.SenderStaffId, content)
-			_, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈串聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
+			_, err := rmsg.ReplyToDingtalk(string(public.TEXT), fmt.Sprintf("=====现在进入与👉%s👈串聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId)
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 			}
 		case "重置":
 			public.UserService.ClearUserMode(rmsg.SenderStaffId)
 			public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
-			_, err := rmsg.ReplyText(fmt.Sprintf("=====已重置与👉%s👈的对话模式,可以开始新的对话=====", rmsg.SenderNick), rmsg.SenderStaffId)
+			_, err := rmsg.ReplyToDingtalk(string(public.TEXT), fmt.Sprintf("=====已重置与👉%s👈的对话模式,可以开始新的对话=====", rmsg.SenderNick), rmsg.SenderStaffId)
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 			}
@@ -39,7 +39,12 @@ func ProcessRequest(rmsg public.ReceiveMsg) error {
 			for _, v := range *public.Prompt {
 				title = title + v.Title + " | "
 			}
-			_, err := rmsg.ReplyText(fmt.Sprintf("%s 您好,当前程序内置集成了这些prompt:\n====================================\n| %s \n====================================\n你可以选择某个prompt开头,然后进行对话。\n以周报为例,可发送 #周报 我本周用Go写了一个钉钉集成ChatGPT的聊天应用", rmsg.SenderNick, title), rmsg.SenderStaffId)
+			_, err := rmsg.ReplyToDingtalk(string(public.TEXT), fmt.Sprintf("%s 您好,当前程序内置集成了这些prompt:\n====================================\n| %s \n====================================\n你可以选择某个prompt开头,然后进行对话。\n以周报为例,可发送 #周报 我本周用Go写了一个钉钉集成ChatGPT的聊天应用", rmsg.SenderNick, title), rmsg.SenderStaffId)
+			if err != nil {
+				logger.Warning(fmt.Errorf("send message error: %v", err))
+			}
+		case "图片":
+			_, err := rmsg.ReplyToDingtalk(string(public.MARKDOWN), "发送以 **#图片** 开头的内容,将会触发绘画能力,图片生成之后,将会保存在程序根目录下的 **images目录** \n 如果你绘图没有思路,可以在这两个网站寻找灵感。\n - [https://lexica.art/](https://lexica.art/)\n- [https://www.clickprompt.org/zh-CN/](https://www.clickprompt.org/zh-CN/)", rmsg.SenderStaffId)
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 			}
@@ -56,7 +61,7 @@ func ProcessRequest(rmsg public.ReceiveMsg) error {
 				cacheMsg = fmt.Sprintf("💵 已用: 💲%v\n💵 剩余: 💲%v\n⏳ 有效时间: 从 %v 到 %v\n", fmt.Sprintf("%.2f", rst.TotalUsed), fmt.Sprintf("%.2f", rst.TotalAvailable), t1.Format("2006-01-02 15:04:05"), t2.Format("2006-01-02 15:04:05"))
 			}
 
-			_, err := rmsg.ReplyText(cacheMsg, rmsg.SenderStaffId)
+			_, err := rmsg.ReplyToDingtalk(string(public.TEXT), cacheMsg, rmsg.SenderStaffId)
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 			}
@@ -72,7 +77,7 @@ func ProcessRequest(rmsg public.ReceiveMsg) error {
 }
 
 // 执行处理请求
-func Do(mode string, rmsg public.ReceiveMsg) error {
+func Do(mode string, rmsg *public.ReceiveMsg) error {
 	// 先把模式注入
 	public.UserService.SetUserMode(rmsg.SenderStaffId, mode)
 	switch mode {
@@ -82,13 +87,13 @@ func Do(mode string, rmsg public.ReceiveMsg) error {
 			logger.Info(fmt.Errorf("gpt request error: %v", err))
 			if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
 				public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
-				_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
+				_, err = rmsg.ReplyToDingtalk(string(public.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
 				if err != nil {
 					logger.Warning(fmt.Errorf("send message error: %v", err))
 					return err
 				}
 			} else {
-				_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
+				_, err = rmsg.ReplyToDingtalk(string(public.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
 				if err != nil {
 					logger.Warning(fmt.Errorf("send message error: %v", err))
 					return err
@@ -103,7 +108,7 @@ func Do(mode string, rmsg public.ReceiveMsg) error {
 			reply = strings.Trim(reply, "\n")
 			// 回复@我的用户
 			// fmt.Println("单聊结果是:", reply)
-			_, err = rmsg.ReplyText(reply, rmsg.SenderStaffId)
+			_, err = rmsg.ReplyToDingtalk(string(public.TEXT), reply, rmsg.SenderStaffId)
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 				return err
@@ -115,13 +120,13 @@ func Do(mode string, rmsg public.ReceiveMsg) error {
 			logger.Info(fmt.Sprintf("gpt request error: %v", err))
 			if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") {
 				public.UserService.ClearUserSessionContext(rmsg.SenderStaffId)
-				_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
+				_, err = rmsg.ReplyToDingtalk(string(public.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId)
 				if err != nil {
 					logger.Warning(fmt.Errorf("send message error: %v", err))
 					return err
 				}
 			} else {
-				_, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
+				_, err = rmsg.ReplyToDingtalk(string(public.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
 				if err != nil {
 					logger.Warning(fmt.Errorf("send message error: %v", err))
 					return err
@@ -135,7 +140,7 @@ func Do(mode string, rmsg public.ReceiveMsg) error {
 			reply = strings.TrimSpace(reply)
 			reply = strings.Trim(reply, "\n")
 			// 回复@我的用户
-			_, err = rmsg.ReplyText(reply, rmsg.SenderStaffId)
+			_, err = rmsg.ReplyToDingtalk(string(public.TEXT), reply, rmsg.SenderStaffId)
 			if err != nil {
 				logger.Warning(fmt.Errorf("send message error: %v", err))
 				return err
@@ -148,22 +153,28 @@ func Do(mode string, rmsg public.ReceiveMsg) error {
 	return nil
 }
 
-// ProcessRequest 分析处理请求逻辑
-func CheckRequest(rmsg public.ReceiveMsg) bool {
-	if public.Config.MaxRequest == 0 {
-		return true
+func ImageGenerate(rmsg *public.ReceiveMsg) error {
+	reply, err := chatgpt.ImageQa(rmsg.Text.Content, rmsg.SenderStaffId)
+	if err != nil {
+		logger.Info(fmt.Errorf("gpt request error: %v", err))
+		_, err = rmsg.ReplyToDingtalk(string(public.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId)
+		if err != nil {
+			logger.Warning(fmt.Errorf("send message error: %v", err))
+			return err
+		}
 	}
-	count := public.UserService.GetUseRequestCount(rmsg.SenderStaffId)
-	// 判断访问次数是否超过限制
-	if count >= public.Config.MaxRequest {
-		logger.Info(fmt.Sprintf("亲爱的: %s,您今日请求次数已达上限,请明天再来,交互发问资源有限,请务必斟酌您的问题,给您带来不便,敬请谅解!", rmsg.SenderNick))
-		_, err := rmsg.ReplyText(fmt.Sprintf("一个好的问题,胜过十个好的答案!\n亲爱的: %s,您今日请求次数已达上限,请明天再来,交互发问资源有限,请务必斟酌您的问题,给您带来不便,敬请谅解!", rmsg.SenderNick), rmsg.SenderStaffId)
+	if reply == "" {
+		logger.Warning(fmt.Errorf("get gpt result falied: %v", err))
+		return nil
+	} else {
+		reply = strings.TrimSpace(reply)
+		reply = strings.Trim(reply, "\n")
+		// 回复@我的用户
+		_, err = rmsg.ReplyToDingtalk(string(public.MARKDOWN), fmt.Sprintf(">点击图片可旋转或放大。\n![](%s)", reply), rmsg.SenderStaffId)
 		if err != nil {
 			logger.Warning(fmt.Errorf("send message error: %v", err))
+			return err
 		}
-		return false
 	}
-	// 访问次数未超过限制,将计数加1
-	public.UserService.SetUseRequestCount(rmsg.SenderStaffId, count+1)
-	return true
+	return nil
 }

+ 52 - 24
public/dingtalk.go

@@ -12,45 +12,73 @@ type ReceiveMsg struct {
 	AtUsers        []struct {
 		DingtalkID string `json:"dingtalkId"`
 	} `json:"atUsers"`
-	ChatbotUserID             string `json:"chatbotUserId"`
-	MsgID                     string `json:"msgId"`
-	SenderNick                string `json:"senderNick"`
-	IsAdmin                   bool   `json:"isAdmin"`
-	SenderStaffId             string `json:"senderStaffId"`
-	SessionWebhookExpiredTime int64  `json:"sessionWebhookExpiredTime"`
-	CreateAt                  int64  `json:"createAt"`
-	ConversationType          string `json:"conversationType"`
-	SenderID                  string `json:"senderId"`
-	ConversationTitle         string `json:"conversationTitle"`
-	IsInAtList                bool   `json:"isInAtList"`
-	SessionWebhook            string `json:"sessionWebhook"`
-	Text                      Text   `json:"text"`
-	RobotCode                 string `json:"robotCode"`
-	Msgtype                   string `json:"msgtype"`
+	ChatbotUserID             string  `json:"chatbotUserId"`
+	MsgID                     string  `json:"msgId"`
+	SenderNick                string  `json:"senderNick"`
+	IsAdmin                   bool    `json:"isAdmin"`
+	SenderStaffId             string  `json:"senderStaffId"`
+	SessionWebhookExpiredTime int64   `json:"sessionWebhookExpiredTime"`
+	CreateAt                  int64   `json:"createAt"`
+	ConversationType          string  `json:"conversationType"`
+	SenderID                  string  `json:"senderId"`
+	ConversationTitle         string  `json:"conversationTitle"`
+	IsInAtList                bool    `json:"isInAtList"`
+	SessionWebhook            string  `json:"sessionWebhook"`
+	Text                      Text    `json:"text"`
+	RobotCode                 string  `json:"robotCode"`
+	Msgtype                   MsgType `json:"msgtype"`
 }
 
+// 消息类型
+type MsgType string
 
-// 发送的消息体
-type SendMsg struct {
-	Text    Text   `json:"text"`
-	Msgtype string `json:"msgtype"`
-	At 		At `json:"at"`
+const TEXT MsgType = "text"
+const MARKDOWN MsgType = "markdown"
+
+// Text 消息
+type TextMessage struct {
+	MsgType MsgType `json:"msgtype"`
+	At      *At     `json:"at"`
+	Text    *Text   `json:"text"`
 }
 
-// 消息内容
+// Text 消息内容
 type Text struct {
 	Content string `json:"content"`
 }
 
+// MarkDown 消息
+type MarkDownMessage struct {
+	MsgType  MsgType   `json:"msgtype"`
+	At       *At       `json:"at"`
+	MarkDown *MarkDown `json:"markdown"`
+}
+
+// MarkDown 消息内容
+type MarkDown struct {
+	Title string `json:"title"`
+	Text  string `json:"text"`
+}
+
 // at 内容
 type At struct {
 	AtUserIds []string `json:"atUserIds"`
+	AtMobiles []string `json:"atMobiles"`
+	IsAtAll   bool     `json:"isAtAll"`
 }
 
 // 发消息给钉钉
-func (r ReceiveMsg) ReplyText(msg string, staffId string) (statuscode int, err error) {
-	// 定义消息
-	msgtmp := &SendMsg{Text: Text{Content: msg}, Msgtype: "text", At: At{AtUserIds: []string{staffId}}}
+func (r ReceiveMsg) ReplyToDingtalk(msgType, msg, staffId string) (statuscode int, err error) {
+	var msgtmp interface{}
+	switch msgType {
+	case string(TEXT):
+		msgtmp = &TextMessage{Text: &Text{Content: msg}, MsgType: TEXT, At: &At{AtUserIds: []string{staffId}}}
+	case string(MARKDOWN):
+		msgtmp = &MarkDownMessage{MsgType: MARKDOWN, At: &At{AtUserIds: []string{staffId}}, MarkDown: &MarkDown{Title: "根据您提供的信息,为您生成图片如下", Text: msg}}
+	default:
+		msgtmp = &TextMessage{Text: &Text{Content: msg}, MsgType: TEXT, At: &At{AtUserIds: []string{staffId}}}
+	}
+
 	data, err := json.Marshal(msgtmp)
 	if err != nil {
 		return 0, err

+ 24 - 1
public/public.go

@@ -1,10 +1,12 @@
 package public
 
 import (
+	"fmt"
 	"strings"
 
 	"github.com/eryajf/chatgpt-dingtalk/config"
 	"github.com/eryajf/chatgpt-dingtalk/pkg/cache"
+	"github.com/eryajf/chatgpt-dingtalk/public/logger"
 )
 
 var UserService cache.UserServiceInterface
@@ -18,7 +20,7 @@ func InitSvc() {
 	_, _ = GetBalance()
 }
 
-func FirstCheck(rmsg ReceiveMsg) bool {
+func FirstCheck(rmsg *ReceiveMsg) bool {
 	lc := UserService.GetUserMode(rmsg.SenderStaffId)
 	if lc == "" {
 		if Config.DefaultMode == "串聊" {
@@ -32,3 +34,24 @@ func FirstCheck(rmsg ReceiveMsg) bool {
 	}
 	return false
 }
+
+// ProcessRequest 分析处理请求逻辑
+// 主要提供单日请求限额的功能
+func CheckRequest(rmsg *ReceiveMsg) bool {
+	if Config.MaxRequest == 0 {
+		return true
+	}
+	count := UserService.GetUseRequestCount(rmsg.SenderStaffId)
+	// 判断访问次数是否超过限制
+	if count >= Config.MaxRequest {
+		logger.Info(fmt.Sprintf("亲爱的: %s,您今日请求次数已达上限,请明天再来,交互发问资源有限,请务必斟酌您的问题,给您带来不便,敬请谅解!", rmsg.SenderNick))
+		_, err := rmsg.ReplyToDingtalk(string(TEXT), fmt.Sprintf("一个好的问题,胜过十个好的答案!\n亲爱的: %s,您今日请求次数已达上限,请明天再来,交互发问资源有限,请务必斟酌您的问题,给您带来不便,敬请谅解!", rmsg.SenderNick), rmsg.SenderStaffId)
+		if err != nil {
+			logger.Warning(fmt.Errorf("send message error: %v", err))
+		}
+		return false
+	}
+	// 访问次数未超过限制,将计数加1
+	UserService.SetUseRequestCount(rmsg.SenderStaffId, count+1)
+	return true
+}