Skip to content

Commit

Permalink
add some minimal typing
Browse files Browse the repository at this point in the history
Makes user code (in explainable_embedding repo specifically) complain less in VSCode.
  • Loading branch information
egpbos committed Jul 30, 2024
1 parent 12a0d03 commit e3291c0
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/distance_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dianna.utils.maskers import generate_masks_for_images
from sklearn.metrics import pairwise_distances
from tqdm import tqdm
import numpy.typing

logging.getLogger(__name__).addHandler(logging.NullHandler())

Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(self, n_masks=1000, feature_res=8, p_keep=.5, # pylint: disable=to
self.mask_selection_negative_range_min = mask_selection_negative_range_min
self.batch_size = batch_size

def explain_image_distance(self, model_or_function, input_data, embedded_reference, masks=None):
def explain_image_distance(self, model_or_function, input_data, embedded_reference, masks=None) -> tuple[numpy.typing.NDArray, float]:
"""Explain an image with respect to a reference point in an embedded space.
Args:
Expand Down

0 comments on commit e3291c0

Please sign in to comment.