Skip to content

Commit

Permalink
added going back to previous question on refill
Browse files Browse the repository at this point in the history
  • Loading branch information
tmbo committed Jun 16, 2023
1 parent 31543a5 commit 58c6aad
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 42 deletions.
100 changes: 71 additions & 29 deletions rasa/core/policies/flow_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
)
from pypred import Predicate
from rasa.core.policies.rule_policy import RulePolicy
from rasa.nlu.classifiers.llm_flow_classifier import CORRECTION_INTENT
from rasa.shared.constants import FLOW_PREFIX
from rasa.shared.nlu.constants import ENTITY_ATTRIBUTE_TYPE, INTENT_NAME_KEY
from rasa.shared.core.constants import (
ACTION_FLOW_CONTINUE_INERRUPTED_NAME,
ACTION_LISTEN_NAME,
FLOW_STACK_SLOT,
)
from rasa.shared.core.events import ActiveLoop, Event, SlotSet
from rasa.shared.core.events import ActiveLoop, Event, SlotSet, UserUttered
from rasa.shared.core.flows.flow import (
END_STEP,
START_STEP,
Expand Down Expand Up @@ -50,8 +51,9 @@
DialogueStateTracker,
)
from rasa.core.policies.detectors import SensitiveTopicDetector
import structlog

logger = logging.getLogger(__name__)
structlogger = structlog.get_logger()

SENSITIVE_TOPIC_DETECTOR_CONFIG_KEY = "sensitive_topic_detector"

