diff --git a/rasa/dialogue_understanding/commands/start_flow_command.py b/rasa/dialogue_understanding/commands/start_flow_command.py index 11ab06bb162b..3604d46afff5 100644 --- a/rasa/dialogue_understanding/commands/start_flow_command.py +++ b/rasa/dialogue_understanding/commands/start_flow_command.py @@ -70,7 +70,7 @@ def run_command_on_tracker( "command_executor.skip_command.already_started_flow", command=self ) return [] - elif self.flow not in all_flows.non_pattern_flows(): + elif self.flow not in all_flows.user_flow_ids: structlogger.debug( "command_executor.skip_command.start_invalid_flow_id", command=self ) diff --git a/rasa/dialogue_understanding/generator/command_prompt_template.jinja2 b/rasa/dialogue_understanding/generator/command_prompt_template.jinja2 index bb46fdcdf663..039ef52dc360 100644 --- a/rasa/dialogue_understanding/generator/command_prompt_template.jinja2 +++ b/rasa/dialogue_understanding/generator/command_prompt_template.jinja2 @@ -14,8 +14,8 @@ Here is what happened previously in the conversation: === {% if current_flow != None %} -You are currently in the flow "{{ current_flow }}", which {{ current_flow.description }} -You have just asked the user for the slot "{{ collect }}"{% if collect_description %} ({{ collect_description }}){% endif %}. +You are currently in the flow "{{ current_flow }}". +You have just asked the user for the slot "{{ current_slot }}"{% if current_slot_description %} ({{ current_slot_description }}){% endif %}. {% if flow_slots|length > 0 %} Here are the slots of the currently active flow: diff --git a/rasa/dialogue_understanding/generator/llm_command_generator.py b/rasa/dialogue_understanding/generator/llm_command_generator.py index 27c2068f3de5..aad8e01e90d0 100644 --- a/rasa/dialogue_understanding/generator/llm_command_generator.py +++ b/rasa/dialogue_understanding/generator/llm_command_generator.py @@ -1,9 +1,10 @@ import importlib.resources import re -from typing import Dict, Any, Optional, List, Union +from typing import Dict, Any, List, Optional, Tuple, Union from jinja2 import Template import structlog + from rasa.dialogue_understanding.stack.utils import top_flow_frame from rasa.dialogue_understanding.generator import CommandGenerator from rasa.dialogue_understanding.commands import ( @@ -22,7 +23,12 @@ from rasa.engine.recipes.default_recipe import DefaultV1Recipe from rasa.engine.storage.resource import Resource from rasa.engine.storage.storage import ModelStorage -from rasa.shared.core.flows.flow import FlowStep, FlowsList, CollectInformationFlowStep +from rasa.shared.core.flows.flow import ( + Flow, + FlowStep, + FlowsList, + CollectInformationFlowStep, +) from rasa.shared.core.trackers import DialogueStateTracker from rasa.shared.core.slots import ( BooleanSlot, @@ -48,9 +54,6 @@ "rasa.dialogue_understanding.generator", "command_prompt_template.jinja2" ) -structlogger = structlog.get_logger() - - DEFAULT_LLM_CONFIG = { "_type": "openai", "request_timeout": 7, @@ -60,6 +63,8 @@ LLM_CONFIG_KEY = "llm" +structlogger = structlog.get_logger() + @DefaultV1Recipe.register( [ @@ -68,6 +73,8 @@ is_trainable=True, ) class LLMCommandGenerator(GraphComponent, CommandGenerator): + """An LLM-based command generator.""" + @staticmethod def get_default_config() -> Dict[str, Any]: """The component's default config (see parent class for full docstring).""" @@ -98,9 +105,6 @@ def create( """Creates a new untrained component (see parent class for full docstring).""" return cls(config, model_storage, resource) - def persist(self) -> None: - pass - @classmethod def load( cls, @@ -113,36 +117,30 @@ def load( """Loads trained component (see parent class for full docstring).""" return cls(config, model_storage, resource) + def persist(self) -> None: + pass + def train(self, training_data: TrainingData) -> Resource: """Train the intent classifier on a data set.""" self.persist() return self._resource - def _generate_action_list_using_llm(self, prompt: str) -> Optional[str]: - """Use LLM to generate a response. - - Args: - prompt: the prompt to send to the LLM - - Returns: - generated text - """ - llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG) - - try: - return llm(prompt) - except Exception as e: - # unfortunately, langchain does not wrap LLM exceptions which means - # we have to catch all exceptions here - structlogger.error("llm_command_generator.llm.error", error=e) - return None - def predict_commands( self, message: Message, flows: FlowsList, tracker: Optional[DialogueStateTracker] = None, ) -> List[Command]: + """Predict commands using the LLM. + + Args: + message: The message from the user. + flows: The flows available to the user. + tracker: The tracker containing the current state of the conversation. + + Returns: + The commands generated by the llm. + """ if tracker is None or flows.is_empty(): # cannot do anything if there are no flows or no tracker return [] @@ -163,62 +161,77 @@ def predict_commands( return commands - @staticmethod - def is_none_value(value: str) -> bool: - return value in { - "[missing information]", - "[missing]", - "None", - "undefined", - "null", - } + def render_template( + self, message: Message, tracker: DialogueStateTracker, flows: FlowsList + ) -> str: + """Render the jinja template to create the prompt for the LLM. - @staticmethod - def clean_extracted_value(value: str) -> str: - """Clean up the extracted value from the llm.""" - # replace any combination of single quotes, double quotes, and spaces - # from the beginning and end of the string - return re.sub(r"^['\"\s]+|['\"\s]+$", "", value) + Args: + message: The current message from the user. + tracker: The tracker containing the current state of the conversation. + flows: The flows available to the user. - @classmethod - def coerce_slot_value( - cls, slot_value: str, slot_name: str, tracker: DialogueStateTracker - ) -> Union[str, bool, float, None]: - """Coerce the slot value to the correct type. + Returns: + The rendered prompt template. + """ + top_relevant_frame = top_flow_frame(DialogueStack.from_tracker(tracker)) + top_flow = top_relevant_frame.flow(flows) if top_relevant_frame else None + current_step = top_relevant_frame.step(flows) if top_relevant_frame else None - Tries to coerce the slot value to the correct type. If the - conversion fails, `None` is returned. + flow_slots = self.prepare_current_flow_slots_for_template( + top_flow, current_step, tracker + ) + current_slot, current_slot_description = self.prepare_current_slot_for_template( + current_step + ) + current_conversation = tracker_as_readable_transcript(tracker) + latest_user_message = sanitize_message_for_prompt(message.get(TEXT)) + current_conversation += f"\nUSER: {latest_user_message}" + + inputs = { + "available_flows": self.prepare_flows_for_template(flows, tracker), + "current_conversation": current_conversation, + "flow_slots": flow_slots, + "current_flow": top_flow.id if top_flow is not None else None, + "current_slot": current_slot, + "current_slot_description": current_slot_description, + "user_message": latest_user_message, + } + + return Template(self.prompt_template).render(**inputs) + + def _generate_action_list_using_llm(self, prompt: str) -> Optional[str]: + """Use LLM to generate a response. Args: - value: the value to coerce - slot_name: the name of the slot - tracker: the tracker + prompt: The prompt to send to the LLM. Returns: - the coerced value or `None` if the conversion failed.""" - nullable_value = slot_value if not cls.is_none_value(slot_value) else None - if slot_name not in tracker.slots: - return nullable_value + The generated text. + """ + llm = llm_factory(self.config.get(LLM_CONFIG_KEY), DEFAULT_LLM_CONFIG) - slot = tracker.slots[slot_name] - if isinstance(slot, BooleanSlot): - try: - return bool_from_any(nullable_value) - except (ValueError, TypeError): - return None - elif isinstance(slot, FloatSlot): - try: - return float(nullable_value) - except (ValueError, TypeError): - return None - else: - return nullable_value + try: + return llm(prompt) + except Exception as e: + # unfortunately, langchain does not wrap LLM exceptions which means + # we have to catch all exceptions here + structlogger.error("llm_command_generator.llm.error", error=e) + return None @classmethod def parse_commands( cls, actions: Optional[str], tracker: DialogueStateTracker ) -> List[Command]: - """Parse the actions returned by the llm into intent and entities.""" + """Parse the actions returned by the llm into intent and entities. + + Args: + actions: The actions returned by the llm. + tracker: The tracker containing the current state of the conversation. + + Returns: + The parsed commands. + """ if not actions: return [ErrorCommand()] @@ -231,13 +244,13 @@ def parse_commands( cancel_flow_re = re.compile(r"CancelFlow\(\)") chitchat_re = re.compile(r"ChitChat\(\)") knowledge_re = re.compile(r"SearchAndReply\(\)") - humand_handoff_re = re.compile(r"HumandHandoff\(\)") + humand_handoff_re = re.compile(r"HumanHandoff\(\)") clarify_re = re.compile(r"Clarify\(([a-zA-Z0-9_, ]+)\)") for action in actions.strip().splitlines(): - if m := slot_set_re.search(action): - slot_name = m.group(1).strip() - slot_value = cls.clean_extracted_value(m.group(2)) + if match := slot_set_re.search(action): + slot_name = match.group(1).strip() + slot_value = cls.clean_extracted_value(match.group(2)) # error case where the llm tries to start a flow using a slot set if slot_name == "flow_name": commands.append(StartFlowCommand(flow=slot_value)) @@ -248,8 +261,8 @@ def parse_commands( commands.append( SetSlotCommand(name=slot_name, value=typed_slot_value) ) - elif m := start_flow_re.search(action): - commands.append(StartFlowCommand(flow=m.group(1).strip())) + elif match := start_flow_re.search(action): + commands.append(StartFlowCommand(flow=match.group(1).strip())) elif cancel_flow_re.search(action): commands.append(CancelFlowCommand()) elif chitchat_re.search(action): @@ -258,40 +271,97 @@ def parse_commands( commands.append(KnowledgeAnswerCommand()) elif humand_handoff_re.search(action): commands.append(HumanHandoffCommand()) - elif m := clarify_re.search(action): - options = [opt.strip() for opt in m.group(1).split(",")] + elif match := clarify_re.search(action): + options = [opt.strip() for opt in match.group(1).split(",")] commands.append(ClarifyCommand(options)) return commands + @staticmethod + def is_none_value(value: str) -> bool: + """Check if the value is a none value.""" + return value in { + "[missing information]", + "[missing]", + "None", + "undefined", + "null", + } + + @staticmethod + def clean_extracted_value(value: str) -> str: + """Clean up the extracted value from the llm.""" + # replace any combination of single quotes, double quotes, and spaces + # from the beginning and end of the string + return value.strip("'\" ") + + @classmethod + def coerce_slot_value( + cls, slot_value: str, slot_name: str, tracker: DialogueStateTracker + ) -> Union[str, bool, float, None]: + """Coerce the slot value to the correct type. + + Tries to coerce the slot value to the correct type. If the + conversion fails, `None` is returned. + + Args: + value: The value to coerce. + slot_name: The name of the slot. + tracker: The tracker containing the current state of the conversation. + + Returns: + The coerced value or `None` if the conversion failed. + """ + nullable_value = slot_value if not cls.is_none_value(slot_value) else None + if slot_name not in tracker.slots: + return nullable_value + + slot = tracker.slots[slot_name] + if isinstance(slot, BooleanSlot): + try: + return bool_from_any(nullable_value) + except (ValueError, TypeError): + return None + elif isinstance(slot, FloatSlot): + try: + return float(nullable_value) + except (ValueError, TypeError): + return None + else: + return nullable_value + @classmethod - def create_template_inputs( + def prepare_flows_for_template( cls, flows: FlowsList, tracker: DialogueStateTracker ) -> List[Dict[str, Any]]: + """Format data on available flows for insertion into the prompt template. + + Args: + flows: The flows available to the user. + tracker: The tracker containing the current state of the conversation. + + Returns: + The inputs for the prompt template. + """ result = [] - for flow in flows.underlying_flows: - # TODO: check if we should filter more flows; e.g. flows that are - # linked to by other flows and that shouldn't be started directly. - # we might need a separate flag for that. - if not flow.is_rasa_default_flow(): - - slots_with_info = [ - {"name": q.collect, "description": q.description} - for q in flow.get_collect_steps() - if cls.is_extractable(q, tracker) - ] - result.append( - { - "name": flow.id, - "description": flow.description, - "slots": slots_with_info, - } - ) + for flow in flows.user_flows: + slots_with_info = [ + {"name": q.collect, "description": q.description} + for q in flow.get_collect_steps() + if cls.is_extractable(q, tracker) + ] + result.append( + { + "name": flow.id, + "description": flow.description, + "slots": slots_with_info, + } + ) return result @staticmethod def is_extractable( - q: CollectInformationFlowStep, + collect_step: CollectInformationFlowStep, tracker: DialogueStateTracker, current_step: Optional[FlowStep] = None, ) -> bool: @@ -299,25 +369,34 @@ def is_extractable( A collect slot can only be filled if the slot exist and either the collect has been asked already or the - slot has been filled already.""" - slot = tracker.slots.get(q.collect) + slot has been filled already. + + Args: + collect_step: The collect_information step. + tracker: The tracker containing the current state of the conversation. + current_step: The current step in the flow. + + Returns: + `True` if the slot can be filled, `False` otherwise. + """ + slot = tracker.slots.get(collect_step.collect) if slot is None: return False return ( # we can fill because this is a slot that can be filled ahead of time - not q.ask_before_filling + not collect_step.ask_before_filling # we can fill because the slot has been filled already or slot.has_been_set # we can fill because the is currently getting asked or ( current_step is not None and isinstance(current_step, CollectInformationFlowStep) - and current_step.collect == q.collect + and current_step.collect == collect_step.collect ) ) - def allowed_values_for_slot(self, slot: Slot) -> Optional[str]: + def allowed_values_for_slot(self, slot: Slot) -> Union[str, None]: """Get the allowed values for a slot.""" if isinstance(slot, BooleanSlot): return str([True, False]) @@ -327,59 +406,59 @@ def allowed_values_for_slot(self, slot: Slot) -> Optional[str]: return None @staticmethod - def slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: - """Get the slot value from the tracker.""" + def get_slot_value(tracker: DialogueStateTracker, slot_name: str) -> str: + """Get the slot value from the tracker. + + Args: + tracker: The tracker containing the current state of the conversation. + slot_name: The name of the slot. + + Returns: + The slot value as a string. + """ slot_value = tracker.get_slot(slot_name) if slot_value is None: return "undefined" else: return str(slot_value) - def render_template( - self, message: Message, tracker: DialogueStateTracker, flows: FlowsList - ) -> str: - flows_without_patterns = FlowsList( - [f for f in flows.underlying_flows if not f.is_handling_pattern()] - ) - top_relevant_frame = top_flow_frame(DialogueStack.from_tracker(tracker)) - top_flow = top_relevant_frame.flow(flows) if top_relevant_frame else None - current_step = top_relevant_frame.step(flows) if top_relevant_frame else None + def prepare_current_flow_slots_for_template( + self, top_flow: Flow, current_step: FlowStep, tracker: DialogueStateTracker + ) -> List[Dict[str, Any]]: + """Prepare the current flow slots for the template. + + Args: + top_flow: The top flow. + current_step: The current step in the flow. + tracker: The tracker containing the current state of the conversation. + + Returns: + The slots with values, types, allowed values and a description. + """ if top_flow is not None: flow_slots = [ { - "name": q.collect, - "value": self.slot_value(tracker, q.collect), - "type": tracker.slots[q.collect].type_name, + "name": collect_step.collect, + "value": self.get_slot_value(tracker, collect_step.collect), + "type": tracker.slots[collect_step.collect].type_name, "allowed_values": self.allowed_values_for_slot( - tracker.slots[q.collect] + tracker.slots[collect_step.collect] ), - "description": q.description, + "description": collect_step.description, } - for q in top_flow.get_collect_steps() - if self.is_extractable(q, tracker, current_step) + for collect_step in top_flow.get_collect_steps() + if self.is_extractable(collect_step, tracker, current_step) ] else: flow_slots = [] + return flow_slots - collect, collect_description = ( + def prepare_current_slot_for_template( + self, current_step: FlowStep + ) -> Tuple[Union[str, None], Union[str, None]]: + """Prepare the current slot for the template.""" + return ( (current_step.collect, current_step.description) if isinstance(current_step, CollectInformationFlowStep) else (None, None) ) - current_conversation = tracker_as_readable_transcript(tracker) - latest_user_message = sanitize_message_for_prompt(message.get(TEXT)) - current_conversation += f"\nUSER: {latest_user_message}" - - inputs = { - "available_flows": self.create_template_inputs( - flows_without_patterns, tracker - ), - "current_conversation": current_conversation, - "flow_slots": flow_slots, - "current_flow": top_flow.id if top_flow is not None else None, - "collect": collect, - "collect_description": collect_description, - "user_message": latest_user_message, - } - - return Template(self.prompt_template).render(**inputs) diff --git a/rasa/shared/core/flows/flow.py b/rasa/shared/core/flows/flow.py index 365e52537118..023c5f621349 100644 --- a/rasa/shared/core/flows/flow.py +++ b/rasa/shared/core/flows/flow.py @@ -179,6 +179,17 @@ def __init__(self, flows: List[Flow]) -> None: """ self.underlying_flows = flows + def __iter__(self) -> Generator[Flow, None, None]: + """Iterates over the flows.""" + yield from self.underlying_flows + + def __eq__(self, other: Any) -> bool: + """Compares the flows.""" + return ( + isinstance(other, FlowsList) + and self.underlying_flows == other.underlying_flows + ) + def is_empty(self) -> bool: """Returns whether the flows list is empty.""" return len(self.underlying_flows) == 0 @@ -254,15 +265,23 @@ def validate(self) -> None: for flow in self.underlying_flows: flow.validate() - def non_pattern_flows(self) -> List[str]: - """Get all flows that can be started. + @property + def user_flow_ids(self) -> List[str]: + """Get all ids of flows that can be started by a user. - Args: - all_flows: All flows. + Returns: + The ids of all flows that can be started by a user.""" + return [f.id for f in self.user_flows] + + @property + def user_flows(self) -> FlowsList: + """Get all flows that can be started by a user. Returns: - All flows that can be started.""" - return [f.id for f in self.underlying_flows if not f.is_handling_pattern()] + All flows that can be started by a user.""" + return FlowsList( + [f for f in self.underlying_flows if not f.is_rasa_default_flow] + ) @property def utterances(self) -> Set[str]: @@ -505,10 +524,6 @@ def _previously_asked_collect( return _previously_asked_collect(step_id or START_STEP, set()) - def is_handling_pattern(self) -> bool: - """Returns whether the flow is handling a pattern.""" - return self.id.startswith(RASA_DEFAULT_FLOW_PATTERN_PREFIX) - def get_trigger_intents(self) -> Set[str]: """Returns the trigger intents of the flow""" results: Set[str] = set() @@ -529,6 +544,7 @@ def is_user_triggerable(self) -> bool: """Test whether a user can trigger the flow with an intent.""" return len(self.get_trigger_intents()) > 0 + @property def is_rasa_default_flow(self) -> bool: """Test whether something is a rasa default flow.""" return self.id.startswith(RASA_DEFAULT_FLOW_PATTERN_PREFIX) diff --git a/tests/core/flows/test_flow.py b/tests/core/flows/test_flow.py index 9d4f6b7692c1..87f4e2551b86 100644 --- a/tests/core/flows/test_flow.py +++ b/tests/core/flows/test_flow.py @@ -1,10 +1,13 @@ -from rasa.shared.core.flows.flow import FlowsList -from rasa.shared.core.flows.yaml_flows_io import flows_from_str +import pytest + +from rasa.shared.core.flows.flow import Flow, FlowsList from rasa.shared.importers.importer import FlowSyncImporter +from tests.utilities import flows_from_str -def test_non_pattern_flows(): - all_flows = flows_from_str( +@pytest.fixture +def user_flows_and_patterns() -> FlowsList: + return flows_from_str( """ flows: foo: @@ -17,15 +20,11 @@ def test_non_pattern_flows(): action: action_listen """ ) - assert all_flows.non_pattern_flows() == ["foo"] - -def test_non_pattern_handles_empty(): - assert FlowsList(flows=[]).non_pattern_flows() == [] - -def test_non_pattern_flows_handles_patterns_only(): - all_flows = flows_from_str( +@pytest.fixture +def only_patterns() -> FlowsList: + return flows_from_str( """ flows: pattern_bar: @@ -34,7 +33,41 @@ def test_non_pattern_flows_handles_patterns_only(): action: action_listen """ ) - assert all_flows.non_pattern_flows() == [] + + +@pytest.fixture +def empty_flowlist() -> FlowsList: + return FlowsList(flows=[]) + + +def test_user_flow_ids(user_flows_and_patterns: FlowsList): + assert user_flows_and_patterns.user_flow_ids == ["foo"] + + +def test_user_flow_ids_handles_empty(empty_flowlist: FlowsList): + assert empty_flowlist.user_flow_ids == [] + + +def test_user_flow_ids_handles_patterns_only(only_patterns: FlowsList): + assert only_patterns.user_flow_ids == [] + + +def test_user_flows(user_flows_and_patterns: FlowsList): + user_flows = user_flows_and_patterns.user_flows + expected_user_flows = FlowsList( + [Flow.from_json("foo", {"steps": [{"id": "first", "action": "action_listen"}]})] + ) + assert user_flows == expected_user_flows + + +def test_user_flows_handles_empty(empty_flowlist: FlowsList): + assert empty_flowlist.user_flows == empty_flowlist + + +def test_user_flows_handles_patterns_only( + only_patterns: FlowsList, empty_flowlist: FlowsList +): + assert only_patterns.user_flows == empty_flowlist def test_collecting_flow_utterances(): diff --git a/tests/dialogue_understanding/generator/rendered_prompt.txt b/tests/dialogue_understanding/generator/rendered_prompt.txt new file mode 100644 index 000000000000..d9e9fd4c14e7 --- /dev/null +++ b/tests/dialogue_understanding/generator/rendered_prompt.txt @@ -0,0 +1,45 @@ +Your task is to analyze the current conversation context and generate a list of actions to start new business processes that we call flows, to extract slots, or respond to small talk and knowledge requests. + +These are the flows that can be started, with their description and slots: + +test_flow: some description + slot: test_slot + + +=== +Here is what happened previously in the conversation: +USER: Hello +AI: Hi +USER: some message + +=== + +You are currently not in any flow and so there are no active slots. +This means you can only set a slot if you first start a flow that requires that slot. + +If you start a flow, first start the flow and then optionally fill that flow's slots with information the user provided in their message. + +The user just said """some message""". + +=== +Based on this information generate a list of actions you want to take. Your job is to start flows and to fill slots where appropriate. Any logic of what happens afterwards is handled by the flow engine. These are your available actions: +* Slot setting, described by "SetSlot(slot_name, slot_value)". An example would be "SetSlot(recipient, Freddy)" +* Starting another flow, described by "StartFlow(flow_name)". An example would be "StartFlow(transfer_money)" +* Cancelling the current flow, described by "CancelFlow()" +* Clarifying which flow should be started. An example would be Clarify(list_contacts, add_contact, remove_contact) if the user just wrote "contacts" and there are multiple potential candidates. It also works with a single flow name to confirm you understood correctly, as in Clarify(transfer_money). +* Responding to knowledge-oriented user messages, described by "SearchAndReply()" +* Responding to a casual, non-task-oriented user message, described by "ChitChat()". +* Handing off to a human, in case the user seems frustrated or explicitly asks to speak to one, described by "HumanHandoff()". + +=== +Write out the actions you want to take, one per line, in the order they should take place. +Do not fill slots with abstract values or placeholders. +Only use information provided by the user. +Only start a flow if it's completely clear what the user wants. Imagine you were a person reading this message. If it's not 100% clear, clarify the next step. +Don't be overly confident. Take a conservative approach and clarify before proceeding. +If the user asks for two things which seem contradictory, clarify before starting a flow. +Strictly adhere to the provided action types listed above. +Focus on the last message and take it one step at a time. +Use the previous conversation steps only to aid understanding. + +Your action list: \ No newline at end of file diff --git a/tests/dialogue_understanding/generator/test_llm_command_generator.py b/tests/dialogue_understanding/generator/test_llm_command_generator.py index d2cd06d266fe..56c50fdcf8dd 100644 --- a/tests/dialogue_understanding/generator/test_llm_command_generator.py +++ b/tests/dialogue_understanding/generator/test_llm_command_generator.py @@ -1,41 +1,468 @@ import uuid +from typing import Optional, Any +from unittest.mock import Mock, patch + import pytest from _pytest.tmpdir import TempPathFactory +from structlog.testing import capture_logs + from rasa.dialogue_understanding.generator.llm_command_generator import ( LLMCommandGenerator, ) +from rasa.dialogue_understanding.commands import ( + Command, + ErrorCommand, + SetSlotCommand, + CancelFlowCommand, + StartFlowCommand, + HumanHandoffCommand, + ChitChatAnswerCommand, + KnowledgeAnswerCommand, + ClarifyCommand, +) from rasa.engine.storage.local_model_storage import LocalModelStorage from rasa.engine.storage.resource import Resource from rasa.engine.storage.storage import ModelStorage +from rasa.shared.core.events import BotUttered, SlotSet, UserUttered +from rasa.shared.core.flows.flow import ( + CollectInformationFlowStep, + FlowsList, + SlotRejection, +) +from rasa.shared.core.slots import ( + Slot, + BooleanSlot, + CategoricalSlot, + FloatSlot, + TextSlot, +) +from rasa.shared.core.trackers import DialogueStateTracker +from rasa.shared.nlu.training_data.message import Message +from tests.utilities import flows_from_str + + +EXPECTED_PROMPT_PATH = "./tests/dialogue_understanding/generator/rendered_prompt.txt" + + +class TestLLMCommandGenerator: + """Tests for the LLMCommandGenerator.""" + + @pytest.fixture + def command_generator(self): + """Create an LLMCommandGenerator.""" + return LLMCommandGenerator.create( + config={}, resource=Mock(), model_storage=Mock(), execution_context=Mock() + ) + + @pytest.fixture + def flows(self) -> FlowsList: + """Create a FlowsList.""" + return flows_from_str( + """ + flows: + test_flow: + steps: + - id: first_step + action: action_listen + """ + ) + + @pytest.fixture(scope="session") + def resource(self) -> Resource: + return Resource(uuid.uuid4().hex) + + @pytest.fixture(scope="session") + def model_storage(self, tmp_path_factory: TempPathFactory) -> ModelStorage: + return LocalModelStorage(tmp_path_factory.mktemp(uuid.uuid4().hex)) + + async def test_llm_command_generator_prompt_init_custom( + self, model_storage: ModelStorage, resource: Resource + ) -> None: + generator = LLMCommandGenerator( + {"prompt": "data/test_prompt_templates/test_prompt.jinja2"}, + model_storage, + resource, + ) + assert generator.prompt_template.startswith("This is a test prompt.") + + async def test_llm_command_generator_prompt_init_default( + self, model_storage: ModelStorage, resource: Resource + ) -> None: + generator = LLMCommandGenerator({}, model_storage, resource) + assert generator.prompt_template.startswith( + "Your task is to analyze the current conversation" + ) + def test_predict_commands_with_no_flows( + self, command_generator: LLMCommandGenerator + ): + """Test that predict_commands returns an empty list when flows is None.""" + # Given + empty_flows = FlowsList([]) + # When + predicted_commands = command_generator.predict_commands( + Mock(), flows=empty_flows, tracker=Mock() + ) + # Then + assert not predicted_commands -@pytest.fixture(scope="session") -def resource() -> Resource: - return Resource(uuid.uuid4().hex) + def test_predict_commands_with_no_tracker( + self, command_generator: LLMCommandGenerator + ): + """Test that predict_commands returns an empty list when tracker is None.""" + # When + predicted_commands = command_generator.predict_commands( + Mock(), flows=Mock(), tracker=None + ) + # Then + assert not predicted_commands + def test_generate_action_list_calls_llm_factory_correctly( + self, + command_generator: LLMCommandGenerator, + ): + """Test that _generate_action_list calls llm correctly.""" + # Given + llm_config = { + "_type": "openai", + "request_timeout": 7, + "temperature": 0.0, + "model_name": "gpt-4", + } + # When + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(), + ) as mock_llm_factory: + command_generator._generate_action_list_using_llm("some prompt") + # Then + mock_llm_factory.assert_called_once_with(None, llm_config) -@pytest.fixture(scope="session") -def model_storage(tmp_path_factory: TempPathFactory) -> ModelStorage: - return LocalModelStorage(tmp_path_factory.mktemp(uuid.uuid4().hex)) + def test_generate_action_list_calls_llm_correctly( + self, + command_generator: LLMCommandGenerator, + ): + """Test that _generate_action_list calls llm correctly.""" + # Given + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(), + ) as mock_llm_factory: + mock_llm_factory.return_value = Mock() + # When + command_generator._generate_action_list_using_llm("some prompt") + # Then + mock_llm_factory.return_value.assert_called_once_with("some prompt") + def test_generate_action_list_catches_llm_exception( + self, + command_generator: LLMCommandGenerator, + ): + """Test that _generate_action_list calls llm correctly.""" + # When + mock_llm = Mock(side_effect=Exception("some exception")) + with patch( + "rasa.dialogue_understanding.generator.llm_command_generator.llm_factory", + Mock(return_value=mock_llm), + ): + with capture_logs() as logs: + command_generator._generate_action_list_using_llm("some prompt") + # Then + print(logs) + assert len(logs) == 1 + assert logs[0]["event"] == "llm_command_generator.llm.error" -async def test_llm_command_generator_prompt_init_custom( - model_storage: ModelStorage, resource: Resource -) -> None: - generator = LLMCommandGenerator( - {"prompt": "data/test_prompt_templates/test_prompt.jinja2"}, - model_storage, - resource, + def test_render_template( + self, + command_generator: LLMCommandGenerator, + ): + """Test that render_template renders the correct template string.""" + # Given + test_message = Message.build(text="some message") + test_slot = TextSlot( + name="test_slot", + mappings=[{}], + initial_value=None, + influence_conversation=False, + ) + test_tracker = DialogueStateTracker.from_events( + sender_id="test", + evts=[UserUttered("Hello"), BotUttered("Hi")], + slots=[test_slot], + ) + test_flows = flows_from_str( + """ + flows: + test_flow: + description: some description + steps: + - id: first_step + collect: test_slot + """ + ) + with open(EXPECTED_PROMPT_PATH, "r", encoding="unicode_escape") as f: + expected_template = f.readlines() + # When + rendered_template = command_generator.render_template( + message=test_message, tracker=test_tracker, flows=test_flows + ) + # Then + for rendered_line, expected_line in zip( + rendered_template.splitlines(True), expected_template + ): + assert rendered_line == expected_line + + @pytest.mark.parametrize( + "input_action, expected_command", + [ + (None, [ErrorCommand()]), + ( + "SetSlot(transfer_money_amount_of_money, )", + [SetSlotCommand(name="transfer_money_amount_of_money", value=None)], + ), + ("SetSlot(flow_name, some_flow)", [StartFlowCommand(flow="some_flow")]), + ("StartFlow(check_balance)", [StartFlowCommand(flow="check_balance")]), + ("CancelFlow()", [CancelFlowCommand()]), + ("ChitChat()", [ChitChatAnswerCommand()]), + ("SearchAndReply()", [KnowledgeAnswerCommand()]), + ("HumanHandoff()", [HumanHandoffCommand()]), + ("Clarify(transfer_money)", [ClarifyCommand(options=["transfer_money"])]), + ( + "Clarify(list_contacts, add_contact, remove_contact)", + [ + ClarifyCommand( + options=["list_contacts", "add_contact", "remove_contact"] + ) + ], + ), + ( + "Here is a list of commands:\nSetSlot(flow_name, some_flow)\n", + [StartFlowCommand(flow="some_flow")], + ), + ( + """SetSlot(flow_name, some_flow) + SetSlot(transfer_money_amount_of_money,)""", + [ + StartFlowCommand(flow="some_flow"), + SetSlotCommand(name="transfer_money_amount_of_money", value=None), + ], + ), + ], + ) + def test_parse_commands_identifies_correct_command( + self, + input_action: Optional[str], + expected_command: Command, + ): + """Test that parse_commands identifies the correct commands.""" + # When + with patch.object( + LLMCommandGenerator, "coerce_slot_value", Mock(return_value=None) + ): + parsed_commands = LLMCommandGenerator.parse_commands(input_action, Mock()) + # Then + assert parsed_commands == expected_command + + @pytest.mark.parametrize( + "slot_name, slot, slot_value, expected_output", + [ + ("some_other_slot", FloatSlot("some_float", []), None, None), + ("some_float", FloatSlot("some_float", []), 40, 40.0), + ("some_float", FloatSlot("some_float", []), 40.0, 40.0), + ("some_text", TextSlot("some_text", []), "fourty", "fourty"), + ("some_bool", BooleanSlot("some_bool", []), "True", True), + ("some_bool", BooleanSlot("some_bool", []), "false", False), + ], + ) + def test_coerce_slot_value( + self, + slot_name: str, + slot: Slot, + slot_value: Any, + expected_output: Any, + ): + """Test that coerce_slot_value coerces the slot value correctly.""" + # Given + tracker = DialogueStateTracker.from_events("test", evts=[], slots=[slot]) + # When + coerced_value = LLMCommandGenerator.coerce_slot_value( + slot_value, slot_name, tracker + ) + # Then + assert coerced_value == expected_output + + @pytest.mark.parametrize( + "input_value, expected_output", + [ + ("text", "text"), + (" text ", "text"), + ('"text"', "text"), + ("'text'", "text"), + ("' \"text' \" ", "text"), + ("", ""), + ], + ) + def test_clean_extracted_value(self, input_value: str, expected_output: str): + """Test that clean_extracted_value removes + the leading and trailing whitespaces. + """ + # When + cleaned_value = LLMCommandGenerator.clean_extracted_value(input_value) + # Then + assert cleaned_value == expected_output + + @pytest.mark.parametrize( + "input_value, expected_truthiness", + [ + ("", False), + (" ", False), + ("none", False), + ("some text", False), + ("[missing information]", True), + ("[missing]", True), + ("None", True), + ("undefined", True), + ("null", True), + ], ) - assert generator.prompt_template.startswith("This is a test prompt.") + def test_is_none_value(self, input_value: str, expected_truthiness: bool): + """Test that is_none_value returns True when the value is None.""" + assert LLMCommandGenerator.is_none_value(input_value) == expected_truthiness + @pytest.mark.parametrize( + "slot, slot_name, expected_output", + [ + (TextSlot("test_slot", [], initial_value="hello"), "test_slot", "hello"), + (TextSlot("test_slot", []), "some_other_slot", "undefined"), + ], + ) + def test_slot_value(self, slot: Slot, slot_name: str, expected_output: str): + """Test that slot_value returns the correct string.""" + # Given + tracker = DialogueStateTracker.from_events("test", evts=[], slots=[slot]) + # When + slot_value = LLMCommandGenerator.get_slot_value(tracker, slot_name) + + assert slot_value == expected_output -async def test_llm_command_generator_prompt_init_default( - model_storage: ModelStorage, resource: Resource -) -> None: - generator = LLMCommandGenerator({}, model_storage, resource) - assert generator.prompt_template.startswith( - "Your task is to analyze the current conversation" + @pytest.mark.parametrize( + "input_slot, expected_slot_values", + [ + (FloatSlot("test_slot", []), None), + (TextSlot("test_slot", []), None), + (BooleanSlot("test_slot", []), "[True, False]"), + ( + CategoricalSlot("test_slot", [], values=["Value1", "Value2"]), + "['value1', 'value2']", + ), + ], ) + def test_allowed_values_for_slot( + self, + command_generator: LLMCommandGenerator, + input_slot: Slot, + expected_slot_values: Optional[str], + ): + """Test that allowed_values_for_slot returns the correct values.""" + # When + allowed_values = command_generator.allowed_values_for_slot(input_slot) + # Then + assert allowed_values == expected_slot_values + + @pytest.fixture + def collect_info_step(self) -> CollectInformationFlowStep: + """Create a CollectInformationFlowStep.""" + return CollectInformationFlowStep( + collect="test_slot", + idx=0, + ask_before_filling=True, + utter="hello", + rejections=[SlotRejection("test_slot", "some rejection")], + custom_id="collect", + description="test_slot", + metadata={}, + next="next_step", + ) + + def test_is_extractable_with_no_slot( + self, + command_generator: LLMCommandGenerator, + collect_info_step: CollectInformationFlowStep, + ): + """Test that is_extractable returns False + when there are no slots to be filled. + """ + # Given + tracker = DialogueStateTracker.from_events(sender_id="test", evts=[], slots=[]) + # When + is_extractable = command_generator.is_extractable(collect_info_step, tracker) + # Then + assert not is_extractable + + def test_is_extractable_when_slot_can_be_filled_without_asking( + self, + command_generator: LLMCommandGenerator, + ): + """Test that is_extractable returns True when + collect_information slot can be filled. + """ + # Given + tracker = DialogueStateTracker.from_events( + sender_id="test", evts=[], slots=[TextSlot(name="test_slot", mappings=[])] + ) + collect_info_step = CollectInformationFlowStep( + collect="test_slot", + ask_before_filling=False, + utter="hello", + rejections=[SlotRejection("test_slot", "some rejection")], + custom_id="collect_information", + idx=0, + description="test_slot", + metadata={}, + next="next_step", + ) + # When + is_extractable = command_generator.is_extractable(collect_info_step, tracker) + # Then + assert is_extractable + + def test_is_extractable_when_slot_has_already_been_set( + self, + command_generator: LLMCommandGenerator, + collect_info_step: CollectInformationFlowStep, + ): + """Test that is_extractable returns True + when collect_information can be filled. + """ + # Given + slot = TextSlot(name="test_slot", mappings=[]) + tracker = DialogueStateTracker.from_events( + sender_id="test", evts=[SlotSet("test_slot", "hello")], slots=[slot] + ) + # When + is_extractable = command_generator.is_extractable(collect_info_step, tracker) + # Then + assert is_extractable + + def test_is_extractable_with_current_step( + self, + command_generator: LLMCommandGenerator, + collect_info_step: CollectInformationFlowStep, + ): + """Test that is_extractable returns True when the current step is a collect + information step and matches the information step. + """ + # Given + tracker = DialogueStateTracker.from_events( + sender_id="test", + evts=[UserUttered("Hello"), BotUttered("Hi")], + slots=[TextSlot(name="test_slot", mappings=[])], + ) + # When + is_extractable = command_generator.is_extractable( + collect_info_step, tracker, current_step=collect_info_step + ) + # Then + assert is_extractable