Browse Source

feat: 支持上传图片到钉钉平台,在图片生成流程中使用钉钉的图片 CDN 能力 (#225)

金喜@DingTalk 1 year ago
parent
commit
2eda9e8e22

+ 10 - 0
README.md

@@ -225,6 +225,7 @@ $ docker run -itd --name chatgpt -p 8090:8090 \
   -e SENSITIVE_WORDS="aa,bb" \
   -e AZURE_ON="false" -e AZURE_API_VERSION="" -e AZURE_RESOURCE_NAME="" \
   -e AZURE_DEPLOYMENT_NAME="" -e AZURE_OPENAI_TOKEN="" \
+  -e DINGTALK_CREDENTIALS="your_client_id1:secret1,your_client_id2:secret2" \
   -e HELP="欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/)
   ,觉得不错你可以来波素质三连."  \
   --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest
@@ -541,6 +542,15 @@ azure_resource_name: "xxxx"
 azure_deployment_name: "xxxx"
 azure_openai_token: "xxxx"
 
+# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
+# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
+# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
+# 建议采用 credentials 代替 app_secrets 配置项,以获得钉钉 OpenAPI 访问能力
+credentials:
+  -
+    client_id: "put-your-client-id-here"
+    client_secret: "put-your-client-secret-here"
+
 ```
 
 ## 常见问题

+ 7 - 0
config.example.yml

@@ -59,3 +59,10 @@ azure_resource_name: "xxxx"
 azure_deployment_name: "xxxx"
 azure_openai_token: "xxxx"
 
+# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
+# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
+# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
+#credentials:
+#  -
+#    client_id: "put-your-client-id-here"
+#    client_secret: "put-your-client-secret-here"

+ 19 - 0
config/config.go

@@ -14,6 +14,11 @@ import (
 	"gopkg.in/yaml.v2"
 )
 
+type Credential struct {
+	ClientID     string `yaml:"client_id"`
+	ClientSecret string `yaml:"client_secret"`
+}
+
 // Configuration 项目配置
 type Configuration struct {
 	// 日志级别,info或者debug
@@ -62,6 +67,8 @@ type Configuration struct {
 	AzureResourceName   string `yaml:"azure_resource_name"`
 	AzureDeploymentName string `yaml:"azure_deployment_name"`
 	AzureOpenAIToken    string `yaml:"azure_openai_token"`
+	// 钉钉应用鉴权凭据
+	Credentials []Credential `yaml:"credentials"`
 }
 
 var config *Configuration
@@ -190,6 +197,18 @@ func LoadConfig() *Configuration {
 		if azureOpenaiToken != "" {
 			config.AzureOpenAIToken = azureOpenaiToken
 		}
+		credentials := os.Getenv("DINGTALK_CREDENTIALS")
+		if credentials != "" {
+			if config.Credentials == nil {
+				config.Credentials = []Credential{}
+			}
+			for _, idSecret := range strings.Split(credentials, ",") {
+				items := strings.SplitN(idSecret, ":", 2)
+				if len(items) == 2 {
+					config.Credentials = append(config.Credentials, Credential{ClientID: items[0], ClientSecret: items[1]})
+				}
+			}
+		}
 
 	})
 

+ 1 - 0
docker-compose.yml

@@ -37,6 +37,7 @@ services:
       AZURE_RESOURCE_NAME: "" # Azure OpenAi API 资源名称,比如 "openai"
       AZURE_DEPLOYMENT_NAME: "" # Azure OpenAi API 部署名称,比如 "openai"
       AZURE_OPENAI_TOKEN: "" # Azure token
+      DINGTALK_CREDENTIALS: "" # 钉钉应用访问凭证,比如 "client_id1:secret1,client_id2:secret2"
       HELP: "欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/),觉得不错你可以来波素质三连." # 帮助信息,放在配置文件,可供自定义
     volumes:
       - ./data:/app/data

+ 9 - 1
main.go

@@ -33,6 +33,14 @@ func Start() {
 			return
 		}
 		// 先校验回调是否合法
+		clientId, checkOk := public.CheckRequestWithCredentials(c.GetHeader("timestamp"), c.GetHeader("sign"))
+		if !checkOk {
+			logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
+			return
+		}
+		// 通过 context 传递 OAuth ClientID,用于后续流程中调用钉钉OpenAPI
+		c.Set(public.DingTalkClientIdKeyName, clientId)
+		// 为了兼容存量老用户,暂时保留 public.CheckRequest 方法,将来升级到 Stream 模式后,建议去除该方法,采用上面的 CheckRequestWithCredentials
 		if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" {
 			logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
 			return
@@ -114,7 +122,7 @@ func Start() {
 			// 除去帮助之外的逻辑分流在这里处理
 			switch {
 			case strings.HasPrefix(msgObj.Text.Content, "#图片"):
-				err := process.ImageGenerate(&msgObj)
+				err := process.ImageGenerate(c, &msgObj)
 				if err != nil {
 					logger.Warning(fmt.Errorf("process request: %v", err))
 					return

+ 17 - 3
pkg/chatgpt/context.go

@@ -2,8 +2,12 @@ package chatgpt
 
 import (
 	"bytes"
+	"context"
 	"encoding/base64"
 	"encoding/gob"
+	"errors"
+	"fmt"
+	"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
 	"github.com/pandodao/tokenizer-go"
 	"image/png"
 	"os"
@@ -218,7 +222,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
 		return resp.Choices[0].Text, nil
 	}
 }
-func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
+func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) {
 	model := public.Config.Model
 	if model == openai.GPT3Dot5Turbo0301 ||
 		model == openai.GPT3Dot5Turbo ||
@@ -247,6 +251,13 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
 		}
 
 		imageName := time.Now().Format("20060102-150405") + ".png"
+		clientId, _ := ctx.Value(public.DingTalkClientIdKeyName).(string)
+		client := public.DingTalkClientManager.GetClientByOAuthClientID(clientId)
+		mediaResult, uploadErr := &dingbot.MediaUploadResult{}, errors.New(fmt.Sprintf("unknown clientId: %s", clientId))
+		if client != nil {
+			mediaResult, uploadErr = client.UploadMedia(imgBytes, imageName, dingbot.MediaTypeImage, dingbot.MimeTypeImagePng)
+		}
+
 		err = os.MkdirAll("data/images", 0755)
 		if err != nil {
 			return "", err
@@ -260,8 +271,11 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
 		if err := png.Encode(file, imgData); err != nil {
 			return "", err
 		}
-
-		return public.Config.ServiceURL + "/images/" + imageName, nil
+		if uploadErr == nil {
+			return mediaResult.MediaID, nil
+		} else {
+			return public.Config.ServiceURL + "/images/" + imageName, nil
+		}
 	}
 	return "", nil
 }

+ 3 - 2
pkg/chatgpt/export.go

@@ -1,6 +1,7 @@
 package chatgpt
 
 import (
+	"context"
 	"time"
 
 	"github.com/avast/retry-go"
@@ -58,7 +59,7 @@ func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error
 }
 
 // ImageQa 生成图片
-func ImageQa(question, userId string) (answer string, err error) {
+func ImageQa(ctx context.Context, question, userId string) (answer string, err error) {
 	chat := New(userId)
 	defer chat.Close()
 	// 定义一个重试策略
@@ -70,7 +71,7 @@ func ImageQa(question, userId string) (answer string, err error) {
 	// 使用重试策略进行重试
 	err = retry.Do(
 		func() error {
-			answer, err = chat.GenreateImage(question)
+			answer, err = chat.GenreateImage(ctx, question)
 			if err != nil {
 				return err
 			}

+ 213 - 0
pkg/dingbot/client.go

@@ -0,0 +1,213 @@
+package dingbot
+
+import (
+	"bytes"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"github.com/eryajf/chatgpt-dingtalk/config"
+	"io"
+	"mime/multipart"
+	"net/http"
+	url2 "net/url"
+	"sync"
+	"time"
+)
+
+// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
+const (
+	MediaTypeImage string = "image"
+	MediaTypeVoice string = "voice"
+	MediaTypeVideo string = "video"
+	MediaTypeFile  string = "file"
+)
+const (
+	MimeTypeImagePng string = "image/png"
+)
+
+type MediaUploadResult struct {
+	ErrorCode    int64  `json:"errcode"`
+	ErrorMessage string `json:"errmsg"`
+	MediaID      string `json:"media_id"`
+	CreatedAt    int64  `json:"created_at"`
+	Type         string `json:"type"`
+}
+
+type OAuthTokenResult struct {
+	ErrorCode    int    `json:"errcode"`
+	ErrorMessage string `json:"errmsg"`
+	AccessToken  string `json:"access_token"`
+	ExpiresIn    int    `json:"expires_in"`
+}
+
+type DingTalkClientInterface interface {
+	GetAccessToken() (string, error)
+	UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error)
+}
+
+type DingTalkClientManagerInterface interface {
+	GetClientByOAuthClientID(clientId string) DingTalkClientInterface
+}
+
+type DingTalkClient struct {
+	Credential  config.Credential
+	AccessToken string
+	expireAt    int64
+	mutex       sync.Mutex
+}
+
+type DingTalkClientManager struct {
+	Credentials []config.Credential
+	Clients     map[string]*DingTalkClient
+	mutex       sync.Mutex
+}
+
+func NewDingTalkClient(credential config.Credential) *DingTalkClient {
+	return &DingTalkClient{
+		Credential: credential,
+	}
+}
+
+func NewDingTalkClientManager(conf *config.Configuration) *DingTalkClientManager {
+	clients := make(map[string]*DingTalkClient)
+
+	if conf != nil && conf.Credentials != nil {
+		for _, credential := range conf.Credentials {
+			clients[credential.ClientID] = NewDingTalkClient(credential)
+		}
+	}
+	return &DingTalkClientManager{
+		Credentials: conf.Credentials,
+		Clients:     clients,
+	}
+}
+
+func (m *DingTalkClientManager) GetClientByOAuthClientID(clientId string) DingTalkClientInterface {
+	m.mutex.Lock()
+	defer m.mutex.Unlock()
+	if client, ok := m.Clients[clientId]; ok {
+		return client
+	}
+	return nil
+}
+
+func (c *DingTalkClient) GetAccessToken() (string, error) {
+	accessToken := ""
+	{
+		// 先查询缓存
+		c.mutex.Lock()
+		now := time.Now().Unix()
+		if c.expireAt > 0 && c.AccessToken != "" && (now+60) < c.expireAt {
+			// 预留一分钟有效期避免在Token过期的临界点调用接口出现401错误
+			accessToken = c.AccessToken
+		}
+		c.mutex.Unlock()
+	}
+	if accessToken != "" {
+		return accessToken, nil
+	}
+
+	tokenResult, err := c.getAccessTokenFromDingTalk()
+	if err != nil {
+		return "", err
+	}
+
+	{
+		// 更新缓存
+		c.mutex.Lock()
+		c.AccessToken = tokenResult.AccessToken
+		c.expireAt = time.Now().Unix() + int64(tokenResult.ExpiresIn)
+		c.mutex.Unlock()
+	}
+	return tokenResult.AccessToken, nil
+}
+
+func (c *DingTalkClient) UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error) {
+	// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
+	accessToken, err := c.GetAccessToken()
+	if err != nil {
+		return nil, err
+	}
+	if len(accessToken) == 0 {
+		return nil, errors.New("empty access token")
+	}
+	body := &bytes.Buffer{}
+	writer := multipart.NewWriter(body)
+	part, err := writer.CreateFormFile("media", filename)
+	if err != nil {
+		return nil, err
+	}
+	_, err = part.Write(content)
+	writer.WriteField("type", mediaType)
+	err = writer.Close()
+	if err != nil {
+		return nil, err
+	}
+
+	// Create a new HTTP request to upload the media file
+	url := fmt.Sprintf("https://oapi.dingtalk.com/media/upload?access_token=%s", url2.QueryEscape(accessToken))
+	req, err := http.NewRequest("POST", url, body)
+	if err != nil {
+		return nil, err
+	}
+	req.Header.Set("Content-Type", writer.FormDataContentType())
+
+	// Send the HTTP request and parse the response
+	client := &http.Client{
+		Timeout: time.Second * 60,
+	}
+	res, err := client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer res.Body.Close()
+
+	// Parse the response body as JSON and extract the media ID
+	media := &MediaUploadResult{}
+	bodyBytes, err := io.ReadAll(res.Body)
+	json.Unmarshal(bodyBytes, media)
+	if err != nil {
+		return nil, err
+	}
+	if media.ErrorCode != 0 {
+		return nil, errors.New(media.ErrorMessage)
+	}
+	return media, nil
+}
+
+func (c *DingTalkClient) getAccessTokenFromDingTalk() (*OAuthTokenResult, error) {
+	// OpenAPI doc: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token
+	apiUrl := "https://oapi.dingtalk.com/gettoken"
+	queryParams := url2.Values{}
+	queryParams.Add("appkey", c.Credential.ClientID)
+	queryParams.Add("appsecret", c.Credential.ClientSecret)
+
+	// Create a new HTTP request to get the AccessToken
+	req, err := http.NewRequest("GET", apiUrl+"?"+queryParams.Encode(), nil)
+	if err != nil {
+		return nil, err
+	}
+
+	// Send the HTTP request and parse the response body as JSON
+	client := http.Client{
+		Timeout: time.Second * 60,
+	}
+	res, err := client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+	defer res.Body.Close()
+	body, err := io.ReadAll(res.Body)
+	if err != nil {
+		return nil, err
+	}
+	tokenResult := &OAuthTokenResult{}
+	err = json.Unmarshal(body, tokenResult)
+	if err != nil {
+		return nil, err
+	}
+	if tokenResult.ErrorCode != 0 {
+		return nil, errors.New(tokenResult.ErrorMessage)
+	}
+	return tokenResult, nil
+}

+ 53 - 0
pkg/dingbot/client_test.go

@@ -0,0 +1,53 @@
+package dingbot
+
+import (
+	"bytes"
+	"github.com/eryajf/chatgpt-dingtalk/config"
+	"image"
+	"image/color"
+	"image/png"
+	"os"
+	"testing"
+)
+
+func TestUploadMedia_Pass_WithValidConfig(t *testing.T) {
+	// 设置了钉钉 ClientID 和 ClientSecret 的环境变量才执行以下测试,用于快速验证钉钉图片上传能力
+	clientId, clientSecret := os.Getenv("DINGTALK_CLIENT_ID_FOR_TEST"), os.Getenv("DINGTALK_CLIENT_SECRET_FOR_TEST")
+	if len(clientId) <= 0 || len(clientSecret) <= 0 {
+		return
+	}
+	credentials := []config.Credential{
+		config.Credential{
+			ClientID:     clientId,
+			ClientSecret: clientSecret,
+		},
+	}
+	client := NewDingTalkClientManager(&config.Configuration{Credentials: credentials}).GetClientByOAuthClientID(clientId)
+	var imageContent []byte
+	{
+		// 生成一张用于测试的图片
+		img := image.NewRGBA(image.Rect(0, 0, 200, 100))
+		blue := color.RGBA{0, 0, 255, 255}
+		for x := 0; x < img.Bounds().Dx(); x++ {
+			for y := 0; y < img.Bounds().Dy(); y++ {
+				img.Set(x, y, blue)
+			}
+		}
+		buf := new(bytes.Buffer)
+		err := png.Encode(buf, img)
+		if err != nil {
+			return
+		}
+		// get the byte array from the buffer
+		imageContent = buf.Bytes()
+	}
+	result, err := client.UploadMedia(imageContent, "filename.png", "image", "image/png")
+	if err != nil {
+		t.Errorf("upload media failed, err=%s", err.Error())
+		return
+	}
+	if result.MediaID == "" {
+		t.Errorf("upload media failed, empty media id")
+		return
+	}
+}

+ 3 - 2
pkg/process/image.go

@@ -1,6 +1,7 @@
 package process
 
 import (
+	"context"
 	"fmt"
 	"github.com/eryajf/chatgpt-dingtalk/public"
 	"strings"
@@ -12,7 +13,7 @@ import (
 )
 
 // ImageGenerate openai生成图片
-func ImageGenerate(rmsg *dingbot.ReceiveMsg) error {
+func ImageGenerate(ctx context.Context, rmsg *dingbot.ReceiveMsg) error {
 	if public.Config.AzureOn {
 		_, err := rmsg.ReplyToDingtalk(string(dingbot.
 			MARKDOWN), "azure 模式下暂不支持图片创作功能")
@@ -32,7 +33,7 @@ func ImageGenerate(rmsg *dingbot.ReceiveMsg) error {
 	if err != nil {
 		logger.Error("往MySQL新增数据失败,错误信息:", err)
 	}
-	reply, err := chatgpt.ImageQa(rmsg.Text.Content, rmsg.GetSenderIdentifier())
+	reply, err := chatgpt.ImageQa(ctx, rmsg.Text.Content, rmsg.GetSenderIdentifier())
 	if err != nil {
 		logger.Info(fmt.Errorf("gpt request error: %v", err))
 		_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err))

+ 6 - 0
public/public.go

@@ -4,12 +4,16 @@ import (
 	"github.com/eryajf/chatgpt-dingtalk/config"
 	"github.com/eryajf/chatgpt-dingtalk/pkg/cache"
 	"github.com/eryajf/chatgpt-dingtalk/pkg/db"
+	"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
 	"github.com/sashabaranov/go-openai"
 )
 
 var UserService cache.UserServiceInterface
 var Config *config.Configuration
 var Prompt *[]config.Prompt
+var DingTalkClientManager dingbot.DingTalkClientManagerInterface
+
+const DingTalkClientIdKeyName = "DingTalkClientId"
 
 func InitSvc() {
 	// 加载配置
@@ -18,6 +22,8 @@ func InitSvc() {
 	Prompt = config.LoadPrompt()
 	// 初始化缓存
 	UserService = cache.NewUserService()
+	// 初始化钉钉开放平台的客户端,用于访问上传图片等能力
+	DingTalkClientManager = dingbot.NewDingTalkClientManager(Config)
 	// 初始化数据库
 	db.InitDB()
 	// 暂时不在初始化时获取余额

+ 17 - 0
public/tools.go

@@ -124,6 +124,23 @@ func GetReadTime(t time.Time) string {
 	return t.Format("2006-01-02 15:04:05")
 }
 
+func CheckRequestWithCredentials(ts, sg string) (clientId string, pass bool) {
+	clientId, pass = "", false
+	credentials := Config.Credentials
+	if credentials == nil || len(credentials) == 0 {
+		return "", true
+	}
+	for _, credential := range Config.Credentials {
+		stringToSign := fmt.Sprintf("%s\n%s", ts, credential.ClientSecret)
+		mac := hmac.New(sha256.New, []byte(credential.ClientSecret))
+		_, _ = mac.Write([]byte(stringToSign))
+		if base64.StdEncoding.EncodeToString(mac.Sum(nil)) == sg {
+			return credential.ClientID, true
+		}
+	}
+	return
+}
+
 func CheckRequest(ts, sg string) bool {
 	appSecrets := Config.AppSecrets
 	// 如果没有指定或者outgoing类型机器人下使用,则默认不做校验

+ 76 - 0
public/tools_test.go

@@ -0,0 +1,76 @@
+package public
+
+import (
+	"github.com/eryajf/chatgpt-dingtalk/config"
+	"testing"
+)
+
+func TestCheckRequestWithCredentials_Pass_WithNilConfig(t *testing.T) {
+	Config = &config.Configuration{
+		Credentials: nil,
+	}
+	clientId, pass := CheckRequestWithCredentials("ts", "sg")
+	if !pass {
+		t.Errorf("pass should be true, but false")
+		return
+	}
+	if len(clientId) > 0 {
+		t.Errorf("client id should be empty")
+		return
+	}
+}
+
+func TestCheckRequestWithCredentials_Pass_WithEmptyConfig(t *testing.T) {
+	Config = &config.Configuration{
+		Credentials: []config.Credential{},
+	}
+	clientId, pass := CheckRequestWithCredentials("ts", "sg")
+	if !pass {
+		t.Errorf("pass should be true, but false")
+		return
+	}
+	if len(clientId) > 0 {
+		t.Errorf("client id should be empty")
+		return
+	}
+}
+
+func TestCheckRequestWithCredentials_Pass_WithValidConfig(t *testing.T) {
+	Config = &config.Configuration{
+		Credentials: []config.Credential{
+			config.Credential{
+				ClientID:     "client-id-for-test",
+				ClientSecret: "client-secret-for-test",
+			},
+		},
+	}
+	clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=")
+	if !pass {
+		t.Errorf("pass should be true, but false")
+		return
+	}
+	if clientId != "client-id-for-test" {
+		t.Errorf("client id should be \"%s\", but \"%s\"", "client-id-for-test", clientId)
+		return
+	}
+}
+
+func TestCheckRequestWithCredentials_Failed_WithInvalidConfig(t *testing.T) {
+	Config = &config.Configuration{
+		Credentials: []config.Credential{
+			config.Credential{
+				ClientID:     "client-id-for-test",
+				ClientSecret: "invalid-client-secret-for-test",
+			},
+		},
+	}
+	clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=")
+	if pass {
+		t.Errorf("pass should be false, but true")
+		return
+	}
+	if clientId != "" {
+		t.Errorf("client id should be empty")
+		return
+	}
+}