context_test.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package chatgpt
  2. import (
  3. "os"
  4. "testing"
  5. "time"
  6. "github.com/joho/godotenv"
  7. )
  8. func TestOfflineContext(t *testing.T) {
  9. key := os.Getenv("CHATGPT_API_KEY")
  10. if key == "" {
  11. t.Skip("CHATGPT_API_KEY is not set")
  12. }
  13. cli := New(key, "", "user1", time.Second*30)
  14. reply, err := cli.ChatWithContext("我叫老三,你是?")
  15. if err != nil {
  16. t.Fatal(err)
  17. }
  18. t.Logf("我叫老三,你是? => %s", reply)
  19. err = cli.ChatContext.SaveConversation("test.conversation")
  20. if err != nil {
  21. t.Fatalf("储存对话记录失败: %v", err)
  22. }
  23. cli.ChatContext.ResetConversation("")
  24. reply, err = cli.ChatWithContext("你知道我是谁吗?")
  25. if err != nil {
  26. t.Fatal(err)
  27. }
  28. t.Logf("你知道我是谁吗? => %s", reply)
  29. // assert.NotContains(t, reply, "老三")
  30. err = cli.ChatContext.LoadConversation("test.conversation")
  31. if err != nil {
  32. t.Fatalf("读取对话记录失败: %v", err)
  33. }
  34. reply, err = cli.ChatWithContext("你知道我是谁吗?")
  35. if err != nil {
  36. t.Fatal(err)
  37. }
  38. t.Logf("你知道我是谁吗? => %s", reply)
  39. // AI 理应知道他叫老三
  40. // assert.Contains(t, reply, "老三")
  41. }
  42. func TestMaintainContext(t *testing.T) {
  43. key := os.Getenv("CHATGPT_API_KEY")
  44. if key == "" {
  45. t.Skip("CHATGPT_API_KEY is not set")
  46. }
  47. cli := New(key, "", "user1", time.Second*30)
  48. cli.ChatContext = NewContext(
  49. WithMaxSeqTimes(1),
  50. WithMaintainSeqTimes(true),
  51. )
  52. reply, err := cli.ChatWithContext("我叫老三,你是?")
  53. if err != nil {
  54. t.Fatal(err)
  55. }
  56. t.Logf("我叫老三,你是? => %s", reply)
  57. reply, err = cli.ChatWithContext("你知道我是谁吗?")
  58. if err != nil {
  59. t.Fatal(err)
  60. }
  61. t.Logf("你知道我是谁吗? => %s", reply)
  62. // 对话次数已经超过 1 次,因此最先前的对话已被移除,AI 理应不知道他叫老三
  63. // assert.NotContains(t, reply, "老三")
  64. }
  65. func init() {
  66. // 本地加载适用于本地测试,如果要在github进行测试,可以透过传入 secrets 到环境参数
  67. _ = godotenv.Load(".env.local")
  68. }