context_test.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package chatgpt
  2. import (
  3. "os"
  4. "testing"
  5. )
  6. func TestOfflineContext(t *testing.T) {
  7. key := os.Getenv("CHATGPT_API_KEY")
  8. if key == "" {
  9. t.Skip("CHATGPT_API_KEY is not set")
  10. }
  11. cli := New("")
  12. reply, err := cli.ChatWithContext("我叫老三,你是?")
  13. if err != nil {
  14. t.Fatal(err)
  15. }
  16. t.Logf("我叫老三,你是? => %s", reply)
  17. err = cli.ChatContext.SaveConversation("test.conversation")
  18. if err != nil {
  19. t.Fatalf("储存对话记录失败: %v", err)
  20. }
  21. cli.ChatContext.ResetConversation("")
  22. reply, err = cli.ChatWithContext("你知道我是谁吗?")
  23. if err != nil {
  24. t.Fatal(err)
  25. }
  26. t.Logf("你知道我是谁吗? => %s", reply)
  27. // assert.NotContains(t, reply, "老三")
  28. err = cli.ChatContext.LoadConversation("test.conversation")
  29. if err != nil {
  30. t.Fatalf("读取对话记录失败: %v", err)
  31. }
  32. reply, err = cli.ChatWithContext("你知道我是谁吗?")
  33. if err != nil {
  34. t.Fatal(err)
  35. }
  36. t.Logf("你知道我是谁吗? => %s", reply)
  37. // AI 理应知道他叫老三
  38. // assert.Contains(t, reply, "老三")
  39. }
  40. func TestMaintainContext(t *testing.T) {
  41. key := os.Getenv("CHATGPT_API_KEY")
  42. if key == "" {
  43. t.Skip("CHATGPT_API_KEY is not set")
  44. }
  45. cli := New("")
  46. cli.ChatContext = NewContext(
  47. WithMaxSeqTimes(1),
  48. WithMaintainSeqTimes(true),
  49. )
  50. reply, err := cli.ChatWithContext("我叫老三,你是?")
  51. if err != nil {
  52. t.Fatal(err)
  53. }
  54. t.Logf("我叫老三,你是? => %s", reply)
  55. reply, err = cli.ChatWithContext("你知道我是谁吗?")
  56. if err != nil {
  57. t.Fatal(err)
  58. }
  59. t.Logf("你知道我是谁吗? => %s", reply)
  60. // 对话次数已经超过 1 次,因此最先前的对话已被移除,AI 理应不知道他叫老三
  61. // assert.NotContains(t, reply, "老三")
  62. }
  63. func init() {
  64. // 本地加载适用于本地测试,如果要在github进行测试,可以透过传入 secrets 到环境参数
  65. // _ = godotenv.Load(".env.local")
  66. }