From 59a46a6c0cd17844d4092c0fceec3a80efe8c5a9 Mon Sep 17 00:00:00 2001 From: Varun Shankar S Date: Thu, 22 Jun 2023 14:47:31 +0200 Subject: [PATCH] Bypass NLU pipeline when message is /intent or /intent + entities - [ENG 286] (#12480) * Bypass NLU pipeline when a message is /intent or /intent + entities --- changelog/12480.improvement.md | 1 + rasa/core/processor.py | 26 ++++++++++++++++---- tests/core/test_agent.py | 12 ++++++++-- tests/core/test_processor.py | 44 +++++++++++++++++++++++++--------- 4 files changed, 66 insertions(+), 17 deletions(-) create mode 100644 changelog/12480.improvement.md diff --git a/changelog/12480.improvement.md b/changelog/12480.improvement.md new file mode 100644 index 000000000000..8a55039f89cb --- /dev/null +++ b/changelog/12480.improvement.md @@ -0,0 +1 @@ +Skip executing the pipeline when the user message is of the form /intent or /intent + entities. \ No newline at end of file diff --git a/rasa/core/processor.py b/rasa/core/processor.py index 67ddd97816f9..7b688cfd6adf 100644 --- a/rasa/core/processor.py +++ b/rasa/core/processor.py @@ -61,6 +61,9 @@ import rasa.core.actions.action import rasa.shared.core.trackers from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity +from rasa.shared.core.training_data.story_reader.yaml_story_reader import ( + YAMLStoryReader, +) from rasa.shared.nlu.constants import ( ENTITIES, INTENT, @@ -68,6 +71,7 @@ PREDICTED_CONFIDENCE_KEY, TEXT, ) +from rasa.shared.nlu.training_data.message import Message from rasa.utils.endpoints import EndpointConfig logger = logging.getLogger(__name__) @@ -712,11 +716,25 @@ async def parse_message( if self.http_interpreter: parse_data = await self.http_interpreter.parse(message) else: - if tracker is None: - tracker = DialogueStateTracker.from_events(message.sender_id, []) - parse_data = self._parse_message_with_graph( - message, tracker, only_output_properties + msg = YAMLStoryReader.unpack_regex_message( + message=Message({TEXT: message.text}) ) + # Intent is not explicitly present. Pass message to graph. + if msg.data.get(INTENT) is None: + if tracker is None: + tracker = DialogueStateTracker.from_events(message.sender_id, []) + parse_data = self._parse_message_with_graph( + message, tracker, only_output_properties + ) + else: + parse_data = { + TEXT: "", + INTENT: {INTENT_NAME_KEY: None, PREDICTED_CONFIDENCE_KEY: 0.0}, + ENTITIES: [], + } + parse_data.update( + msg.as_dict(only_output_properties=only_output_properties) + ) structlogger.debug( "processor.message.parse", diff --git a/tests/core/test_agent.py b/tests/core/test_agent.py index 21b42d690a96..9241f8f950b3 100644 --- a/tests/core/test_agent.py +++ b/tests/core/test_agent.py @@ -98,11 +98,19 @@ async def test_agent_train(default_agent: Agent): "start": 6, "end": 21, "value": "Rasa", - "extractor": "RegexMessageHandler", } ], }, - ) + ), + ( + "hi hello", + { + "text": "hi hello", + "intent": {"name": "greet", "confidence": 1.0}, + "text_tokens": [(0, 2), (3, 8)], + "entities": [], + }, + ), ], ) async def test_agent_parse_message( diff --git a/tests/core/test_processor.py b/tests/core/test_processor.py index 392d85c29745..640cbc28d973 100644 --- a/tests/core/test_processor.py +++ b/tests/core/test_processor.py @@ -18,6 +18,7 @@ from _pytest.logging import LogCaptureFixture from aioresponses import aioresponses from typing import Optional, Text, List, Callable, Type, Any +from unittest import mock from rasa.core.lock_store import InMemoryLockStore from rasa.core.policies.ensemble import DefaultPolicyPredictionEnsemble @@ -113,10 +114,26 @@ async def test_message_id_logging(default_processor: MessageProcessor): async def test_parsing(default_processor: MessageProcessor): - message = UserMessage('/greet{"name": "boy"}') - parsed = await default_processor.parse_message(message) - assert parsed["intent"][INTENT_NAME_KEY] == "greet" - assert parsed["entities"][0]["entity"] == "name" + with mock.patch( + "rasa.core.processor.MessageProcessor._parse_message_with_graph" + ) as mocked_function: + # Case1: message has intent and entities explicitly set. + message = UserMessage('/greet{"name": "boy"}') + parsed = await default_processor.parse_message(message) + assert parsed["intent"][INTENT_NAME_KEY] == "greet" + assert parsed["entities"][0]["entity"] == "name" + mocked_function.assert_not_called() + + # Case2: Normal user message. + parse_data = { + "text": "mocked", + "intent": {"name": None, "confidence": 0.0}, + "entities": [], + } + mocked_function.return_value = parse_data + message = UserMessage("hi hello how are you?") + parsed = await default_processor.parse_message(message) + mocked_function.assert_called() async def test_check_for_unseen_feature(default_processor: MessageProcessor): @@ -874,7 +891,7 @@ async def test_handle_message_with_session_start( # make sure the sequence of events is as expected with_model_ids_expected = with_model_ids( [ - ActionExecuted(ACTION_SESSION_START_NAME), + ActionExecuted(ACTION_SESSION_START_NAME, confidence=1.0), SessionStarted(), ActionExecuted(ACTION_LISTEN_NAME), UserUttered( @@ -886,15 +903,18 @@ async def test_handle_message_with_session_start( "start": 6, "end": 22, "value": "Core", - "extractor": "RegexMessageHandler", } ], ), SlotSet(entity, slot_1[entity]), DefinePrevUserUtteredFeaturization(False), - ActionExecuted("utter_greet"), - BotUttered("hey there Core!", metadata={"utter_action": "utter_greet"}), - ActionExecuted(ACTION_LISTEN_NAME), + ActionExecuted( + "utter_greet", policy="AugmentedMemoizationPolicy", confidence=1.0 + ), + BotUttered( + "hey there Core!", data={}, metadata={"utter_action": "utter_greet"} + ), + ActionExecuted(ACTION_LISTEN_NAME, confidence=1.0), ActionExecuted(ACTION_SESSION_START_NAME), SessionStarted(), # the initial SlotSet is reapplied after the SessionStarted sequence @@ -909,15 +929,17 @@ async def test_handle_message_with_session_start( "start": 6, "end": 42, "value": "post-session start hello", - "extractor": "RegexMessageHandler", } ], ), SlotSet(entity, slot_2[entity]), DefinePrevUserUtteredFeaturization(False), - ActionExecuted("utter_greet"), + ActionExecuted( + "utter_greet", policy="AugmentedMemoizationPolicy", confidence=1.0 + ), BotUttered( "hey there post-session start hello!", + data={}, metadata={"utter_action": "utter_greet"}, ), ActionExecuted(ACTION_LISTEN_NAME),