From f7eec67315fa476a00199769d8fca8192b051db1 Mon Sep 17 00:00:00 2001 From: Sidharth Sundar Date: Tue, 14 Jun 2022 11:50:01 -0400 Subject: [PATCH] Update variable names With pattern matching and hypothesis updates, variables are now altered to indicate that there are multiple hypotheses being matched against/proposed for updates. Also requires updating function parameter in abstract hypothesis update function. --- adam/learner/__init__.py | 6 +- adam/learner/contrastive_learner.py | 91 ++++++++++++++--------------- adam/learner/subset.py | 12 ++-- 3 files changed, 55 insertions(+), 54 deletions(-) diff --git a/adam/learner/__init__.py b/adam/learner/__init__.py index 8919c45ab..4d3f681ad 100644 --- a/adam/learner/__init__.py +++ b/adam/learner/__init__.py @@ -325,13 +325,17 @@ def concepts_to_patterns(self) -> Dict[Concept, PerceptionGraphPattern]: @abstractmethod def propose_updated_hypotheses( self, - concept_to_updated_patterns: Dict[ + concept_to_hypothesis_updates: Dict[ Concept, Dict[PerceptionGraphTemplate, PerceptionGraphPattern] ], ) -> None: """ Propose new/updated hypotheses to the learner. + This expects hypothesis updates to be given as mappings of the form `old_hypothesis -> + new_hypothesis` where `old_hypothesis` is a `PerceptionGraphTemplate` and `new_hypothesis` + is a `PerceptionGraphPattern`. + The learner may do with these as it will. """ diff --git a/adam/learner/contrastive_learner.py b/adam/learner/contrastive_learner.py index 32292e5af..58d10928b 100644 --- a/adam/learner/contrastive_learner.py +++ b/adam/learner/contrastive_learner.py @@ -192,9 +192,7 @@ def learn_from(self, matching: LanguagePerceptionSemanticContrast) -> None: def _match_concept_pattern_to_multiple_graphs( self, concept: Concept, graph: PerceptionGraph - ) -> Optional[ - Mapping[PerceptionGraphTemplate, Optional[PerceptionGraphPatternMatch]] - ]: + ) -> Mapping[PerceptionGraphTemplate, Optional[PerceptionGraphPatternMatch]]: return { apprentice_concept: apprentice_concept.graph_pattern.matcher( graph, @@ -255,57 +253,54 @@ def _update_counts( difference_nodes: ImmutableSet[PerceptionGraphNode], ) -> None: for pattern_to_graph_match in pattern_to_graph_matches: - if pattern_to_graph_match: - for pattern_node in pattern_to_graph_match.matched_pattern: - # If it's a stroke GNN recognition node, count the "recognized object" string observed - if isinstance(pattern_node, StrokeGNNRecognitionPredicate): - self._stroke_gnn_recognition_nodes_present[ + for pattern_node in pattern_to_graph_match.matched_pattern: + # If it's a stroke GNN recognition node, count the "recognized object" string observed + if isinstance(pattern_node, StrokeGNNRecognitionPredicate): + self._stroke_gnn_recognition_nodes_present[ + concept, pattern_node.recognized_object + ] += 1 + + if ( + pattern_to_graph_match.pattern_node_to_matched_graph_node[ + pattern_node + ] + in difference_nodes + ): + self._stroke_gnn_recognition_nodes_present_in_difference[ concept, pattern_node.recognized_object ] += 1 - - if ( - pattern_to_graph_match.pattern_node_to_matched_graph_node[ - pattern_node - ] - in difference_nodes - ): - self._stroke_gnn_recognition_nodes_present_in_difference[ - concept, pattern_node.recognized_object - ] += 1 - # If it has an ontology node, do X - elif isinstance(pattern_node, IsOntologyNodePredicate): - self._ontology_node_present[ + # If it has an ontology node, count the observed property value + elif isinstance(pattern_node, IsOntologyNodePredicate): + self._ontology_node_present[concept, pattern_node.property_value] += 1 + + if ( + pattern_to_graph_match.pattern_node_to_matched_graph_node[ + pattern_node + ] + in difference_nodes + ): + self._ontology_node_present_in_difference[ concept, pattern_node.property_value ] += 1 - - if ( - pattern_to_graph_match.pattern_node_to_matched_graph_node[ - pattern_node - ] - in difference_nodes - ): - self._ontology_node_present_in_difference[ - concept, pattern_node.property_value - ] += 1 - # Otherwise, if it's categorical, count the value observed - elif isinstance(pattern_node, CategoricalPredicate): - self._categorical_values_present[concept, pattern_node.value] += 1 - - if ( - pattern_to_graph_match.pattern_node_to_matched_graph_node[ - pattern_node - ] - in difference_nodes - ): - self._categorical_values_present_in_difference[ - concept, pattern_node.value - ] += 1 + # Otherwise, if it's categorical, count the value observed + elif isinstance(pattern_node, CategoricalPredicate): + self._categorical_values_present[concept, pattern_node.value] += 1 + + if ( + pattern_to_graph_match.pattern_node_to_matched_graph_node[ + pattern_node + ] + in difference_nodes + ): + self._categorical_values_present_in_difference[ + concept, pattern_node.value + ] += 1 def _propose_updated_hypothesis_to_apprentice(self, concept: Concept) -> None: - patterns = self.apprentice.concept_to_hypotheses(concept, self._top_n) + hypotheses = self.apprentice.concept_to_hypotheses(concept, self._top_n) updated_hypotheses = dict() - for pattern_template in patterns: - pattern = pattern_template.graph_pattern + for hypothesis in hypotheses: + pattern = hypothesis.graph_pattern old_to_new_node = {} old_pattern_digraph = pattern.copy_as_digraph() new_pattern_digraph = DiGraph() @@ -319,7 +314,7 @@ def _propose_updated_hypothesis_to_apprentice(self, concept: Concept) -> None: old_to_new_node[old_u], old_to_new_node[old_v], **data ) - updated_hypotheses[pattern_template] = PerceptionGraphPattern( + updated_hypotheses[hypothesis] = PerceptionGraphPattern( new_pattern_digraph, dynamic=pattern.dynamic ) self.apprentice.propose_updated_hypotheses({concept: updated_hypotheses}) diff --git a/adam/learner/subset.py b/adam/learner/subset.py index bcd095f8d..7d6922756 100644 --- a/adam/learner/subset.py +++ b/adam/learner/subset.py @@ -274,15 +274,15 @@ def get_concepts(self) -> ImmutableSet[Concept]: def propose_updated_hypotheses( self, - concept_to_updated_patterns: Dict[ + concept_to_hypothesis_updates: Dict[ Concept, Dict[PerceptionGraphTemplate, PerceptionGraphPattern], ], ) -> None: - for concept, pattern_updates in concept_to_updated_patterns.items(): + for concept, hypothesis_update in concept_to_hypothesis_updates.items(): if concept in self.concept_to_surface_template: if len(self._concept_to_hypotheses[concept]) >= 1: - for hypothesis, pattern_update in pattern_updates.items(): + for hypothesis, pattern_update in hypothesis_update.items(): if hypothesis not in self._concept_to_hypotheses[concept]: raise ValueError( f"Target hypothesis to update {hypothesis} not present among preexisting hypotheses" @@ -312,8 +312,10 @@ def propose_updated_hypotheses( self._concept_to_hypotheses[concept] = immutableset( [ - evolve(hypothesis, graph_pattern=pattern_updates[hypothesis]) - if hypothesis in pattern_updates + evolve( + hypothesis, graph_pattern=hypothesis_update[hypothesis] + ) + if hypothesis in hypothesis_update else hypothesis for hypothesis in self._concept_to_hypotheses[concept] ]