From 0089ded0e37075d52035c5638979c250e46d904b Mon Sep 17 00:00:00 2001 From: "E. G. Patrick Bos" Date: Tue, 30 Jul 2024 16:34:43 +0200 Subject: [PATCH] add some minimal typing Makes user code (in explainable_embedding repo specifically) complain less in VSCode. --- src/distance_explainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/distance_explainer.py b/src/distance_explainer.py index 463be17..13a248b 100644 --- a/src/distance_explainer.py +++ b/src/distance_explainer.py @@ -6,6 +6,7 @@ from dianna.utils.maskers import generate_masks_for_images from sklearn.metrics import pairwise_distances from tqdm import tqdm +import numpy.typing logging.getLogger(__name__).addHandler(logging.NullHandler()) @@ -50,7 +51,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=.5, # pylint: disable=to self.mask_selection_negative_range_min = mask_selection_negative_range_min self.batch_size = batch_size - def explain_image_distance(self, model_or_function, input_data, embedded_reference, masks=None): + def explain_image_distance(self, model_or_function, input_data, embedded_reference, masks=None) -> tuple[numpy.typing.NDArray, float]: """Explain an image with respect to a reference point in an embedded space. Args: @@ -104,6 +105,7 @@ def describe(x, name): statistics.append(describe(highest_mask_weights, 'highest_mask_weights')) unnormalized_sal_highest = np.mean(highest_distances_masks, axis=0) else: + print("HIGHEST IS NUL") unnormalized_sal_highest = 0 lowest_distances_masks, lowest_mask_weights = self._get_lowest_distance_masks_and_weights( @@ -116,6 +118,7 @@ def describe(x, name): statistics.append(describe(lowest_mask_weights, 'lowest_mask_weights')) unnormalized_sal_lowest = np.mean(lowest_distances_masks, axis=0) else: + print("LOWEST IS NUL") unnormalized_sal_lowest = 0 self.statistics = '\n'.join(statistics)