Skip to content

Commit

Permalink
Update variable names
Browse files Browse the repository at this point in the history
With pattern matching and hypothesis updates, variables are now altered to indicate that there are multiple hypotheses being matched against/proposed for updates. Also requires updating function parameter in abstract hypothesis update function.
  • Loading branch information
Sidharth Sundar committed Jun 15, 2022
1 parent f8feceb commit f7eec67
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 54 deletions.
6 changes: 5 additions & 1 deletion adam/learner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,17 @@ def concepts_to_patterns(self) -> Dict[Concept, PerceptionGraphPattern]:
@abstractmethod
def propose_updated_hypotheses(
self,
concept_to_updated_patterns: Dict[
concept_to_hypothesis_updates: Dict[
Concept, Dict[PerceptionGraphTemplate, PerceptionGraphPattern]
],
) -> None:
"""
Propose new/updated hypotheses to the learner.
This expects hypothesis updates to be given as mappings of the form `old_hypothesis ->
new_hypothesis` where `old_hypothesis` is a `PerceptionGraphTemplate` and `new_hypothesis`
is a `PerceptionGraphPattern`.
The learner may do with these as it will.
"""

Expand Down
91 changes: 43 additions & 48 deletions adam/learner/contrastive_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ def learn_from(self, matching: LanguagePerceptionSemanticContrast) -> None:

def _match_concept_pattern_to_multiple_graphs(
self, concept: Concept, graph: PerceptionGraph
) -> Optional[
Mapping[PerceptionGraphTemplate, Optional[PerceptionGraphPatternMatch]]
]:
) -> Mapping[PerceptionGraphTemplate, Optional[PerceptionGraphPatternMatch]]:
return {
apprentice_concept: apprentice_concept.graph_pattern.matcher(
graph,
Expand Down Expand Up @@ -255,57 +253,54 @@ def _update_counts(
difference_nodes: ImmutableSet[PerceptionGraphNode],
) -> None:
for pattern_to_graph_match in pattern_to_graph_matches:
if pattern_to_graph_match:
for pattern_node in pattern_to_graph_match.matched_pattern:
# If it's a stroke GNN recognition node, count the "recognized object" string observed
if isinstance(pattern_node, StrokeGNNRecognitionPredicate):
self._stroke_gnn_recognition_nodes_present[
for pattern_node in pattern_to_graph_match.matched_pattern:
# If it's a stroke GNN recognition node, count the "recognized object" string observed
if isinstance(pattern_node, StrokeGNNRecognitionPredicate):
self._stroke_gnn_recognition_nodes_present[
concept, pattern_node.recognized_object
] += 1

if (
pattern_to_graph_match.pattern_node_to_matched_graph_node[
pattern_node
]
in difference_nodes
):
self._stroke_gnn_recognition_nodes_present_in_difference[
concept, pattern_node.recognized_object
] += 1

if (
pattern_to_graph_match.pattern_node_to_matched_graph_node[
pattern_node
]
in difference_nodes
):
self._stroke_gnn_recognition_nodes_present_in_difference[
concept, pattern_node.recognized_object
] += 1
# If it has an ontology node, do X
elif isinstance(pattern_node, IsOntologyNodePredicate):
self._ontology_node_present[
# If it has an ontology node, count the observed property value
elif isinstance(pattern_node, IsOntologyNodePredicate):
self._ontology_node_present[concept, pattern_node.property_value] += 1

if (
pattern_to_graph_match.pattern_node_to_matched_graph_node[
pattern_node
]
in difference_nodes
):
self._ontology_node_present_in_difference[
concept, pattern_node.property_value
] += 1

if (
pattern_to_graph_match.pattern_node_to_matched_graph_node[
pattern_node
]
in difference_nodes
):
self._ontology_node_present_in_difference[
concept, pattern_node.property_value
] += 1
# Otherwise, if it's categorical, count the value observed
elif isinstance(pattern_node, CategoricalPredicate):
self._categorical_values_present[concept, pattern_node.value] += 1

if (
pattern_to_graph_match.pattern_node_to_matched_graph_node[
pattern_node
]
in difference_nodes
):
self._categorical_values_present_in_difference[
concept, pattern_node.value
] += 1
# Otherwise, if it's categorical, count the value observed
elif isinstance(pattern_node, CategoricalPredicate):
self._categorical_values_present[concept, pattern_node.value] += 1

if (
pattern_to_graph_match.pattern_node_to_matched_graph_node[
pattern_node
]
in difference_nodes
):
self._categorical_values_present_in_difference[
concept, pattern_node.value
] += 1

def _propose_updated_hypothesis_to_apprentice(self, concept: Concept) -> None:
patterns = self.apprentice.concept_to_hypotheses(concept, self._top_n)
hypotheses = self.apprentice.concept_to_hypotheses(concept, self._top_n)
updated_hypotheses = dict()
for pattern_template in patterns:
pattern = pattern_template.graph_pattern
for hypothesis in hypotheses:
pattern = hypothesis.graph_pattern
old_to_new_node = {}
old_pattern_digraph = pattern.copy_as_digraph()
new_pattern_digraph = DiGraph()
Expand All @@ -319,7 +314,7 @@ def _propose_updated_hypothesis_to_apprentice(self, concept: Concept) -> None:
old_to_new_node[old_u], old_to_new_node[old_v], **data
)

updated_hypotheses[pattern_template] = PerceptionGraphPattern(
updated_hypotheses[hypothesis] = PerceptionGraphPattern(
new_pattern_digraph, dynamic=pattern.dynamic
)
self.apprentice.propose_updated_hypotheses({concept: updated_hypotheses})
Expand Down
12 changes: 7 additions & 5 deletions adam/learner/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,15 @@ def get_concepts(self) -> ImmutableSet[Concept]:

def propose_updated_hypotheses(
self,
concept_to_updated_patterns: Dict[
concept_to_hypothesis_updates: Dict[
Concept,
Dict[PerceptionGraphTemplate, PerceptionGraphPattern],
],
) -> None:
for concept, pattern_updates in concept_to_updated_patterns.items():
for concept, hypothesis_update in concept_to_hypothesis_updates.items():
if concept in self.concept_to_surface_template:
if len(self._concept_to_hypotheses[concept]) >= 1:
for hypothesis, pattern_update in pattern_updates.items():
for hypothesis, pattern_update in hypothesis_update.items():
if hypothesis not in self._concept_to_hypotheses[concept]:
raise ValueError(
f"Target hypothesis to update {hypothesis} not present among preexisting hypotheses"
Expand Down Expand Up @@ -312,8 +312,10 @@ def propose_updated_hypotheses(

self._concept_to_hypotheses[concept] = immutableset(
[
evolve(hypothesis, graph_pattern=pattern_updates[hypothesis])
if hypothesis in pattern_updates
evolve(
hypothesis, graph_pattern=hypothesis_update[hypothesis]
)
if hypothesis in hypothesis_update
else hypothesis
for hypothesis in self._concept_to_hypotheses[concept]
]
Expand Down

0 comments on commit f7eec67

Please sign in to comment.