Skip to content

Commit

Permalink
skip neutral value tests, refs #19, #16
Browse files Browse the repository at this point in the history
  • Loading branch information
cwmeijer committed Nov 5, 2024
1 parent 4bc6b9a commit d7dd41d
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions tests/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def set_all_the_seeds(seed_value=0):
os.environ['PYTHONHASHSEED'] = str(seed_value)
np.random.seed(seed_value)


@pytest.fixture()
def dummy_model() -> Callable:
"""Get a dummy model that returns a random embedding for every input in a batch."""
Expand Down Expand Up @@ -46,41 +47,66 @@ def get_explainer(config: Config, axis_labels={2: 'channels'}, preprocess_functi
return explainer


def test_distance_explainer(dummy_data: tuple[ArrayLike, ArrayLike],
dummy_model: Callable):
"""Code output should be identical to recorded output."""
def test_distance_explainer_saliency(dummy_data: tuple[ArrayLike, ArrayLike],
dummy_model: Callable):
"""Code output should be identical to recorded saliency."""
embedded_reference, input_arr = dummy_data
explainer = get_explainer(get_default_config())
expected_saliency, expected_value = np.load('./tests/test_data/test_dummy_data_exact_expected_output.npz').values()

saliency, value = explainer.explain_image_distance(dummy_model, input_arr, embedded_reference)

assert saliency.shape == (1,) + input_arr.shape[:2] + (1,) # Has correct shape
assert np.allclose(expected_saliency, saliency) # Has correct saliency


@pytest.mark.skip("See 'neutral value not correct #19', https://github.com/dianna-ai/distance_explainer/issues/19")
def test_distance_explainer_value(dummy_data: tuple[ArrayLike, ArrayLike],
dummy_model: Callable):
"""Code output should be identical to recorded value."""
embedded_reference, input_arr = dummy_data
explainer = get_explainer(get_default_config())
expected_saliency, expected_value = np.load('./tests/test_data/test_dummy_data_exact_expected_output.npz').values()
assert np.allclose(expected_saliency, saliency) # Has correct value

saliency, value = explainer.explain_image_distance(dummy_model, input_arr, embedded_reference)

assert np.allclose(expected_value, value) # Has correct value


@pytest.mark.parametrize("empty_side,expected_tag",
[({"mask_selection_range_max": 0.}, "pos_empty"),
({"mask_selection_negative_range_min": 1.}, "neg_empty")])
def test_distance_explainer_one_sided(dummy_data: tuple[ArrayLike, ArrayLike],
dummy_model: Callable,
empty_side: dict[str, float],
expected_tag: str):
"""Code output should be identical to recorded output."""
def test_distance_explainer_one_sided_saliency(dummy_data: tuple[ArrayLike, ArrayLike],
dummy_model: Callable,
empty_side: dict[str, float],
expected_tag: str):
"""Code output should be identical to recorded saliency."""
embedded_reference, input_arr = dummy_data

expected_saliency, expected_value = np.load(
f'./tests/test_data/test_dummy_data_exact_expected_output_{expected_tag}.npz').values()
config = dataclasses.replace(get_default_config(), **empty_side)
explainer = get_explainer(config)

saliency, value = explainer.explain_image_distance(dummy_model, input_arr, embedded_reference)

assert saliency.shape == (1,) + input_arr.shape[:2] + (1,) # Has correct shape
assert np.allclose(expected_saliency, saliency) # Has correct saliency


# np.savez(f'./tests/test_data/test_dummy_data_exact_expected_output_{expected_tag}.npz',
# expected_saliency=saliency, expected_value=value)
@pytest.mark.skip("See 'neutral value not correct #19', https://github.com/dianna-ai/distance_explainer/issues/19")
@pytest.mark.parametrize("empty_side,expected_tag",
[({"mask_selection_range_max": 0.}, "pos_empty"),
({"mask_selection_negative_range_min": 1.}, "neg_empty")])
def test_distance_explainer_one_sided_value(dummy_data: tuple[ArrayLike, ArrayLike],
dummy_model: Callable,
empty_side: dict[str, float],
expected_tag: str):
"""Code output should be identical to recorded saliency."""
embedded_reference, input_arr = dummy_data
expected_saliency, expected_value = np.load(
f'./tests/test_data/test_dummy_data_exact_expected_output_{expected_tag}.npz').values()
assert np.allclose(expected_saliency, saliency) # Has correct value
config = dataclasses.replace(get_default_config(), **empty_side)
explainer = get_explainer(config)

saliency, value = explainer.explain_image_distance(dummy_model, input_arr, embedded_reference)

assert np.allclose(expected_value, value) # Has correct value

0 comments on commit d7dd41d

Please sign in to comment.