diff --git a/tests/shared/utils/test_llm.py b/tests/shared/utils/test_llm.py index 44e00ca917c1..175408a16dd9 100644 --- a/tests/shared/utils/test_llm.py +++ b/tests/shared/utils/test_llm.py @@ -59,6 +59,49 @@ def test_tracker_as_readable_transcript_handles_tracker_with_events_and_max_turn assert tracker_as_readable_transcript(tracker, max_turns=1) == ("""AI: hi""") +def test_tracker_as_readable_transcript_and_discard_excess_turns_with_default_max_turns( + domain: Domain, +): + tracker = DialogueStateTracker(sender_id="test", slots=domain.slots) + tracker.update_with_events( + [ + UserUttered("A0"), + BotUttered("B1"), + UserUttered("C2"), + BotUttered("D3"), + UserUttered("E4"), + BotUttered("F5"), + UserUttered("G6"), + BotUttered("H7"), + UserUttered("I8"), + BotUttered("J9"), + UserUttered("K10"), + BotUttered("L11"), + UserUttered("M12"), + BotUttered("N13"), + UserUttered("O14"), + BotUttered("P15"), + UserUttered("Q16"), + BotUttered("R17"), + UserUttered("S18"), + BotUttered("T19"), + UserUttered("U20"), + BotUttered("V21"), + UserUttered("W22"), + BotUttered("X23"), + UserUttered("Y24"), + ], + domain, + ) + response = tracker_as_readable_transcript(tracker) + assert response == ( + """AI: F5\nUSER: G6\nAI: H7\nUSER: I8\nAI: J9\nUSER: K10\nAI: L11\n""" + """USER: M12\nAI: N13\nUSER: O14\nAI: P15\nUSER: Q16\nAI: R17\nUSER: S18\n""" + """AI: T19\nUSER: U20\nAI: V21\nUSER: W22\nAI: X23\nUSER: Y24""" + ) + assert response.count("\n") == 19 + + def test_sanitize_message_for_prompt_handles_none(): assert sanitize_message_for_prompt(None) == ""