Skip to content

Commit

Permalink
Use sklearn's logistic regression for linear probing (#169)
Browse files Browse the repository at this point in the history
* use binary logistic regression to initialize the linear layer

* plot integrated gradients from a binary classifier

* add cmap to 'visual' requirements

* move model assembling to lca

* rename init argument

* disable feature scaling

* update test and evaluation scripts to use new API

* add docstrings to LCA
  • Loading branch information
ziw-liu authored Sep 19, 2024
1 parent 5cc0dc0 commit aed53e4
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 282 deletions.
97 changes: 37 additions & 60 deletions applications/contrastive_phenotyping/evaluation/grad_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
import torch
from captum.attr import IntegratedGradients
from cmap import Colormap
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch import seed_everything
from skimage.exposure import rescale_intensity

from viscy.data.triplet import TripletDataModule
from viscy.representation.embedding_writer import read_embedding_dataset
from viscy.representation.engine import ContrastiveEncoder, ContrastiveModule
from viscy.representation.evaluation import load_annotation
from viscy.representation.lca import LinearClassifier, LinearProbingDataModule
from viscy.representation.lca import (
AssembledClassifier,
fit_logistic_regression,
linear_from_binary_logistic_regression,
)
from viscy.transforms import NormalizeSampled, ScaleIntensityRangePercentilesd

# %%
Expand Down Expand Up @@ -83,55 +86,21 @@
)

# %%
linear_data = LinearProbingDataModule(
embeddings=torch.from_numpy(features.values).float(),
labels=torch.from_numpy(infection.cat.codes.values).long(),
split_ratio=(0.4, 0.2, 0.4),
batch_size=2**14,
)
linear_data.setup("fit")

linear_classifier = LinearClassifier(
in_features=features.shape[1], out_features=3, lr=0.001
)

log_path = Path(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/test"
)

trainer = Trainer(
max_epochs=60,
logger=CSVLogger(log_path),
log_every_n_steps=1,
)

trainer.fit(linear_classifier, linear_data)
losses = pd.read_csv(
log_path / "lightning_logs" / "version_0" / "metrics.csv", index_col="epoch"
)
losses = pd.merge(
losses["loss/train"].dropna(), losses["loss/val"].dropna(), on="epoch"
)
losses.plot()
train_fovs = ["/A/3/7", "/A/3/8", "/A/3/9", "/B/4/7", "/B/4/8"]

# %%
linear_classifier = linear_classifier.eval()

logistic_regression, data_split = fit_logistic_regression(
features.copy(),
infection.copy(),
train_fovs,
remove_background_class=True,
scale_features=False,
class_weight="balanced",
solver="liblinear",
)

# %%
class AssembledClassifier(torch.nn.Module):
def __init__(self, model, classifier):
super().__init__()
self.model = model
self.classifier = classifier

def forward(self, x):
x = self.model.stem(x)
x = self.model.encoder(x)
x = self.classifier(x)
return x


linear_classifier = linear_from_binary_logistic_regression(logistic_regression)
assembled_classifier = AssembledClassifier(model.model, linear_classifier).eval().cpu()

# %%
Expand All @@ -141,19 +110,18 @@ def forward(self, x):
)
track_classes = infection[infection["fov_name"] == fov[1:]]
track_classes = track_classes[track_classes["track_id"] == track]["infection_state"]
track_classes


# %%
def attribute_sample(img, target, assembled_classifier):
def attribute_sample(img, assembled_classifier):
ig = IntegratedGradients(assembled_classifier, multiply_by_inputs=True)
assembled_classifier.zero_grad()
attribution = ig.attribute(torch.from_numpy(img), target=target).numpy()
attribution = ig.attribute(torch.from_numpy(img)).numpy()
return img, attribution


def color_and_clim(heatmap, cmap):
lo, hi = np.percentile(heatmap, (1, 99))
def color_and_clim(heatmap, cmap, low=1, high=99):
lo, hi = np.percentile(heatmap, (low, high))
rescaled = rescale_intensity(heatmap.clip(lo, hi), out_range=(0, 1))
return Colormap(cmap)(rescaled)

Expand All @@ -163,10 +131,9 @@ def color_and_clim(heatmap, cmap):
img = sample["anchor"].numpy()

# %%
target = torch.from_numpy(track_classes.values).long()
with torch.inference_mode():
probs = assembled_classifier(torch.from_numpy(img)).softmax(dim=1)
img, attribution = attribute_sample(img, target, assembled_classifier)
probs = assembled_classifier(torch.from_numpy(img)).sigmoid()
img, attribution = attribute_sample(img, assembled_classifier)

