diff --git a/docs/docs/sources/rasa_interactive___help.txt b/docs/docs/sources/rasa_interactive___help.txt index 3ea3e8a28beb..4ab5ff98aca7 100644 --- a/docs/docs/sources/rasa_interactive___help.txt +++ b/docs/docs/sources/rasa_interactive___help.txt @@ -39,7 +39,7 @@ options: --conversation-id CONVERSATION_ID Specify the id of the conversation the messages are in. Defaults to a UUID that will be randomly - generated. (default: de04d0f298734aeabe093213937197d3) + generated. (default: d1e32dd965814ad3be8eb506bf4e281c) --endpoints ENDPOINTS Configuration file for the model server and the connectors as a yml file. (default: endpoints.yml) diff --git a/docs/docs/sources/rasa_shell___help.txt b/docs/docs/sources/rasa_shell___help.txt index e2104badea2d..622459875832 100644 --- a/docs/docs/sources/rasa_shell___help.txt +++ b/docs/docs/sources/rasa_shell___help.txt @@ -30,7 +30,7 @@ options: -h, --help show this help message and exit --conversation-id CONVERSATION_ID Set the conversation ID. (default: - 1fdbcdaf73d348a7aea93a95065c1262) + 2e227e84f7224337975b25bbe79bd32a) -m MODEL, --model MODEL Path to a trained Rasa model. If a directory is specified, it will use the latest model in this diff --git a/rasa/shared/core/training_data/visualization.py b/rasa/shared/core/training_data/visualization.py index 21176c67b4b7..3e7049cbc9fb 100644 --- a/rasa/shared/core/training_data/visualization.py +++ b/rasa/shared/core/training_data/visualization.py @@ -533,17 +533,17 @@ def _remove_auxiliary_nodes( graph.remove_node(TMP_NODE_ID) - if not len(list(graph.predecessors(END_NODE_ID))): + if not graph.predecessors(END_NODE_ID): graph.remove_node(END_NODE_ID) # remove duplicated "..." nodes after merging - ps = set() + predecessors_seen = set() for i in range(special_node_idx + 1, TMP_NODE_ID): - for pred in list(graph.predecessors(i)): - if pred in ps: + predecessors = graph.predecessors(i) + for pred in predecessors: + if pred in predecessors_seen: graph.remove_node(i) - else: - ps.add(pred) + predecessors_seen.update(predecessors) def visualize_stories( diff --git a/tests/shared/core/training_data/test_visualization.py b/tests/shared/core/training_data/test_visualization.py index 3bb4f4f4ad27..63604b5ef354 100644 --- a/tests/shared/core/training_data/test_visualization.py +++ b/tests/shared/core/training_data/test_visualization.py @@ -10,6 +10,8 @@ from rasa.shared.nlu.training_data.message import Message from rasa.shared.nlu.training_data.training_data import TrainingData +import pytest + def test_style_transfer(): r = visualization._transfer_style({"class": "dashed great"}, {"class": "myclass"}) @@ -188,3 +190,42 @@ def test_story_visualization_with_merging(domain: Domain): assert 15 < len(generated_graph.nodes()) < 33 assert 20 < len(generated_graph.edges()) < 33 + + +@pytest.mark.parametrize( + "input_nodes, input_edges, remove_count, expected_nodes, expected_edges", + [ + ( + [-2, -1, 0, 1, 2, 3, 4, 5], + [(-2, 0), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], + 3, + set([0, 1, 2, 3, 4, 5, -1]), + [(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], + ), + ( + [-3, -2, -1, 0, 1, 2, 3, 4, 5], + [(-3, -2), (-2, -1), (-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], + 4, + set([-3, -1, 0, 1, 2, 3, 4, 5]), + [(-1, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], + ), + ], +) +def test_remove_auxiliary_nodes( + input_nodes, input_edges, remove_count, expected_nodes, expected_edges +): + import networkx as nx + + # Create a sample graph + graph = nx.MultiDiGraph() + graph.add_nodes_from(input_nodes) + graph.add_edges_from(input_edges) + + # Call the method to remove auxiliary nodes + visualization._remove_auxiliary_nodes(graph, remove_count) + + # Check if the expected nodes are removed + assert set(graph.nodes()) == expected_nodes, "Nodes mismatch" + + # Check if the edges are updated correctly + assert list(graph.edges()) == expected_edges, "Edges mismatch"