Skip to content

Commit ca3f93d

Browse files
mkoistinenjackx111
andauthored
22911: Renames Feature Influence detail keys and removes feature influence matrices, MAJOR (#361)
- Renames the supported keys - Adds deprecation support for old key usage (both in and out) - Adds test to ensure deprecated keys, if used, return the correct details - Removes the methods, trainee.get_mda_matrix, trainee.get_contribution_matrix - Removes the utility methods get_matrix_diff NOTE: This won't pass tests until howsoai/howso-engine#435 is merged. --------- Co-authored-by: jack-xia-dp <[email protected]>
1 parent d3767bd commit ca3f93d

File tree

11 files changed

+735
-763
lines changed

11 files changed

+735
-763
lines changed

howso/client/base.py

+226-153
Large diffs are not rendered by default.

howso/client/schemas/reaction.py

+51-15
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import numpy as np
1010
import pandas as pd
1111

12+
from howso.utilities.constants import _RENAMED_DETAIL_KEYS # type: ignore reportPrivateUsage
13+
14+
1215
__all__ = [
1316
"Reaction"
1417
]
@@ -44,24 +47,57 @@ class Reaction(abc.MutableMapping):
4447

4548
SPECIAL_KEYS = {"action_features", }
4649
KNOWN_KEYS = {
47-
"case_directional_feature_contributions_full", "case_directional_feature_contributions_robust",
48-
"directional_feature_contributions_full", "directional_feature_contributions_robust",
49-
"boundary_cases_familiarity_convictions", "boundary_cases",
50-
"boundary_values", "feature_case_contributions_full",
51-
"feature_case_contributions_robust", "case_feature_contributions_full",
52-
"case_feature_contributions_robust", "case_feature_residuals_robust", "case_feature_residuals_full",
53-
"case_mda_full", "case_mda_robust", "categorical_action_probabilities", "context_values",
54-
"derivation_parameters", "distance_contribution", "distance_ratio_parts", "distance_ratio",
55-
"feature_contributions_full", "feature_contributions_robust", "feature_deviations",
56-
"feature_mda_ex_post_full", "feature_mda_ex_post_robust", "feature_mda_robust", "feature_mda_full",
57-
"feature_residuals_full", "feature_residuals_robust", "generate_attempts", "series_generate_attempts",
58-
"hypothetical_values", "influential_cases_familiarity_convictions", "influential_cases_raw_weights",
59-
"influential_cases", "case_feature_residual_convictions_full",
60-
"case_feature_residual_convictions_robust", "most_similar_case_indices", "most_similar_cases",
61-
"observational_errors", "prediction_stats", "outlying_feature_values", "robust_influences",
50+
"boundary_cases",
51+
"boundary_cases_familiarity_convictions",
52+
"boundary_values",
53+
"case_full_accuracy_contributions",
54+
"case_full_prediction_contributions",
55+
"case_robust_accuracy_contributions",
56+
"case_robust_prediction_contributions",
57+
"categorical_action_probabilities",
58+
"context_values",
59+
"derivation_parameters",
60+
"distance_contribution",
61+
"distance_ratio_parts",
62+
"distance_ratio",
63+
"feature_deviations",
64+
"feature_full_accuracy_contributions_ex_post",
65+
"feature_full_accuracy_contributions",
66+
"feature_full_directional_prediction_contributions",
67+
"feature_full_directional_prediction_contributions_for_case",
68+
"feature_full_prediction_contributions_for_case",
69+
"feature_full_prediction_contributions",
70+
"feature_full_residual_convictions_for_case",
71+
"feature_full_residuals_for_case",
72+
"feature_full_residuals",
73+
"feature_robust_accuracy_contributions_ex_post",
74+
"feature_robust_accuracy_contributions",
75+
"feature_robust_directional_prediction_contributions",
76+
"feature_robust_directional_prediction_contributions_for_case",
77+
"feature_robust_prediction_contributions_for_case",
78+
"feature_robust_prediction_contributions",
79+
"feature_robust_residual_convictions_for_case",
80+
"feature_robust_residuals_for_case",
81+
"feature_robust_residuals",
82+
"generate_attempts",
83+
"hypothetical_values",
84+
"influential_cases_familiarity_convictions",
85+
"influential_cases_raw_weights",
86+
"influential_cases",
87+
"most_similar_case_indices",
88+
"most_similar_cases",
89+
"observational_errors",
90+
"outlying_feature_values",
91+
"prediction_stats",
92+
"robust_influences",
93+
"series_generate_attempts",
6294
"similarity_conviction",
6395
}
6496

97+
# These detail keys are deprecated, but should be treated as KNOWN_KEYs
98+
# during the deprecation period.
99+
KNOWN_KEYS |= set(_RENAMED_DETAIL_KEYS.keys())
100+
65101
def __init__(self,
66102
action: t.Optional[pd.DataFrame | list | dict] = None,
67103
details: t.Optional[abc.MutableMapping[str, t.Any]] = None

howso/client/tests/test_client.py

+116-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Mapping
12
import importlib.metadata
23
import json
34
import os
@@ -13,12 +14,18 @@
1314

1415
import howso
1516
from howso.client import HowsoClient
16-
from howso.client.client import _check_isfile, get_configuration_path, get_howso_client_class, LEGACY_CONFIG_FILENAMES
17+
from howso.client.client import (
18+
_check_isfile, # type: ignore reportPrivateUsage
19+
get_configuration_path,
20+
get_howso_client_class,
21+
LEGACY_CONFIG_FILENAMES
22+
)
1723
from howso.client.exceptions import HowsoApiError, HowsoConfigurationError, HowsoError
1824
from howso.client.protocols import ProjectClient
1925
from howso.client.schemas.reaction import Reaction
2026
from howso.direct import HowsoDirectClient
2127
from howso.utilities.testing import get_configurationless_test_client, get_test_options
28+
from howso.utilities.constants import _RENAMED_DETAIL_KEYS, _RENAMED_DETAIL_KEYS_EXTRA # type: ignore reportPrivateUsage
2229

2330
TEST_OPTIONS = get_test_options()
2431

@@ -577,8 +584,8 @@ def test_a_la_cart_data(self, trainee):
577584
['similarity_conviction', ]
578585
),
579586
(
580-
{'feature_residuals_robust': True, },
581-
['feature_residuals_robust', ]
587+
{'feature_robust_residuals': True, },
588+
['feature_robust_residuals', ]
582589
),
583590
]
584591
for audit_detail_set, keys_to_expect in details_sets:
@@ -590,6 +597,112 @@ def test_a_la_cart_data(self, trainee):
590597
details = response['details']
591598
assert (all(details[key] is not None for key in keys_to_expect))
592599

