context_test.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package chatgpt
  2. import (
  3. "os"
  4. "testing"
  5. "time"
  6. "github.com/joho/godotenv"
  7. "github.com/stretchr/testify/assert"
  8. )
  9. func TestOfflineContext(t *testing.T) {
  10. key := os.Getenv("CHATGPT_API_KEY")
  11. if key == "" {
  12. t.Skip("CHATGPT_API_KEY is not set")
  13. }
  14. cli := New(key, "user1", time.Second*30)
  15. reply, err := cli.ChatWithContext("我叫老三,你是?")
  16. if err != nil {
  17. t.Fatal(err)
  18. }
  19. t.Logf("我叫老三,你是? => %s", reply)
  20. err = cli.ChatContext.SaveConversation("test.conversation")
  21. if err != nil {
  22. t.Fatalf("储存对话记录失败: %v", err)
  23. }
  24. cli.ChatContext.ResetConversation()
  25. reply, err = cli.ChatWithContext("你知道我是谁吗?")
  26. if err != nil {
  27. t.Fatal(err)
  28. }
  29. t.Logf("你知道我是谁吗? => %s", reply)
  30. assert.NotContains(t, reply, "老三")
  31. err = cli.ChatContext.LoadConversation("test.conversation")
  32. if err != nil {
  33. t.Fatalf("读取对话记录失败: %v", err)
  34. }
  35. reply, err = cli.ChatWithContext("你知道我是谁吗?")
  36. if err != nil {
  37. t.Fatal(err)
  38. }
  39. t.Logf("你知道我是谁吗? => %s", reply)
  40. // AI 理应知道他叫老三
  41. assert.Contains(t, reply, "老三")
  42. }
  43. func TestMaintainContext(t *testing.T) {
  44. key := os.Getenv("CHATGPT_API_KEY")
  45. if key == "" {
  46. t.Skip("CHATGPT_API_KEY is not set")
  47. }
  48. cli := New(key, "user1", time.Second*30)
  49. cli.ChatContext = NewContext(
  50. WithMaxSeqTimes(1),
  51. WithMaintainSeqTimes(true),
  52. )
  53. reply, err := cli.ChatWithContext("我叫老三,你是?")
  54. if err != nil {
  55. t.Fatal(err)
  56. }
  57. t.Logf("我叫老三,你是? => %s", reply)
  58. reply, err = cli.ChatWithContext("你知道我是谁吗?")
  59. if err != nil {
  60. t.Fatal(err)
  61. }
  62. t.Logf("你知道我是谁吗? => %s", reply)
  63. // 对话次数已经超过 1 次,因此最先前的对话已被移除,AI 理应不知道他叫老三
  64. assert.NotContains(t, reply, "老三")
  65. }
  66. func init() {
  67. // 本地加载适用于本地测试,如果要在github进行测试,可以透过传入 secrets 到环境参数
  68. _ = godotenv.Load(".env.local")
  69. }