From ac219a1531b4d987a18667aa36d33cd09f9b2885 Mon Sep 17 00:00:00 2001 From: vcidst Date: Mon, 4 Mar 2024 16:30:00 +0100 Subject: [PATCH] add test to tracker store --- tests/core/test_tracker_stores.py | 32 ++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/core/test_tracker_stores.py b/tests/core/test_tracker_stores.py index fb4a891b097e..39122e664731 100644 --- a/tests/core/test_tracker_stores.py +++ b/tests/core/test_tracker_stores.py @@ -614,6 +614,36 @@ async def test_sql_additional_events_with_session_start(domain: Domain): assert isinstance(additional_events[0], UserUttered) +async def test_tracker_store_retrieve_ordered_by_id( + domain: Domain, +): + tracker_store_kwargs = {"host": "sqlite:///"} + tracker_store = SQLTrackerStore(domain, **tracker_store_kwargs) + events = [ + SessionStarted(timestamp=1), + UserUttered("Hola", {"name": "greet"}, timestamp=2), + BotUttered("Hi", timestamp=2), + UserUttered("How are you?", {"name": "greet"}, timestamp=2), + BotUttered("I am good, whats up", timestamp=2), + UserUttered("Ciao", {"name": "greet"}, timestamp=2), + BotUttered("Bye", timestamp=2), + ] + sender_id = "test_sql_tracker_store_events_order" + tracker = DialogueStateTracker.from_events(sender_id, events) + await tracker_store.save(tracker) + + # Save other tracker to ensure that we don't run into problems with other senders + other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()]) + await tracker_store.save(other_tracker) + + # Retrieve tracker with events since latest SessionStarted + tracker = await tracker_store.retrieve(sender_id) + + assert len(tracker.events) == 7 + # assert the order of events is same as the order in which they were added + assert all((event == tracker.events[i] for i, event in enumerate(events))) + + @pytest.mark.parametrize( "tracker_store_type,tracker_store_kwargs", [(MockedMongoTrackerStore, {}), (SQLTrackerStore, {"host": "sqlite:///"})], @@ -641,7 +671,7 @@ async def test_tracker_store_retrieve_with_session_started_events( # Retrieve tracker with events since latest SessionStarted tracker = await tracker_store.retrieve(sender_id) - assert len(tracker.events) == 2 + assert len(tracker.events) == 3 assert all((event == tracker.events[i] for i, event in enumerate(events[2:])))