Skip to content

Commit

Permalink
Bypass NLU pipeline when message is /intent or /intent + entities - […
Browse files Browse the repository at this point in the history
…ENG 286] (#12480)

* Bypass NLU pipeline when a message is /intent or /intent + entities
  • Loading branch information
varunshankar committed Jun 22, 2023
1 parent f9682d4 commit 59a46a6
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 17 deletions.
1 change: 1 addition & 0 deletions changelog/12480.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Skip executing the pipeline when the user message is of the form /intent or /intent + entities.
26 changes: 22 additions & 4 deletions rasa/core/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,17 @@
import rasa.core.actions.action
import rasa.shared.core.trackers
from rasa.shared.core.trackers import DialogueStateTracker, EventVerbosity
from rasa.shared.core.training_data.story_reader.yaml_story_reader import (
YAMLStoryReader,
)
from rasa.shared.nlu.constants import (
ENTITIES,
INTENT,
INTENT_NAME_KEY,
PREDICTED_CONFIDENCE_KEY,
TEXT,
)
from rasa.shared.nlu.training_data.message import Message
from rasa.utils.endpoints import EndpointConfig

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -712,11 +716,25 @@ async def parse_message(
if self.http_interpreter:
parse_data = await self.http_interpreter.parse(message)
else:
if tracker is None:
tracker = DialogueStateTracker.from_events(message.sender_id, [])
parse_data = self._parse_message_with_graph(
message, tracker, only_output_properties
msg = YAMLStoryReader.unpack_regex_message(
message=Message({TEXT: message.text})
)
# Intent is not explicitly present. Pass message to graph.
if msg.data.get(INTENT) is None:
if tracker is None:
tracker = DialogueStateTracker.from_events(message.sender_id, [])
parse_data = self._parse_message_with_graph(
message, tracker, only_output_properties
)
else:
parse_data = {
TEXT: "",
INTENT: {INTENT_NAME_KEY: None, PREDICTED_CONFIDENCE_KEY: 0.0},
ENTITIES: [],
}
parse_data.update(
msg.as_dict(only_output_properties=only_output_properties)
)

structlogger.debug(
"processor.message.parse",
Expand Down
12 changes: 10 additions & 2 deletions tests/core/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,19 @@ async def test_agent_train(default_agent: Agent):
"start": 6,
"end": 21,
"value": "Rasa",
"extractor": "RegexMessageHandler",
}
],
},
)
),
(
"hi hello",
{
"text": "hi hello",
"intent": {"name": "greet", "confidence": 1.0},
"text_tokens": [(0, 2), (3, 8)],
"entities": [],
},
),
],
)
async def test_agent_parse_message(
Expand Down
44 changes: 33 additions & 11 deletions tests/core/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from _pytest.logging import LogCaptureFixture
from aioresponses import aioresponses
from typing import Optional, Text, List, Callable, Type, Any
from unittest import mock

from rasa.core.lock_store import InMemoryLockStore
from rasa.core.policies.ensemble import DefaultPolicyPredictionEnsemble
Expand Down Expand Up @@ -113,10 +114,26 @@ async def test_message_id_logging(default_processor: MessageProcessor):


async def test_parsing(default_processor: MessageProcessor):
message = UserMessage('/greet{"name": "boy"}')
parsed = await default_processor.parse_message(message)
assert parsed["intent"][INTENT_NAME_KEY] == "greet"
assert parsed["entities"][0]["entity"] == "name"
with mock.patch(
"rasa.core.processor.MessageProcessor._parse_message_with_graph"
) as mocked_function:
# Case1: message has intent and entities explicitly set.
message = UserMessage('/greet{"name": "boy"}')
parsed = await default_processor.parse_message(message)
assert parsed["intent"][INTENT_NAME_KEY] == "greet"
assert parsed["entities"][0]["entity"] == "name"
mocked_function.assert_not_called()

# Case2: Normal user message.
parse_data = {
"text": "mocked",
"intent": {"name": None, "confidence": 0.0},
"entities": [],
}
mocked_function.return_value = parse_data
message = UserMessage("hi hello how are you?")
parsed = await default_processor.parse_message(message)
mocked_function.assert_called()


async def test_check_for_unseen_feature(default_processor: MessageProcessor):
Expand Down Expand Up @@ -874,7 +891,7 @@ async def test_handle_message_with_session_start(
# make sure the sequence of events is as expected
with_model_ids_expected = with_model_ids(
[
ActionExecuted(ACTION_SESSION_START_NAME),
ActionExecuted(ACTION_SESSION_START_NAME, confidence=1.0),
SessionStarted(),
ActionExecuted(ACTION_LISTEN_NAME),
UserUttered(
Expand All @@ -886,15 +903,18 @@ async def test_handle_message_with_session_start(
"start": 6,
"end": 22,
"value": "Core",
"extractor": "RegexMessageHandler",
}
],
),
SlotSet(entity, slot_1[entity]),
DefinePrevUserUtteredFeaturization(False),
ActionExecuted("utter_greet"),
BotUttered("hey there Core!", metadata={"utter_action": "utter_greet"}),
ActionExecuted(ACTION_LISTEN_NAME),
ActionExecuted(
"utter_greet", policy="AugmentedMemoizationPolicy", confidence=1.0
),
BotUttered(
"hey there Core!", data={}, metadata={"utter_action": "utter_greet"}
),
ActionExecuted(ACTION_LISTEN_NAME, confidence=1.0),
ActionExecuted(ACTION_SESSION_START_NAME),
SessionStarted(),
# the initial SlotSet is reapplied after the SessionStarted sequence
Expand All @@ -909,15 +929,17 @@ async def test_handle_message_with_session_start(
"start": 6,
"end": 42,
"value": "post-session start hello",
"extractor": "RegexMessageHandler",
}
],
),
SlotSet(entity, slot_2[entity]),
DefinePrevUserUtteredFeaturization(False),
ActionExecuted("utter_greet"),
ActionExecuted(
"utter_greet", policy="AugmentedMemoizationPolicy", confidence=1.0
),
BotUttered(
"hey there post-session start hello!",
data={},
metadata={"utter_action": "utter_greet"},
),
ActionExecuted(ACTION_LISTEN_NAME),
Expand Down

0 comments on commit 59a46a6

Please sign in to comment.