# %%
z_slice = 5
Expand All @@ -184,10 +151,20 @@ def color_and_clim(heatmap, cmap):
print(grid.shape)

# %%
f, ax = plt.subplots(6, 8, figsize=(16, 12))
for i, (z_slice, a) in enumerate(zip(grid, ax.flatten())):
a.imshow(z_slice)
a.set_title(f"t={i}")
selected_time_points = [0, 4, 8, 34]
class_text = {0: "none", 1: "uninfected", 2: "infected"}

sps = len(selected_time_points)
f, ax = plt.subplots(1, sps, figsize=(4 * sps, 4))
for time, a in zip(selected_time_points, ax.flatten()):
rendered = grid[time]
prob = probs[time].item()
a.imshow(rendered)
hpi = 3 + 0.5 * time
text_label = class_text[track_classes.iloc[time]]
a.set_title(
f"{hpi} HPI,\npredicted infection probability: {prob:.2f},\nannotation: {text_label}"
)
a.axis("off")
f.tight_layout()

Expand Down
38 changes: 11 additions & 27 deletions applications/contrastive_phenotyping/evaluation/linear_probing.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# %% Imports
from pathlib import Path
from tempfile import TemporaryDirectory

import pandas as pd
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import CSVLogger

from viscy.representation.embedding_writer import read_embedding_dataset
from viscy.representation.evaluation import load_annotation
from viscy.representation.lca import train_and_test_linear_classifier
from viscy.representation.lca import fit_logistic_regression

# %%
TRAIN_FOVS = ["/A/3/7", "/A/3/8", "/A/3/9", "/B/4/7", "/B/4/8"]

path_embedding = Path(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr"
)
Expand All @@ -33,27 +30,14 @@
infection

# %%
temp_dir = TemporaryDirectory()
log_path = Path(temp_dir.name)

train_and_test_linear_classifier(
features.to_numpy(),
infection.cat.codes.values,
num_classes=3,
trainer=Trainer(max_epochs=60, logger=CSVLogger(log_path), log_every_n_steps=1),
split_ratio=(0.4, 0.2, 0.4),
batch_size=2**14,
lr=0.001,
log_reg = fit_logistic_regression(
features,
infection,
train_fovs=TRAIN_FOVS,
remove_background_class=True,
scale_features=False,
class_weight="balanced",
solver="liblinear",
)

# plot loss curves to check if training converged/overfitted
# adjust number of epochs if necessary
losses = pd.read_csv(
log_path / "lightning_logs" / "version_0" / "metrics.csv", index_col="epoch"
)
losses = pd.merge(
losses["loss/train"].dropna(), losses["loss/val"].dropna(), on="epoch"
)
losses.plot()
temp_dir.cleanup()
# %%
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@ metrics = [
]
examples = ["napari", "jupyter", "jupytext"]

visual = ["ipykernel", "graphviz", "torchview", "seaborn", "plotly", "nbformat"]
visual = [
"ipykernel",
"graphviz",
"torchview",
"seaborn",
"plotly",
"nbformat",
"cmap",
]

dev = [
"pytest",
Expand Down
32 changes: 18 additions & 14 deletions tests/representation/test_lca.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression

from viscy.representation.lca import Trainer, train_and_test_linear_classifier
from viscy.representation.lca import linear_from_binary_logistic_regression


def test_train_and_test_linear_classifier(caplog):
"""Test ``train_and_test_linear_classifier``."""
embeddings = np.random.rand(10, 8)
labels = np.random.randint(0, 2, 10)
with caplog.at_level("INFO"):
train_and_test_linear_classifier(
embeddings,
labels,
num_classes=3,
trainer=Trainer(fast_dev_run=True),
batch_size=4,
def test_linear_from_logistic_regression():
"""
Test ``linear_from_logistic_regression``.
Check that the logits from the logistic regression
and the linear model are almost equal.
"""
rand_data = np.random.rand(100, 8)
rand_labels = np.random.randint(0, 2, size=(100))
logistic_regression = LogisticRegression().fit(rand_data, rand_labels)
linear_model = linear_from_binary_logistic_regression(logistic_regression)
logistic_logits = logistic_regression.decision_function(rand_data)
with torch.inference_mode():
torch_logits = (
linear_model(torch.from_numpy(rand_data).float()).squeeze().numpy()
)
assert "accuracy_macro" in caplog.text
assert "f1_weighted" in caplog.text
np.testing.assert_allclose(logistic_logits, torch_logits, rtol=1e-3)
Loading

0 comments on commit aed53e4

Please sign in to comment.