Skip to content

Commit

Permalink
Update categorical node counter to use label
Browse files Browse the repository at this point in the history
  • Loading branch information
Sidharth Sundar committed Jun 15, 2022
1 parent f7eec67 commit 897a4e4
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions adam/learner/contrastive_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,10 @@ class TeachingContrastiveObjectLearner(ContrastiveLearner):
_ontology_node_present_in_difference: Counter[Tuple[Concept, OntologyNode]] = attrib(
validator=instance_of(Counter), factory=Counter, init=False
)
_categorical_values_present: Counter[Tuple[Concept, str]] = attrib(
_categorical_values_present: Counter[Tuple[Concept, str, str]] = attrib(
validator=instance_of(Counter), factory=Counter, init=False
)
_categorical_values_present_in_difference: Counter[Tuple[Concept, str]] = attrib(
_categorical_values_present_in_difference: Counter[Tuple[Concept, str, str]] = attrib(
validator=instance_of(Counter), factory=Counter, init=False
)

Expand Down Expand Up @@ -284,7 +284,9 @@ def _update_counts(
] += 1
# Otherwise, if it's categorical, count the value observed
elif isinstance(pattern_node, CategoricalPredicate):
self._categorical_values_present[concept, pattern_node.value] += 1
self._categorical_values_present[
concept, pattern_node.label, pattern_node.value
] += 1

if (
pattern_to_graph_match.pattern_node_to_matched_graph_node[
Expand All @@ -293,7 +295,7 @@ def _update_counts(
in difference_nodes
):
self._categorical_values_present_in_difference[
concept, pattern_node.value
concept, pattern_node.label, pattern_node.value
] += 1

def _propose_updated_hypothesis_to_apprentice(self, concept: Concept) -> None:
Expand Down Expand Up @@ -360,10 +362,10 @@ def _calculate_weight_for(self, concept: Concept, node: NodePredicate) -> float:
# Otherwise, if it's categorical, count the value observed
elif isinstance(node, CategoricalPredicate):
present_count = self._categorical_values_present.get(
(concept, node.value), 1
(concept, node.label, node.value), 1
)
in_difference_count = self._categorical_values_present_in_difference.get(
(concept, node.value), 1
(concept, node.label, node.value), 1
)

return 0.5 + in_difference_count / present_count
Expand Down

0 comments on commit 897a4e4

Please sign in to comment.