600+
@pytest.mark.parametrize('old_key,new_key', _RENAMED_DETAIL_KEYS.items())
601+
def test_deprecated_detail_keys_react(self, trainee, old_key, new_key):
602+
"""Ensure using any of the deprecated keys raises a warning, but continues to work."""
603+
# These keys shouldn't be tested like this:
604+
if new_key in [
605+
"feature_full_directional_prediction_contributions",
606+
"feature_robust_directional_prediction_contributions",
607+
"feature_full_accuracy_contributions_permutation",
608+
"feature_robust_accuracy_contributions_permutation",
609+
]:
610+
return True
611+
612+
with pytest.warns(DeprecationWarning) as record:
613+
self.client.train(
614+
trainee.id, [[1, 2], [1, 2], [1, 2]],
615+
features=['penguin', 'play']
616+
)
617+
reaction = self.client.react(
618+
trainee.id,
619+
contexts=[['1']],
620+
context_features=['penguin'],
621+
action_features=['play'],
622+
details={old_key: True}
623+
)
624+
625+
# Check that the correct warning was raised.
626+
assert len(record)
627+
# There may be multiple warnings. Ensure at least one of them contains
628+
# the deprecation message.
629+
assert any([
630+
f"'{old_key}' is deprecated" in str(r.message)
631+
for r in record
632+
])
633+
634+
# We DO want the old_key to be present during the deprecation period.
635+
assert old_key in reaction.get('details', {}).keys()
636+
637+
# We do NOT want the new_key present during the deprecation period.
638+
assert new_key not in reaction.get('details', {}).keys()
639+
640+
# Some keys request multiple keys to be returned, these too should be
641+
# converted to the old names if the old name was originally used.
642+
if old_key in _RENAMED_DETAIL_KEYS_EXTRA.keys():
643+
for old_extra_key, new_extra_key in _RENAMED_DETAIL_KEYS_EXTRA[old_key]["additional_keys"].items():
644+
assert new_extra_key not in reaction.get('details', {}).keys()
645+
assert old_extra_key in reaction.get('details', {}).keys()
646+
647+
@pytest.mark.parametrize('old_key,new_key', _RENAMED_DETAIL_KEYS.items())
648+
def test_deprecated_detail_keys_react_aggregate(self, trainee, old_key, new_key):
649+
"""Ensure using any of the deprecated keys raises a warning, but continues to work."""
650+
# These keys shouldn't be tested like this:
651+
if new_key in {
652+
"case_full_prediction_contributions",
653+
"case_robust_prediction_contributions",
654+
"feature_full_prediction_contributions_for_case",
655+
"feature_robust_prediction_contributions_for_case",
656+
"feature_full_residual_convictions_for_case",
657+
"feature_robust_residual_convictions_for_case",
658+
"feature_full_residuals_for_case",
659+
"feature_robust_residuals_for_case",
660+
"case_full_accuracy_contributions",
661+
"case_robust_accuracy_contributions",
662+
"feature_full_directional_prediction_contributions",
663+
"feature_robust_directional_prediction_contributions",
664+
"feature_full_accuracy_contributions_ex_post",
665+
"feature_robust_accuracy_contributions_ex_post",
666+
}:
667+
return
668+
669+
with pytest.warns(DeprecationWarning) as record:
670+
self.client.train(
671+
trainee.id, [[1, 2], [1, 2], [1, 2]],
672+
features=['penguin', 'play']
673+
)
674+
response = self.client.react_aggregate(
675+
trainee.id,
676+
action_feature='penguin',
677+
num_samples=1,
678+
details={old_key: True}
679+
)
680+
681+
# Check that the correct warning was raised.
682+
assert len(record)
683+
# There may be multiple warnings. Ensure at least one of them contains
684+
# the deprecation message.
685+
assert any([
686+
f"'{old_key}' is deprecated" in str(r.message)
687+
for r in record
688+
])
689+
690+
# No point in testing further if we didn't get back a Mapping instance.
691+
assert isinstance(response, Mapping), "react_aggregate did not return a Mapping."
692+
693+
# We DO want the old_key to be present during the deprecation period.
694+
assert old_key in response.keys()
695+
696+
# We do NOT want the new_key present during the deprecation period.
697+
assert new_key not in response.keys()
698+
699+
# Some keys request multiple keys to be returned, these too should be
700+
# converted to the old names if the old name was originally used.
701+
if old_key in _RENAMED_DETAIL_KEYS_EXTRA.keys():
702+
for old_extra_key, new_extra_key in _RENAMED_DETAIL_KEYS_EXTRA[old_key]["additional_keys"].items():
703+
assert new_extra_key not in response.keys()
704+
assert old_extra_key in response.keys()
705+
593706
def test_get_version(self):
594707
"""Test get_version()."""
595708
version = self.client.get_version()

