Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip llm steering #12501

Merged
merged 58 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
e330f08
implemented flows using a proper stack and state machine model
tmbo Jun 7, 2023
07d4a01
Apply suggestions from code review
tmbo Jun 8, 2023
9cfd048
added interrupt and return
tmbo Jun 8, 2023
27f7b32
Merge branch 'flows-with-proper-stack' of github.com:RasaHQ/rasa into…
tmbo Jun 8, 2023
8fa0916
review improvements
tmbo Jun 9, 2023
ccd2ab3
review comments
tmbo Jun 9, 2023
765df4c
chat: improved chat
tmbo Jun 9, 2023
8622588
fixed style
tmbo Jun 11, 2023
b818ea2
properly access slot value when checking for scope
tmbo Jun 13, 2023
ec9c64c
fixed style
tmbo Jun 13, 2023
4fecf71
added default flow continue interrupted utterance
tmbo Jun 13, 2023
660384f
wip llm steering
twerkmeister Jun 12, 2023
673cc02
some docs
tmbo Jun 13, 2023
4762a76
run rule within flow
tmbo Jun 13, 2023
55cf52e
Update rasa/core/actions/flows.py
tmbo Jun 14, 2023
cf35201
merged dm2
tmbo Jun 14, 2023
cf7f942
Merge branch 'flows-with-proper-stack' of github.com:RasaHQ/rasa into…
tmbo Jun 14, 2023
3cde5b0
merged updated dm2
tmbo Jun 14, 2023
e73ed79
added some docs
tmbo Jun 14, 2023
25ee2e0
Merge branch 'dm2' into flows-with-proper-stack
tmbo Jun 14, 2023
3f80558
updated docs
tmbo Jun 14, 2023
3ef8f8f
some improvements
tmbo Jun 15, 2023
9395c7d
made llm flow classifier an nlu component
twerkmeister Jun 15, 2023
77d0d84
merged flow stack
tmbo Jun 15, 2023
85d1c97
Llms docs updates (#12511)
m-vdb Jun 15, 2023
2619a4e
readded to separate sidebar
tmbo Jun 15, 2023
2edd973
remove hacky ToC from LLM docs pages
m-vdb Jun 15, 2023
688c227
added finding earlier question
tmbo Jun 15, 2023
4e845e3
added intentless docs
tmbo Jun 15, 2023
1eef053
added intentless docs
tmbo Jun 15, 2023
08fdd17
docs improvements
tmbo Jun 15, 2023
b703980
further intentless improvements
tmbo Jun 15, 2023
3d492b6
made some edits to LLM rephrasing docs
amn41 Jun 16, 2023
308fd0b
updates to LLM intent classifier docs
amn41 Jun 16, 2023
1807815
updates to intentless policy docs
amn41 Jun 16, 2023
74c0cf6
added a few extra heuristics to catch some common mistakes/hallucinat…
twerkmeister Jun 16, 2023
95dcab1
Merge branch 'flows-with-proper-stack' into ENG-339-llm-flow
tmbo Jun 16, 2023
d52a6ab
Merge branch 'ENG-339-llm-flow' of github.com:RasaHQ/rasa into ENG-33…
tmbo Jun 16, 2023
31543a5
made tracker optional in the component interface
twerkmeister Jun 16, 2023
58c6aad
added going back to previous question on refill
tmbo Jun 16, 2023
5275d3d
fixed circular import
twerkmeister Jun 16, 2023
c76a319
refined classifier and prompt
twerkmeister Jun 17, 2023
b224ea0
added default flows
tmbo Jun 17, 2023
2c4892a
Merge branch 'ENG-339-llm-flow' of github.com:RasaHQ/rasa into ENG-33…
tmbo Jun 17, 2023
5884139
fixed form handling in flow when filling slots
tmbo Jun 18, 2023
e646d3b
unified component style
tmbo Jun 18, 2023
2d56ce5
fixed linting issues
tmbo Jun 18, 2023
8871d76
trying to fix tests
tmbo Jun 18, 2023
4b52a83
merged dm2
tmbo Jun 18, 2023
cc7cbad
added jinja template to project include
tmbo Jun 18, 2023
32f6726
fxied lint
tmbo Jun 18, 2023
47aec3b
fixed types
tmbo Jun 18, 2023
9f3c099
remove unused disambiguation action
tmbo Jun 18, 2023
6152960
removed unused constant
tmbo Jun 18, 2023
e57a0f3
removed unused import
tmbo Jun 18, 2023
d456ee0
Merge branch 'dm2' into ENG-339-llm-flow
tmbo Jun 19, 2023
087508f
added todo for correcting prior questions
tmbo Jun 19, 2023
3ab1539
Merge branch 'dm2' into ENG-339-llm-flow
tmbo Jun 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repository = "https://github.com/rasahq/rasa"
documentation = "https://rasa.com/docs"
classifiers = [ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Topic :: Software Development :: Libraries",]
keywords = [ "nlp", "machine-learning", "machine-learning-library", "bot", "bots", "botkit", "rasa conversational-agents", "conversational-ai", "chatbot", "chatbot-framework", "bot-framework",]
include = [ "LICENSE.txt", "README.md", "rasa/shared/core/training_data/visualization.html", "rasa/cli/default_config.yml", "rasa/shared/importers/*", "rasa/utils/schemas/*", "rasa/keys", "rasa/core/channels/chat.html", "rasa/core/policies/detectors/prompt_sensitive_topic.jinja2"]
include = [ "LICENSE.txt", "README.md", "rasa/shared/core/training_data/visualization.html", "rasa/cli/default_config.yml", "rasa/shared/importers/*", "rasa/utils/schemas/*", "rasa/keys", "rasa/core/channels/chat.html", "rasa/core/policies/detectors/prompt_sensitive_topic.jinja2", "rasa/nlu/classifiers/flow_prompt_template.jinja2"]
readme = "README.md"
license = "Apache-2.0"
[[tool.poetry.source]]
Expand Down
26 changes: 24 additions & 2 deletions rasa/core/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["Action"]:
"""List default actions."""
from rasa.core.actions.two_stage_fallback import TwoStageFallbackAction
from rasa.core.actions.flows import ActionFlowContinueInterupted

return [
ActionListen(),
Expand All @@ -114,7 +113,6 @@ def default_actions(action_endpoint: Optional[EndpointConfig] = None) -> List["A
ActionSendText(),
ActionBack(),
ActionExtractSlots(action_endpoint),
ActionFlowContinueInterupted(),
]


Expand Down Expand Up @@ -314,6 +312,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Simple run implementation uttering a (hopefully defined) response."""
response_ids_for_response = domain.response_ids_per_response.get(
Expand Down Expand Up @@ -372,6 +371,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action (see parent class for full docstring)."""
message = {"text": self.action_text}
Expand Down Expand Up @@ -460,6 +460,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Query the appropriate response and create a bot utterance with that."""
latest_message = tracker.latest_message
Expand Down Expand Up @@ -520,6 +521,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
# only utter the response if it is available
Expand All @@ -545,6 +547,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
return []
Expand All @@ -570,6 +573,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
# only utter the response if it is available
Expand Down Expand Up @@ -606,6 +610,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
_events: List[Event] = [SessionStarted()]
Expand Down Expand Up @@ -635,6 +640,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
# only utter the response if it is available
Expand All @@ -655,6 +661,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
return [ActiveLoop(None), SlotSet(REQUESTED_SLOT, None)]
Expand Down Expand Up @@ -769,6 +776,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
json_body = self._action_call_format(tracker, domain)
Expand Down Expand Up @@ -895,6 +903,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
from rasa.core.policies.two_stage_fallback import has_user_rephrased
Expand Down Expand Up @@ -925,6 +934,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
return []
Expand Down Expand Up @@ -1005,6 +1015,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
latest_message = tracker.latest_message
Expand Down Expand Up @@ -1273,6 +1284,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Runs action. Please see parent class for the full docstring."""
slot_events: List[Event] = []
Expand Down Expand Up @@ -1354,6 +1366,16 @@ def extract_slot_value_from_predefined_mapping(
tracker: "DialogueStateTracker",
) -> List[Any]:
"""Extracts slot value if slot has an applicable predefined mapping."""

if tracker.has_bot_message_after_latest_user_message():
# TODO: this needs further validation - not sure if this breaks something!!!

# If the bot sent a message after the user sent a message, we can't
# extract any slots from the user message. We assume that the user
# message was already processed by the bot and the slot value was
# already extracted (e.g. for a prior form slot).
return []

should_fill_entity_slot = (
mapping_type == SlotMappingType.FROM_ENTITY
and SlotMapping.entity_is_desired(mapping, tracker)
Expand Down
54 changes: 13 additions & 41 deletions rasa/core/actions/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
from rasa.core.policies.flow_policy import FlowStack, FlowStackFrame, StackFrameType
from rasa.shared.constants import FLOW_PREFIX

from rasa.shared.core.constants import (
ACTION_FLOW_CONTINUE_INERRUPTED_NAME,
FLOW_STACK_SLOT,
)
from rasa.shared.core.constants import FLOW_STACK_SLOT
from rasa.shared.core.domain import Domain
from rasa.shared.core.events import (
ActiveLoop,
Expand Down Expand Up @@ -47,10 +44,13 @@ async def run(
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Trigger the flow."""

stack = FlowStack.from_tracker(tracker)
if tracker.active_loop_name and not stack.is_empty():
frame_type = StackFrameType.INTERRUPT
elif self._flow_name == "pattern_continue_interrupted":
frame_type = StackFrameType.RESUME
elif self._flow_name == "pattern_correction":
frame_type = StackFrameType.CORRECTION
else:
frame_type = StackFrameType.REGULAR

Expand All @@ -61,43 +61,15 @@ async def run(
)
)

events: List[Event] = [SlotSet(FLOW_STACK_SLOT, stack.as_dict())]
slots_to_be_set = metadata.get("slots", {}) if metadata else {}
slot_set_events: List[Event] = [
SlotSet(key, value) for key, value in slots_to_be_set.items()
]

events: List[Event] = [
SlotSet(FLOW_STACK_SLOT, stack.as_dict())
] + slot_set_events
if tracker.active_loop_name:
events.append(ActiveLoop(None))

return events


UTTER_FLOW_CONTINUE_INTERRUPTED = "utter_flow_continue_interrupted"


class ActionFlowContinueInterupted(action.Action):
"""Action triggered when an interrupted flow is continued."""

def name(self) -> Text:
"""Return the flow name."""
return ACTION_FLOW_CONTINUE_INERRUPTED_NAME

async def run(
self,
output_channel: "OutputChannel",
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
"""Trigger the flow."""

fallback = {"text": "Let's return to the previous topic."}
flow_name = metadata.get("flow_name") if metadata else None

generated = await nlg.generate(
UTTER_FLOW_CONTINUE_INTERRUPTED,
tracker,
output_channel.name(),
flow_name=flow_name,
)

utterance: Event = action.create_bot_utterance(generated or fallback)

return [utterance]
3 changes: 2 additions & 1 deletion rasa/core/actions/loops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import List, TYPE_CHECKING
from typing import Any, Dict, List, TYPE_CHECKING, Optional, Text

from rasa.core.actions.action import Action
from rasa.shared.core.events import Event, ActiveLoop
Expand All @@ -18,6 +18,7 @@ async def run(
nlg: "NaturalLanguageGenerator",
tracker: "DialogueStateTracker",
domain: "Domain",
metadata: Optional[Dict[Text, Any]] = None,
) -> List[Event]:
events: List[Event] = []

Expand Down
68 changes: 68 additions & 0 deletions rasa/core/policies/default_flows.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
responses:
utter_flow_continue_interrupted:
- text: Let's continue with the topic {rasa_previous_flow}.
metadata:
rephrase: True

utter_ask_confirm_correction:
- text: "Do you want to update your information?"
buttons:
- payload: /affirm
title: Yes
- payload: /deny
title: No, please keep the previous information
metadata:
rephrase: True

utter_corrected_previous_input:
- text: "Ok, I corrected the previous input."
metadata:
rephrase: True

utter_not_corrected_previous_input:
- text: "Ok, I did not correct the previous input."
metadata:
rephrase: True

slots:
confirm_correction:
type: text
mappings:
- intent: affirm
type: from_intent
value: "True"
conditions:
- active_loop: question_confirm_correction
- intent: deny
type: from_intent
value: "False"
conditions:
- active_loop: question_confirm_correction

flows:
pattern_continue_interrupted:
description: A meta flow that should be started to continue an interrupted flow.

steps:
- id: "0"
action: utter_flow_continue_interrupted

pattern_correction:
description: A meta flow that should be started to correct a previous user input.

steps:
- id: "0"
question: confirm_correction
next:
- if: confirm_correction
then: "1"
- else: "2"
- id: "1"
action: utter_corrected_previous_input
- id: "2"
set_slots:
- rasa_corrected_slots: None
next: "3"
- id: "3"
action: utter_not_corrected_previous_input

8 changes: 5 additions & 3 deletions rasa/core/policies/detectors/sensitive_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def check(self, user_msg: Text) -> bool:
user_msg: user message to check

Returns:
True if the message contains sensitive topic, False otherwise"""
True if the message contains sensitive topic, False otherwise
"""
...

def action(self) -> Text:
Expand Down Expand Up @@ -104,7 +105,7 @@ def check(self, user_msg: Optional[Text]) -> bool:
if self._use_stub:
return self._stub.check(user_msg)
try:
resp = openai.Completion.create(
resp = openai.Completion.create( # type: ignore[no-untyped-call]
model=self._model_name,
prompt=self._make_prompt(user_msg),
temperature=0.0,
Expand All @@ -125,7 +126,8 @@ def _make_prompt(self, user_message: Text) -> Text:
def _parse_response(text: Text) -> bool:
"""Parse response from OpenAI ChatGPT model.

Expected responses are "YES" and "NO" (case-insensitive)."""
Expected responses are "YES" and "NO" (case-insensitive).
"""
return "YES" in text.upper()


Expand Down
Loading
Loading