Skip to content

Commit

Permalink
add test to tracker store
Browse files Browse the repository at this point in the history
  • Loading branch information
vcidst committed Mar 4, 2024
1 parent 76c64fe commit ac219a1
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion tests/core/test_tracker_stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,36 @@ async def test_sql_additional_events_with_session_start(domain: Domain):
assert isinstance(additional_events[0], UserUttered)


async def test_tracker_store_retrieve_ordered_by_id(
domain: Domain,
):
tracker_store_kwargs = {"host": "sqlite:///"}
tracker_store = SQLTrackerStore(domain, **tracker_store_kwargs)
events = [
SessionStarted(timestamp=1),
UserUttered("Hola", {"name": "greet"}, timestamp=2),
BotUttered("Hi", timestamp=2),
UserUttered("How are you?", {"name": "greet"}, timestamp=2),
BotUttered("I am good, whats up", timestamp=2),
UserUttered("Ciao", {"name": "greet"}, timestamp=2),
BotUttered("Bye", timestamp=2),
]
sender_id = "test_sql_tracker_store_events_order"
tracker = DialogueStateTracker.from_events(sender_id, events)
await tracker_store.save(tracker)

# Save other tracker to ensure that we don't run into problems with other senders
other_tracker = DialogueStateTracker.from_events("other-sender", [SessionStarted()])
await tracker_store.save(other_tracker)

# Retrieve tracker with events since latest SessionStarted
tracker = await tracker_store.retrieve(sender_id)

assert len(tracker.events) == 7
# assert the order of events is same as the order in which they were added
assert all((event == tracker.events[i] for i, event in enumerate(events)))


@pytest.mark.parametrize(
"tracker_store_type,tracker_store_kwargs",
[(MockedMongoTrackerStore, {}), (SQLTrackerStore, {"host": "sqlite:///"})],
Expand Down Expand Up @@ -641,7 +671,7 @@ async def test_tracker_store_retrieve_with_session_started_events(
# Retrieve tracker with events since latest SessionStarted
tracker = await tracker_store.retrieve(sender_id)

assert len(tracker.events) == 2
assert len(tracker.events) == 3
assert all((event == tracker.events[i] for i, event in enumerate(events[2:])))


Expand Down

0 comments on commit ac219a1

Please sign in to comment.