howso/engine/tests/test_engine.py

-48
Original file line numberDiff line numberDiff line change
@@ -319,54 +319,6 @@ def test_delete_method_standalone_bad(self):
319319
):
320320
delete_trainee(file_path=file_path)
321321

322-
def test_get_contribution_matrix(self, trainee):
323-
"""Test `get_contribution_matrix`."""
324-
matrix = trainee.get_contribution_matrix(
325-
normalize=True,
326-
fill_diagonal=True
327-
)
328-
assert len(matrix) == 5
329-
assert len(matrix.columns) == 5
330-
331-
# The raw matrix is saved in the trainee. This section
332-
# tests to make sure the matrix processing parameters are
333-
# passed through correctly.
334-
saved_matrix = trainee.calculated_matrices
335-
assert len(saved_matrix['contribution']) == 5
336-
assert len(saved_matrix['contribution'].columns) == 5
337-
338-
saved_matrix = matrix_processing(
339-
saved_matrix['contribution'],
340-
normalize=True,
341-
fill_diagonal=True
342-
)
343-
344-
assert_frame_equal(matrix, saved_matrix)
345-
346-
def test_get_mda_matrix(self, trainee):
347-
"""Test `get_mda_matrix`."""
348-
matrix = trainee.get_mda_matrix(
349-
absolute=True,
350-
fill_diagonal=True
351-
)
352-
assert len(matrix) == 5
353-
assert len(matrix.columns) == 5
354-
355-
# The raw matrix is saved in the trainee. This section
356-
# tests to make sure the matrix processing parameters are
357-
# passed through correctly.
358-
saved_matrix = trainee.calculated_matrices
359-
assert len(saved_matrix['mda']) == 5
360-
assert len(saved_matrix['mda'].columns) == 5
361-
362-
saved_matrix = matrix_processing(
363-
saved_matrix['mda'],
364-
absolute=True,
365-
fill_diagonal=True
366-
)
367-
368-
assert_frame_equal(matrix, saved_matrix)
369-
370322
def test_reduce_data(self, trainee):
371323
"""Test `reduce_data`."""
372324
pre_reduction_cases = trainee.get_cases()

0 commit comments

Comments
 (0)