diff --git a/state_formatters/dp_formatters.py b/state_formatters/dp_formatters.py index 4fbf6dc1eb..16017f19fc 100755 --- a/state_formatters/dp_formatters.py +++ b/state_formatters/dp_formatters.py @@ -1,6 +1,7 @@ import logging from copy import deepcopy -from typing import Dict, List, Any, Tuple, Optional, Union +from itertools import zip_longest +from typing import Dict, List, Any, Union from common.utils import get_entities, get_intents import state_formatters.utils as utils @@ -25,83 +26,68 @@ def preprocess_dialog( return dialog -def get_history(dialog): - history = [] - prev_human_utterance = None - for utt in dialog["utterances"]: - if utt["user"]["user_type"] == "human": - prev_human_utterance = utt["annotations"].get("spelling_preprocessing", utt["text"]) - elif utt["user"]["user_type"] == "bot" and utt["active_skill"] == "eliza" and prev_human_utterance is not None: - history.append(prev_human_utterance) - return history - - -def get_utterance_histories(dialog): - return [[utt["text"] for utt in dialog["utterances"]]] - - -def get_annotation_histories(dialog): - return [[deepcopy(utt.get("annotations")) for utt in dialog["utterances"]]] - - -def get_ongoing_utterances(dialog): - return [utils.count_ongoing_skill_utterances(dialog["bot_utterances"], "convert_reddit")] +def get_annotation( + dialog: Dict, + annotation_type: str, + default_result: Any = None, + last_n_utts: int = 1, + utterance_type: str = "human_utterances" +) -> List[Dict]: + return dialog[utterance_type][-last_n_utts]["annotations"].get(annotation_type, default_result) -def get_human_attributes(dialog): - return [dialog["human"]["attributes"]] +def get_history(dialog): + return [utt["annotations"].get("spelling_preprocessing", utt["text"]) for utt in dialog["utterances"] + if utt["user"]["user_type"] == "bot" and utt["active_skill"] == "eliza"] -def get_text(dialog): - return dialog["human_utterances"][-1]["annotations"].get( - "spelling_preprocessing", dialog["human_utterances"][-1]["text"]) +def get_utterances_attribute( + dialog: Dict, + utterance_type: str, + attribute: str = None, + sub_attribute: str = None, + last_n_utts: int = 0 +) -> List: + dialog_slice = dialog[utterance_type][-last_n_utts:] if last_n_utts > 0 else dialog[utterance_type] + if attribute is None: + return dialog_slice -def get_speeches(dialog: Dict) -> Dict: - return dialog["human_utterances"][-1].get("attributes", {}).get("speech", {}) + if attribute == 'attributes' and utterance_type == 'human': + return [dialog_slice[attribute]] + if sub_attribute is None: + return [utt.get(attribute, "") for utt in dialog_slice] -def get_human_utterances(dialog: Dict) -> List[Dict]: - return dialog["human_utterances"][-3:] + return [utt.get(attribute, {}).get(sub_attribute, "") for utt in dialog_slice] -def get_dialog_history(dialog: Dict) -> List[str]: - return [uttr["text"] for uttr in dialog["utterances"][-2:]] +def get_ongoing_utterances(dialog): + return [utils.count_ongoing_skill_utterances(dialog["bot_utterances"], "convert_reddit")] -def get_entities_with_labels(dialog: Dict) -> Any: # replace Any with the actual return type of get_entities +def get_entities_with_labels(dialog: Dict) -> Union[List[Dict[str, str]], List]: return get_entities(dialog["human_utterances"][-1], only_named=False, with_labels=True) -def get_entity_info(dialog: Dict) -> List[Dict]: - return dialog["human_utterances"][-1]["annotations"].get("entity_linking", [{}]) - - -def get_named_entities(dialog: Dict) -> List[Dict]: - return dialog["human_utterances"][-1]["annotations"].get("ner", [{}]) - - def get_tokenized_sentences(dialog: Dict) -> List[List[str]]: - tokens = dialog["human_utterances"][-1]["annotations"].get("spacy_annotator", []) + tokens = get_annotation(dialog, annotation_type="spacy_annotator", default_result=[], + last_n_utts=1, utterance_type="human_utterance") tokens = [token["text"] for token in tokens] return [tokens] if len(tokens) else None def get_sentences_with_history(dialog: Dict) -> List[str]: - last_human_utt = get_text(dialog)[0] - if dialog["bot_utterances"]: - # h sep b sep h sep b sep h - prev_bot_utts = [k["text"] for k in dialog["bot_utterances"][-2:]] - prev_human_utts = [ - utt["annotations"].get("spelling_preprocessing", utt["text"]) for utt in dialog["human_utterances"][-3:-1] - ] - prev_utts = [] - for human_utt, bot_utt in zip(prev_human_utts, prev_bot_utts): - prev_utts.append(human_utt) - prev_utts.append(bot_utt) - sentence_w_history = " [SEP] ".join(prev_utts + [last_human_utt]) - else: - sentence_w_history = last_human_utt + # get the two most recent bot and human utterances, and the last human utterance + last_human_utt = get_utterances_attribute(dialog, 'human_utterances', 'text', last_n_utts=1)[0] + prev_bot_utts = get_utterances_attribute(dialog, 'bot_utterances', 'text', last_n_utts=2) + prev_human_utts = get_utterances_attribute(dialog, 'human_utterances', 'annotations', + 'spelling_preprocessing', last_n_utts=3) + + # join the utterances with a separator, starting with the older utterances + utterances = [utt for pair in zip_longest(prev_human_utts, prev_bot_utts, fillvalue='') for utt in pair if utt] + sentence_w_history = ' [SEP] '.join(utterances + [last_human_utt]) + return [sentence_w_history] @@ -124,21 +110,13 @@ def get_utterances_with_histories(dialog: Dict) -> List[List[str]]: def get_active_skills(dialog: Dict): - active_skills = [utt.get("active_skill", "") for utt in dialog["utterances"]] - active_skills = [skill for skill in active_skills if skill] - return [active_skills] + active_skills = get_utterances_attribute(dialog, utterance_type="utterance", attribute="active_skill") + return [[skill for skill in active_skills if skill]] -def get_cobot_topics(dialog: Dict): - topics = [] - for utt in dialog["utterances"]: - topics += utt.get("annotations", {}).get("cobot_topics", {}).get("text", []) - return [topics] - - -def get_hypotheses(dialog: Dict): - hypots = [h["text"] for h in dialog["human_utterances"][-1]["hypotheses"]] - return hypots +def get_cobot_topics(dialog: Dict) -> List[List[str]]: + return [[topic for utt in dialog["utterances"] + for topic in utt.get("annotations", {}).get("cobot_topics", {}).get("text", [])]] def get_contexts(dialog: Dict): @@ -152,13 +130,13 @@ def get_midas_preparation(dialog: Dict): return [max(midas_dist, key=midas_dist.get)] -def unified_formatter( +def dream_formatter( dialog: Dict, result_keys: List, service_name: str = "", - last_n_utts: int = None, preprocess: bool = False, - preprocess_params: Dict = None + preprocess_params: Dict = None, + additional_params: Dict = None ) -> List: """ Parameters @@ -166,9 +144,9 @@ def unified_formatter( service_name: name of the service dialog: full dialog state result_keys: list of keys in result dialog - last_n_utts: how many last user utterances to take - preprocess: preprocess dialog, - preprocess_params: preprocess_params + preprocess: preprocess dialog + preprocess_params: parameters for preprocessing + additional_params: additional parameters for dialog processing Returns ------- @@ -179,17 +157,17 @@ def unified_formatter( keys_table = { "speeches": get_speeches, - "human_utterances": get_human_utterances, + "human_utterances": get_utterances_attribute, "last_utterance": get_text, "last_utternace_batch": get_text, "human_utterance_history_batch": get_history, - "personality": lambda dialog: dialog["bot"]["persona"] if "convert" == "convert" else get_text(dialog), + "personality": lambda dialog: dialog["bot"]["persona"] if service_name == "convert" else get_text(dialog), "states_batch": lambda dialog: dialog, - "utterances_histories": get_utterance_histories, + "utterances_histories": get_utterances_attribute, "annotation_histories": get_annotation_histories, "sentences": get_text, - "contexts": get_utterance_histories, - "utterances": get_dialog_history, + "contexts": get_utterances_attribute, + "utterances": get_utterances_attribute, "entities_with_labels": get_entities_with_labels, "named_entities": get_named_entities, "entity_info": get_entity_info, @@ -200,10 +178,11 @@ def unified_formatter( "dialog_context": get_contexts, "hypotheses": get_hypotheses, "last_midas_labels": get_midas_preparation, - "return_probas": lambda dialog: 1 + "return_probas": lambda dialog: 1, + "dialogs": lambda dialog: dialog } - formatted_dialog = {key: keys_table[key](dialog) for key in result_keys} + formatted_dialog = {key: keys_table[key](dialog, **additional_params) for key in result_keys} if formatted_dialog.get("tokenized_sentences") is None: del formatted_dialog["tokenized_sentences"] @@ -213,7 +192,7 @@ def unified_formatter( def eliza_formatter_dialog(dialog: Dict) -> List[Dict]: # Used by: eliza_formatter - return unified_formatter( + return dream_formatter( service_name="eliza", dialog=dialog, result_keys=["last_utterance_batch", "human_utterance_history_batch"], preprocess=False @@ -247,13 +226,15 @@ def cobot_asr_formatter_service(payload: List): def base_skill_selector_formatter_dialog(dialog: Dict) -> List[Dict]: - return unified_formatter(preprocess=True, preprocess_params={"bot_last_turns": 5, "mode": "punct_sent"}, - dialog=dialog, result_keys=["states_batch"]) + return dream_formatter( + dialog, result_keys=["states_batch"], + preprocess=True, preprocess_params={"bot_last_turns": 5, "mode": "punct_sent"} + ) def convert_formatter_dialog(dialog: Dict) -> List[Dict]: # Used by: convert - return unified_formatter( + return dream_formatter( dialog, service_name="convert", result_keys=["utterances_histories", "personality", "num_ongoing_utt", "human_attributes"], preprocess=True, preprocess_params={ @@ -267,19 +248,19 @@ def convert_formatter_dialog(dialog: Dict) -> List[Dict]: def personality_catcher_formatter_dialog(dialog: Dict) -> List[Dict]: # Used by: personality_catcher_formatter - return unified_formatter(dialog, result_keys=["personality"]) + return dream_formatter(dialog, result_keys=["personality"]) def sent_rewrite_formatter_dialog(dialog: Dict) -> List[Dict]: # Used by: sent_rewrite_formatter - return unified_formatter( + return dream_formatter( dialog, result_keys=["utterances_histories", "annotation_histories"], preprocess=True, preprocess_params={"bot_last_turns": utils.LAST_N_TURNS} ) def sent_rewrite_formatter_w_o_last_dialog(dialog: Dict) -> List[Dict]: - return unified_formatter( + return dream_formatter( dialog, result_keys=["utterances_histories", "annotation_histories"], preprocess=True, preprocess_params={"bot_last_turns": utils.LAST_N_TURNS + 1} ) @@ -287,7 +268,7 @@ def sent_rewrite_formatter_w_o_last_dialog(dialog: Dict) -> List[Dict]: def cobot_formatter_dialog(dialog: Dict): # Used by: cobot_dialogact_formatter, cobot_classifiers_formatter - return unified_formatter( + return dream_formatter( dialog, result_keys=["utterances_histories", "annotation_histories"], preprocess=True, preprocess_params={"bot_last_turns": utils.LAST_N_TURNS} ) @@ -318,7 +299,7 @@ def base_response_selector_formatter_service(payload: List): def asr_formatter_dialog(dialog: Dict) -> List[Dict]: # Used by: asr_formatter - return unified_formatter( + return dream_formatter( dialog=dialog, result_keys=["speeches", "human_utterances"], preprocess=False, @@ -328,26 +309,25 @@ def asr_formatter_dialog(dialog: Dict) -> List[Dict]: def last_utt_dialog(dialog: Dict) -> List[Dict]: # Used by: dp_toxic_formatter, sent_segm_formatter, tfidf_formatter, sentiment_classification - return unified_formatter(dialog, result_keys=["sentences"]) + return dream_formatter(dialog, result_keys=["sentences"]) def preproc_last_human_utt_dialog(dialog: Dict) -> List[Dict]: # Used by: sentseg over human uttrs - return unified_formatter(dialog, result_keys=["speeches"], service_name="sentseg") + return dream_formatter(dialog, result_keys=["speeches"], service_name="sentseg") def entity_detection_formatter_dialog(dialog: Dict) -> List[Dict]: - return unified_formatter( + return dream_formatter( dialog, result_keys=["sentences"], preprocess=True, preprocess_params={"mode": "punct_sent", "remove_clarification": False, "replace_utterances": False} ) def property_extraction_formatter_dialog(dialog: Dict) -> List[Dict]: - return unified_formatter( + return dream_formatter( dialog=dialog, result_keys=["utterances", "entities_with_labels", "named_entities", "entity_info"], - last_n_utts=2, preprocess=True, preprocess_params={ "mode": "punct_sent", "bot_last_turns": 1, "remove_clarification": False, "replace_utterances": True @@ -357,7 +337,7 @@ def property_extraction_formatter_dialog(dialog: Dict) -> List[Dict]: def preproc_last_human_utt_dialog_w_hist(dialog: Dict) -> List[Dict]: # Used by: sentseg over human uttrs - return unified_formatter( + return dream_formatter( dialog=dialog, result_keys=["sentences", "sentences_with_history"], service_name="preproc_last_human_utt_dialog_w_hist", @@ -366,7 +346,7 @@ def preproc_last_human_utt_dialog_w_hist(dialog: Dict) -> List[Dict]: def preproc_and_tokenized_last_human_utt_dialog(dialog: Dict) -> List[Dict]: # Used by: sentseg over human uttrs - return unified_formatter(dialog=dialog, result_keys=["sentences", "tokenized_sentences"]) + return dream_formatter(dialog=dialog, result_keys=["sentences", "tokenized_sentences"]) def last_bot_utt_dialog(dialog: Dict) -> List[Dict]: @@ -396,10 +376,9 @@ def hypotheses_list_last_uttr(dialog: Dict) -> List[Dict]: def hypothesis_histories_list(dialog: Dict): - return unified_formatter( + return dream_formatter( dialog=dialog, result_keys=["utterances_with_histories"], - last_n_utts=1, preprocess=True, preprocess_params={ "mode": "segments", @@ -411,7 +390,7 @@ def hypothesis_histories_list(dialog: Dict): def last_utt_and_history_dialog(dialog: Dict) -> List: # Used by: topicalchat retrieval skills - return unified_formatter( + return dream_formatter( dialog, result_keys=["sentences", "utterances_histories"], preprocess=True, @@ -511,8 +490,17 @@ def persona_bot_formatter(dialog: Dict): ] +def cropped_dialog(dialog: Dict): + return dream_formatter(dialog, result_keys=["dialogs"], preprocess=True, preprocess_params={ + "mode": "punct_sent", + "bot_last_turns": None, + "remove_clarification": True, + "replace_utterances": True + }) + + def full_dialog(dialog: Dict): - return [{"dialogs": [dialog]}] + return dream_formatter(dialog, result_keys=["dialogs"], preprocess=False) def fetch_active_skills(bot_utterances: List[Dict]): @@ -530,8 +518,8 @@ def sentrewrite_dialog_formatter(dialog: Dict, bot_last_turns: Any, mode: str, a "mode": mode, "bot_last_turns": bot_last_turns, "remove_clarification": True, "replace_utterances": True - } - ) + } + ) if active_skills: return [{"dialogs": [dialog], "all_prev_active_skills": [all_prev_active_skills]}] @@ -555,6 +543,27 @@ def base_skill_formatter(payload: Dict): return [{"text": payload[0], "confidence": payload[1]}] +def generate_hypothesis(payload: List) -> Dict: + """ + Helper function to generate a single hypothesis from the payload. + """ + hypothesis = { + "text": payload[0], + "confidence": payload[1] + } + + if len(payload) >= 4: + hypothesis["human_attributes"] = payload[2] + hypothesis["bot_attributes"] = payload[3] + + if len(payload) == 3 or len(payload) == 5: + attributes = payload[-1] + assert isinstance(attributes, dict), "Attribute is a dictionary" + hypothesis.update(attributes) + + return hypothesis + + def skill_with_attributes_formatter_service(payload: List): """ Formatter should use `"state_manager_method": "add_hypothesis"` in config!!! @@ -570,39 +579,16 @@ def skill_with_attributes_formatter_service(payload: List): **attributes}, by ^ marked optional elements """ - # Used by: book_skill_formatter, skill_with_attributes_formatter, news_skill, meta_script_skill, dummy_skill - # deal with text & confidences + result = [] if isinstance(payload[0], list) and isinstance(payload[1], list): # several hypotheses from this skill - result = [] for hyp in zip(*payload): - if len(hyp[0]) > 0 and hyp[1] > 0.0: - full_hyp = {"text": hyp[0], "confidence": hyp[1]} - if len(payload) >= 4: - # have human and bot attributes in hyps - full_hyp["human_attributes"] = hyp[2] - full_hyp["bot_attributes"] = hyp[3] - if len(payload) == 3 or len(payload) == 5: - # have also attributes in hyps - assert isinstance(hyp[-1], dict), "Attribute is a dictionary" - for key in hyp[-1]: - full_hyp[key] = hyp[-1][key] - result += [full_hyp] + if hyp[0] and hyp[1] > 0.0: + result.append(generate_hypothesis(hyp)) else: - # only one hypotheses from this skill - if len(payload[0]) > 0 and payload[1] > 0.0: - result = [{"text": payload[0], "confidence": payload[1]}] - if len(payload) >= 4: - # have human and bot attributes in hyps - result[0]["human_attributes"] = payload[2] - result[0]["bot_attributes"] = payload[3] - if len(payload) == 3 or len(payload) == 5: - # have also attributes in hyps - assert isinstance(payload[-1], dict), "Attribute is a dictionary" - for key in payload[-1]: - result[0][key] = payload[-1][key] - else: - result = [] + # only one hypothesis from this skill + if payload[0] and payload[1] > 0.0: + result.append(generate_hypothesis(payload)) return result @@ -954,7 +940,7 @@ def dff_prompted_skill_formatter(dialog, skill_name=None): ) -def dff_universal_prompted_skill_formatter(dialog, skill_name=None): +def dff_universal_prompted_skill_formatter(dialog): return utils.dff_formatter( dialog, "dff_universal_prompted_skill", @@ -1058,7 +1044,7 @@ def hypothesis_scorer_formatter(dialog: Dict) -> List[Dict]: def topic_recommendation_formatter(dialog: Dict): - return unified_formatter( + return dream_formatter( dialog, result_keys=["active_skills", "cobot_topics"], preprocess=True, @@ -1067,11 +1053,11 @@ def topic_recommendation_formatter(dialog: Dict): def hypotheses_with_context_list(dialog: Dict) -> List[Dict]: - return unified_formatter(dialog, result_keys=["dialog_contexts", "hypotheses"]) + return dream_formatter(dialog, result_keys=["dialog_contexts", "hypotheses"]) def context_formatter_dialog(dialog: Dict) -> List[Dict]: - return unified_formatter( + return dream_formatter( dialog, result_keys=["contexts"], preprocess=True, preprocess_params={ "mode": "punct_sent", @@ -1083,7 +1069,7 @@ def context_formatter_dialog(dialog: Dict) -> List[Dict]: def midas_predictor_formatter(dialog: Dict): - return unified_formatter(dialog, result_keys=["last_midas_labels", "return_probas"]) + return dream_formatter(dialog, result_keys=["last_midas_labels", "return_probas"]) def image_captioning_formatter(dialog: Dict) -> List[Dict]: