Skip to content

Commit

Permalink
move attribution function to a method of the classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Sep 21, 2024
1 parent cfb7c6b commit 734b087
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 26 deletions.
33 changes: 14 additions & 19 deletions applications/contrastive_phenotyping/evaluation/grad_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from viscy.representation.evaluation import load_annotation
from viscy.representation.lca import (
AssembledClassifier,
attribute_sample_binary,
fit_logistic_regression,
linear_from_binary_logistic_regression,
)
Expand Down Expand Up @@ -167,27 +166,22 @@
img = sample["anchor"].numpy()

# %%
img_tensor = torch.from_numpy(img)

with torch.inference_mode():
infection_probs = assembled_classifier_infection(torch.from_numpy(img)).sigmoid()
division_probs = assembled_classifier_division(torch.from_numpy(img)).sigmoid()
infection_probs = assembled_classifier_infection(img_tensor).sigmoid()
division_probs = assembled_classifier_division(img_tensor).sigmoid()

# %%
img, infection_attribution = attribute_sample_binary(
img,
assembled_classifier_infection,
# multiply_by_inputs=False,
)
_, division_attribution = attribute_sample_binary(
img,
assembled_classifier_division,
# multiply_by_inputs=False,
)
infection_attribution = assembled_classifier_infection.attribute_sample_binary(
img_tensor # , multiply_by_inputs=False
).numpy()
division_attribution = assembled_classifier_division.attribute_sample_binary(
img_tensor # , multiply_by_inputs=False
).numpy()


# %%
g_lim = 2e-3


def clip_rescale(img, low, high):
return rescale_intensity(img.clip(low, high), out_range=(0, 1))

Expand All @@ -197,6 +191,7 @@ def clim_percentile(heatmap, low=1, high=99):
return clip_rescale(heatmap, lo, hi)


g_lim = 4e-3
z_slice = 5
phase = clim_percentile(img[:, 0, z_slice])
rfp = clim_percentile(img[:, 1, z_slice])
Expand Down Expand Up @@ -231,11 +226,11 @@ def clim_percentile(heatmap, low=1, high=99):
div_binary = str(selected_div_states[i]).lower()
ax[0, i].imshow(img_render[time], cmap="gray")
ax[0, i].set_title(f"{hpi} HPI")
ax[1, i].imshow(inf_render[time], cmap=icefire)
ax[1, i].imshow(inf_render[time], cmap=icefire, vmin=0, vmax=1)
ax[1, i].set_title(
f"infected: {prob:.3f}\n" f"label: {inf_binary}",
)
ax[2, i].imshow(div_render[time], cmap=icefire)
ax[2, i].imshow(div_render[time], cmap=icefire, vmin=0, vmax=1)
ax[2, i].set_title(
f"dividing: {division_probs[time].item():.3f}\n" f"label: {div_binary}",
)
Expand All @@ -246,7 +241,7 @@ def clim_percentile(heatmap, low=1, high=99):
mpl.cm.ScalarMappable(norm=norm, cmap=icefire),
orientation="vertical",
ax=ax[1:].ravel().tolist(),
format=mpl.ticker.StrMethodFormatter("{x}")
format=mpl.ticker.StrMethodFormatter("{x:.3f}"),
)
cbar.set_label("integrated gradients")

Expand Down
27 changes: 20 additions & 7 deletions viscy/representation/lca.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,6 @@ def linear_from_binary_logistic_regression(
return model


def attribute_sample_binary(img, assembled_classifier, **kwargs):
ig = IntegratedGradients(assembled_classifier, **kwargs)
assembled_classifier.zero_grad()
attribution = ig.attribute(torch.from_numpy(img)).numpy()
return img, attribution


class AssembledClassifier(torch.nn.Module):
"""Assemble a contrastive encoder with a linear classifier.
Expand Down Expand Up @@ -157,3 +150,23 @@ def forward(self, x: Tensor, scale_features: bool = False) -> Tensor:
x = self.scale_features(x)
x = self.classifier(x)
return x

def attribute_sample_binary(self, img: Tensor, **kwargs) -> Tensor:
"""Compute integrated gradients for a binary classification task.
Parameters
----------
img : Tensor
input image
**kwargs : Any
Keyword arguments for the integrated gradients algorithm.
Returns
-------
attribution : Tensor
Integrated gradients attribution map.
"""
self.zero_grad()
ig = IntegratedGradients(self, **kwargs)
attribution = ig.attribute(img)
return attribution

0 comments on commit 734b087

Please sign in to comment.