Expand Down Expand Up @@ -112,7 +114,9 @@ def __init__(
# if the detector is configured, we need to load it
full_config = SensitiveTopicDetector.get_default_config()
full_config.update(detector_config)
self._sensitive_topic_detector = SensitiveTopicDetector(full_config)
self._sensitive_topic_detector: Optional[
SensitiveTopicDetector
] = SensitiveTopicDetector(full_config)
else:
self._sensitive_topic_detector = None

Expand Down Expand Up @@ -188,29 +192,19 @@ def predict_action_probabilities(
# sure that the input isn't used in any following flow
# steps. At the same time, we can't completely skip flows
# as we want to guide the user to the next step of the flow.
logger.info(
"Sensitive topic detected, predicting action %s", predicted_action
structlogger.info(
"sensitive.topic.detected", predicted_action=predicted_action
)
else:
logger.info("No sensitive topic detected: %s", latest_message.text)
structlogger.debug(
"sensitive.topic.notdetected", message=latest_message.text
)

# if detector predicted an action, we don't want to predict a flow
if predicted_action is not None:
return self._create_prediction_result(predicted_action, domain, 1.0, [])

executor = FlowExecutor.from_tracker(tracker, flows or FlowsList([]))
if tracker.active_loop:
# we are in a loop - likely answering a question - we need to check
# if the user responded with a trigger intent for another flow rather
# than answering the question
prediction = executor.consider_flow_switch(tracker)
return self._create_prediction_result(
action_name=prediction.action_name,
domain=domain,
score=prediction.score,
events=[],
action_metadata=prediction.metadata,
)

# create executor and predict next action
prediction = executor.advance_flows(tracker, domain)
Expand Down Expand Up @@ -341,7 +335,7 @@ def top_flow_step(self, flows: FlowsList) -> Optional[FlowStep]:
if not (top := self.top()) or not (top_flow := self.top_flow(flows)):
return None

return top_flow.step_for_id(top.step_id)
return top_flow.step_by_id(top.step_id)

def is_empty(self) -> bool:
"""Checks if the stack is empty.
Expand Down Expand Up @@ -620,14 +614,14 @@ def _is_step_completed(
else:
return True

def _find_updated_question(
self, current_step: FlowStep, flow: Flow, updated_slot_name: Text
def _find_earliest_updated_question(
self, current_step: FlowStep, flow: Flow, updated_slots: List[SlotSet]
) -> Optional[FlowStep]:
"""Find the question that was updated."""
asked_question_steps = flow.previously_asked_questions(current_step.id)

for question_step in asked_question_steps:
if question_step.question == updated_slot_name:
for question_step in reversed(asked_question_steps):
if question_step.question in {s.key for s in updated_slots}:
return question_step
return None

Expand All @@ -642,10 +636,10 @@ def consider_flow_switch(self, tracker: DialogueStateTracker) -> ActionPredictio
if new_flow := self.find_startable_flow(tracker):
# there are flows available, but we are not in a flow
# it looks like we can start a flow, so we'll predict the trigger action
logger.debug(f"Found startable flow: {new_flow.id}")
structlogger.debug("flow.startable", flow_id=new_flow.id)
return ActionPrediction(FLOW_PREFIX + new_flow.id, 1.0)
else:
logger.debug("No startable flow found.")
structlogger.debug("flow.nostartable")
return ActionPrediction(None, 0.0)

def advance_flows(
Expand All @@ -667,10 +661,11 @@ def advance_flows(
if prediction.action_name:
# if a flow can be started, we'll start it
return prediction
if self.flow_stack.is_empty():
if not (top_flow := self.flow_stack.top()):
# if there are no flows, there is nothing to do
return ActionPrediction(None, 0.0)
else:
self._correct_flow_position(top_flow, tracker)
prediction = self._select_next_action(tracker, domain)
if FlowStack.from_tracker(tracker).as_dict() != self.flow_stack.as_dict():
# we need to update the flow stack to persist the state of the executor
Expand All @@ -684,6 +679,52 @@ def advance_flows(
)
return prediction

def _slot_sets_after_latest_message(
self, tracker: DialogueStateTracker
) -> List[SlotSet]:
"""Get all slot sets after the latest message."""
if not tracker.latest_message:
return []

slot_sets = []

for event in reversed(tracker.applied_events()):
if isinstance(event, UserUttered):
break
elif isinstance(event, SlotSet):
slot_sets.append(event)
return slot_sets

def _correct_flow_position(
self, flow_stack_frame: FlowStackFrame, tracker: DialogueStateTracker
) -> None:
if not tracker.latest_action_name == ACTION_LISTEN_NAME:
return None

if (
tracker.latest_message
and tracker.latest_message.intent.get("name") == CORRECTION_INTENT
):
newly_set_slots = self._slot_sets_after_latest_message(tracker)

if not (flow := self.all_flows.flow_by_id(flow_stack_frame.flow_id)):
return

if not (step := flow.step_by_id(flow_stack_frame.step_id)):
return

reset_point = self._find_earliest_updated_question(
step, flow, newly_set_slots
)

if reset_point:
structlogger.info(
"flow.reset.slotupdate",
stack_frame=flow_stack_frame,
reset_point=reset_point.id,
)
self.flow_stack.advance_top_flow(reset_point.id)

def _select_next_action(
self,
tracker: DialogueStateTracker,
Expand Down Expand Up @@ -722,7 +763,7 @@ def _select_next_action(
"to __start__ if it ended it should be popped from the stack."
)

logger.info(previous_step)
structlogger.debug("flow.nextAction.step", previous_step)
predicted_action = self._wrap_up_previous_step(
current_flow, previous_step, tracker
)
Expand Down Expand Up @@ -821,11 +862,12 @@ def _run_step(
if isinstance(step, QuestionFlowStep):
slot = tracker.slots.get(step.question, None)
initial_value = slot.initial_value if slot else None
if step.skip_if_filled and slot.value != initial_value:
slot_value = slot.value if slot else None
if step.skip_if_filled and slot_value != initial_value:
return ActionPrediction(None, 0.0)

question_action = ActionPrediction("question_" + step.question, 1.0)
if slot.value != initial_value:
if slot_value != initial_value:
question_action.events = [SlotSet(step.question, initial_value)]
return question_action

Expand Down
9 changes: 6 additions & 3 deletions rasa/nlu/classifiers/llm_flow_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
importlib.resources.read_text("rasa.nlu.classifiers", "flow_prompt_template.jinja2")
)

CORRECTION_INTENT = "correction"

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -212,7 +214,7 @@ def parse_action_list(
and top_flow_step.question != slot_sets[0][0]
and slot_sets[0][0] in slots_so_far
):
return "correction", slot_sets
return CORRECTION_INTENT, slot_sets
elif (
len(slot_sets) == 1
and top_flow_step is not None
Expand Down Expand Up @@ -248,8 +250,9 @@ def create_template_inputs(cls, flows: FlowsList) -> List[Dict[str, str]]:
{
"name": flow.id,
"description": flow.description,
"slots": flow.slots()
})
"slots": flow.slots(),
}
)
return result

@classmethod
Expand Down
18 changes: 11 additions & 7 deletions rasa/shared/core/flows/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def step_by_id(self, step_id: Text, flow_id: Text) -> FlowStep:
if not flow:
raise UnresolvedFlowException(flow_id)

step = flow.step_for_id(step_id)
step = flow.step_by_id(step_id)
if not step:
raise UnresolvedFlowStepIdException(step_id, flow, referenced_from=None)

Expand Down Expand Up @@ -240,7 +240,7 @@ def _reachable_steps(
reached_steps.add(step.id)
for link in step.next.links:
reached_steps = _reachable_steps(
self.step_for_id(link.target), reached_steps
self.step_by_id(link.target), reached_steps
)
return reached_steps

Expand All @@ -250,7 +250,7 @@ def _reachable_steps(
if step.id not in reached_steps:
raise UnreachableFlowStepException(step, self)

def step_for_id(self, step_id: Optional[Text]) -> Optional[FlowStep]:
def step_by_id(self, step_id: Optional[Text]) -> Optional[FlowStep]:
"""Returns the step with the given id."""
if not step_id:
return None
Expand Down Expand Up @@ -310,20 +310,24 @@ def _previously_asked_questions(
"""Returns the questions asked before the given step.
Keeps track of the steps that have been visited to avoid circles."""
current_step = self.step_for_id(current_step_id)
current_step = self.step_by_id(current_step_id)

questions = []

if isinstance(current_step, QuestionFlowStep):
questions.append(current_step.question)
questions.append(current_step)

visited_steps.add(current_step)
visited_steps.add(current_step.id)

for previous_step in self.steps:
for next_link in previous_step.next.links:
if next_link.target != current_step_id:
continue
questions.extend(self._previously_asked_questions(previous_step.id))
if previous_step.id in visited_steps:
continue
questions.extend(
_previously_asked_questions(previous_step.id, visited_steps)
)
return questions

return _previously_asked_questions(step_id, set())
Expand Down
8 changes: 5 additions & 3 deletions rasa/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
logger = logging.getLogger(__name__)


def generate_text_openai_chat(prompt: str,
model: str = "gpt-3.5-turbo") -> Optional[str]:
def generate_text_openai_chat(
prompt: str, model: str = "gpt-3.5-turbo"
) -> Optional[str]:
chat_completion = openai.ChatCompletion.create(
model=model, messages=[{"role": "user", "content": prompt}], temperature=0.0)
model=model, messages=[{"role": "user", "content": prompt}], temperature=0.0
)
return chat_completion.choices[0].message.content


Expand Down

0 comments on commit 58c6aad

Please sign in to comment.