Skip to content

Commit

Permalink
add some minimal typing
Browse files Browse the repository at this point in the history
Makes user code (in explainable_embedding repo specifically) complain less in VSCode.
  • Loading branch information
egpbos committed Jul 30, 2024
1 parent 12a0d03 commit 0089ded
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/distance_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dianna.utils.maskers import generate_masks_for_images
from sklearn.metrics import pairwise_distances
from tqdm import tqdm
import numpy.typing

logging.getLogger(__name__).addHandler(logging.NullHandler())

Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=.5, # pylint: disable=to
self.mask_selection_negative_range_min = mask_selection_negative_range_min
self.batch_size = batch_size

def explain_image_distance(self, model_or_function, input_data, embedded_reference, masks=None):
def explain_image_distance(self, model_or_function, input_data, embedded_reference, masks=None) -> tuple[numpy.typing.NDArray, float]:
"""Explain an image with respect to a reference point in an embedded space.
Args:
Expand Down Expand Up @@ -104,6 +105,7 @@ def describe(x, name):
statistics.append(describe(highest_mask_weights, 'highest_mask_weights'))
unnormalized_sal_highest = np.mean(highest_distances_masks, axis=0)
else:
print("HIGHEST IS NUL")
unnormalized_sal_highest = 0

lowest_distances_masks, lowest_mask_weights = self._get_lowest_distance_masks_and_weights(
Expand All @@ -116,6 +118,7 @@ def describe(x, name):
statistics.append(describe(lowest_mask_weights, 'lowest_mask_weights'))
unnormalized_sal_lowest = np.mean(lowest_distances_masks, axis=0)
else:
print("LOWEST IS NUL")
unnormalized_sal_lowest = 0

self.statistics = '\n'.join(statistics)
Expand Down

0 comments on commit 0089ded

Please sign in to comment.