From aed53e4626e6e45aa258cc3f4043bc34b25d9f92 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Thu, 19 Sep 2024 12:33:01 -0400 Subject: [PATCH] Use sklearn's logistic regression for linear probing (#169) * use binary logistic regression to initialize the linear layer * plot integrated gradients from a binary classifier * add cmap to 'visual' requirements * move model assembling to lca * rename init argument * disable feature scaling * update test and evaluation scripts to use new API * add docstrings to LCA --- .../evaluation/grad_attr.py | 97 +++--- .../evaluation/linear_probing.py | 38 +-- pyproject.toml | 10 +- tests/representation/test_lca.py | 32 +- viscy/representation/lca.py | 307 ++++++++---------- 5 files changed, 202 insertions(+), 282 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/grad_attr.py b/applications/contrastive_phenotyping/evaluation/grad_attr.py index 594b0294..169345dd 100644 --- a/applications/contrastive_phenotyping/evaluation/grad_attr.py +++ b/applications/contrastive_phenotyping/evaluation/grad_attr.py @@ -7,15 +7,18 @@ import torch from captum.attr import IntegratedGradients from cmap import Colormap -from lightning.pytorch import Trainer, seed_everything -from lightning.pytorch.loggers import CSVLogger +from lightning.pytorch import seed_everything from skimage.exposure import rescale_intensity from viscy.data.triplet import TripletDataModule from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.engine import ContrastiveEncoder, ContrastiveModule from viscy.representation.evaluation import load_annotation -from viscy.representation.lca import LinearClassifier, LinearProbingDataModule +from viscy.representation.lca import ( + AssembledClassifier, + fit_logistic_regression, + linear_from_binary_logistic_regression, +) from viscy.transforms import NormalizeSampled, ScaleIntensityRangePercentilesd # %% @@ -83,55 +86,21 @@ ) # %% -linear_data = LinearProbingDataModule( - embeddings=torch.from_numpy(features.values).float(), - labels=torch.from_numpy(infection.cat.codes.values).long(), - split_ratio=(0.4, 0.2, 0.4), - batch_size=2**14, -) -linear_data.setup("fit") - -linear_classifier = LinearClassifier( - in_features=features.shape[1], out_features=3, lr=0.001 -) - -log_path = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/test" -) - -trainer = Trainer( - max_epochs=60, - logger=CSVLogger(log_path), - log_every_n_steps=1, -) - -trainer.fit(linear_classifier, linear_data) -losses = pd.read_csv( - log_path / "lightning_logs" / "version_0" / "metrics.csv", index_col="epoch" -) -losses = pd.merge( - losses["loss/train"].dropna(), losses["loss/val"].dropna(), on="epoch" -) -losses.plot() +train_fovs = ["/A/3/7", "/A/3/8", "/A/3/9", "/B/4/7", "/B/4/8"] # %% -linear_classifier = linear_classifier.eval() - +logistic_regression, data_split = fit_logistic_regression( + features.copy(), + infection.copy(), + train_fovs, + remove_background_class=True, + scale_features=False, + class_weight="balanced", + solver="liblinear", +) # %% -class AssembledClassifier(torch.nn.Module): - def __init__(self, model, classifier): - super().__init__() - self.model = model - self.classifier = classifier - - def forward(self, x): - x = self.model.stem(x) - x = self.model.encoder(x) - x = self.classifier(x) - return x - - +linear_classifier = linear_from_binary_logistic_regression(logistic_regression) assembled_classifier = AssembledClassifier(model.model, linear_classifier).eval().cpu() # %% @@ -141,19 +110,18 @@ def forward(self, x): ) track_classes = infection[infection["fov_name"] == fov[1:]] track_classes = track_classes[track_classes["track_id"] == track]["infection_state"] -track_classes # %% -def attribute_sample(img, target, assembled_classifier): +def attribute_sample(img, assembled_classifier): ig = IntegratedGradients(assembled_classifier, multiply_by_inputs=True) assembled_classifier.zero_grad() - attribution = ig.attribute(torch.from_numpy(img), target=target).numpy() + attribution = ig.attribute(torch.from_numpy(img)).numpy() return img, attribution -def color_and_clim(heatmap, cmap): - lo, hi = np.percentile(heatmap, (1, 99)) +def color_and_clim(heatmap, cmap, low=1, high=99): + lo, hi = np.percentile(heatmap, (low, high)) rescaled = rescale_intensity(heatmap.clip(lo, hi), out_range=(0, 1)) return Colormap(cmap)(rescaled) @@ -163,10 +131,9 @@ def color_and_clim(heatmap, cmap): img = sample["anchor"].numpy() # %% -target = torch.from_numpy(track_classes.values).long() with torch.inference_mode(): - probs = assembled_classifier(torch.from_numpy(img)).softmax(dim=1) -img, attribution = attribute_sample(img, target, assembled_classifier) + probs = assembled_classifier(torch.from_numpy(img)).sigmoid() +img, attribution = attribute_sample(img, assembled_classifier) # %% z_slice = 5 @@ -184,10 +151,20 @@ def color_and_clim(heatmap, cmap): print(grid.shape) # %% -f, ax = plt.subplots(6, 8, figsize=(16, 12)) -for i, (z_slice, a) in enumerate(zip(grid, ax.flatten())): - a.imshow(z_slice) - a.set_title(f"t={i}") +selected_time_points = [0, 4, 8, 34] +class_text = {0: "none", 1: "uninfected", 2: "infected"} + +sps = len(selected_time_points) +f, ax = plt.subplots(1, sps, figsize=(4 * sps, 4)) +for time, a in zip(selected_time_points, ax.flatten()): + rendered = grid[time] + prob = probs[time].item() + a.imshow(rendered) + hpi = 3 + 0.5 * time + text_label = class_text[track_classes.iloc[time]] + a.set_title( + f"{hpi} HPI,\npredicted infection probability: {prob:.2f},\nannotation: {text_label}" + ) a.axis("off") f.tight_layout() diff --git a/applications/contrastive_phenotyping/evaluation/linear_probing.py b/applications/contrastive_phenotyping/evaluation/linear_probing.py index 6148336d..aa73f278 100644 --- a/applications/contrastive_phenotyping/evaluation/linear_probing.py +++ b/applications/contrastive_phenotyping/evaluation/linear_probing.py @@ -1,16 +1,13 @@ # %% Imports from pathlib import Path -from tempfile import TemporaryDirectory - -import pandas as pd -from lightning.pytorch import Trainer -from lightning.pytorch.loggers import CSVLogger from viscy.representation.embedding_writer import read_embedding_dataset from viscy.representation.evaluation import load_annotation -from viscy.representation.lca import train_and_test_linear_classifier +from viscy.representation.lca import fit_logistic_regression # %% +TRAIN_FOVS = ["/A/3/7", "/A/3/8", "/A/3/9", "/B/4/7", "/B/4/8"] + path_embedding = Path( "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/feb_test_time_interval_1_epoch_178.zarr" ) @@ -33,27 +30,14 @@ infection # %% -temp_dir = TemporaryDirectory() -log_path = Path(temp_dir.name) - -train_and_test_linear_classifier( - features.to_numpy(), - infection.cat.codes.values, - num_classes=3, - trainer=Trainer(max_epochs=60, logger=CSVLogger(log_path), log_every_n_steps=1), - split_ratio=(0.4, 0.2, 0.4), - batch_size=2**14, - lr=0.001, +log_reg = fit_logistic_regression( + features, + infection, + train_fovs=TRAIN_FOVS, + remove_background_class=True, + scale_features=False, + class_weight="balanced", + solver="liblinear", ) -# plot loss curves to check if training converged/overfitted -# adjust number of epochs if necessary -losses = pd.read_csv( - log_path / "lightning_logs" / "version_0" / "metrics.csv", index_col="epoch" -) -losses = pd.merge( - losses["loss/train"].dropna(), losses["loss/val"].dropna(), on="epoch" -) -losses.plot() -temp_dir.cleanup() # %% diff --git a/pyproject.toml b/pyproject.toml index 36c44ee6..039bdc41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,15 @@ metrics = [ ] examples = ["napari", "jupyter", "jupytext"] -visual = ["ipykernel", "graphviz", "torchview", "seaborn", "plotly", "nbformat"] +visual = [ + "ipykernel", + "graphviz", + "torchview", + "seaborn", + "plotly", + "nbformat", + "cmap", +] dev = [ "pytest", diff --git a/tests/representation/test_lca.py b/tests/representation/test_lca.py index 34ba3515..f64b5771 100644 --- a/tests/representation/test_lca.py +++ b/tests/representation/test_lca.py @@ -1,19 +1,23 @@ import numpy as np +import torch +from sklearn.linear_model import LogisticRegression -from viscy.representation.lca import Trainer, train_and_test_linear_classifier +from viscy.representation.lca import linear_from_binary_logistic_regression -def test_train_and_test_linear_classifier(caplog): - """Test ``train_and_test_linear_classifier``.""" - embeddings = np.random.rand(10, 8) - labels = np.random.randint(0, 2, 10) - with caplog.at_level("INFO"): - train_and_test_linear_classifier( - embeddings, - labels, - num_classes=3, - trainer=Trainer(fast_dev_run=True), - batch_size=4, +def test_linear_from_logistic_regression(): + """ + Test ``linear_from_logistic_regression``. + Check that the logits from the logistic regression + and the linear model are almost equal. + """ + rand_data = np.random.rand(100, 8) + rand_labels = np.random.randint(0, 2, size=(100)) + logistic_regression = LogisticRegression().fit(rand_data, rand_labels) + linear_model = linear_from_binary_logistic_regression(logistic_regression) + logistic_logits = logistic_regression.decision_function(rand_data) + with torch.inference_mode(): + torch_logits = ( + linear_model(torch.from_numpy(rand_data).float()).squeeze().numpy() ) - assert "accuracy_macro" in caplog.text - assert "f1_weighted" in caplog.text + np.testing.assert_allclose(logistic_logits, torch_logits, rtol=1e-3) diff --git a/viscy/representation/lca.py b/viscy/representation/lca.py index 130c50f9..bcd10816 100644 --- a/viscy/representation/lca.py +++ b/viscy/representation/lca.py @@ -1,204 +1,151 @@ """Linear probing of trained encoder based on cell state labels.""" -import logging -from pprint import pformat -from typing import Literal +from typing import Mapping -import numpy as np +import pandas as pd import torch import torch.nn as nn -from lightning.pytorch import LightningDataModule, LightningModule, Trainer from numpy.typing import NDArray -from torch import Tensor, optim -from torch.utils.data import DataLoader, TensorDataset -from torchmetrics.functional.classification import ( - multiclass_accuracy, - multiclass_f1_score, -) - -_logger = logging.getLogger("lightning.pytorch") - - -def _test_metrics(preds: Tensor, target: Tensor, num_classes: int) -> dict[str, float]: - """Test metrics for the linear classifier. +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import classification_report +from sklearn.preprocessing import StandardScaler +from torch import Tensor +from xarray import DataArray + +from viscy.representation.contrastive import ContrastiveEncoder + + +def fit_logistic_regression( + features: DataArray, + annotations: pd.Series, + train_fovs: list[str], + remove_background_class: bool = True, + scale_features: bool = False, + class_weight: Mapping | str | None = "balanced", + random_state: int | None = None, + solver="liblinear", +) -> tuple[ + LogisticRegression, + tuple[tuple[NDArray, NDArray], tuple[NDArray, NDArray]], +]: + """Fit a binary logistic regression classifier. Parameters ---------- - preds : Tensor - Predicted logits, shape (n_samples, n_classes) - target : Tensor - Labels, shape (n_samples,) - num_classes : int - Number of classes + features : DataArray + Xarray of features. + annotations : pd.Series + Categorical class annotations with label values starting from 0. + Must have 3 classes (when remove background is True) or 2 classes. + train_fovs : list[str] + List of FOVs to use for training. The rest will be used for testing. + remove_background_class : bool, optional + Remove background class (0), by default True + scale_features : bool, optional + Scale features, by default False + class_weight : Mapping | str | None, optional + Class weight for balancing, by default "balanced" + random_state : int | None, optional + Random state or seed, by default None + solver : str, optional + Solver for the regression problem, by default "liblinear" Returns ------- - dict[str, float] - Test metrics + tuple[LogisticRegression, tuple[tuple[NDArray, NDArray], tuple[NDArray, NDArray]]] + Trained classifier and data split [[X_train, y_train], [X_test, y_test]]. """ - # TODO: add more metrics - metrics = {} - for average in ["macro", "weighted"]: - metrics[f"accuracy_{average}"] = multiclass_accuracy( - preds, target, num_classes, average=average - ).item() - metrics[f"f1_{average}"] = multiclass_f1_score( - preds, target, num_classes, average=average - ).item() - return metrics - - -class LinearProbingDataModule(LightningDataModule): - def __init__( - self, - embeddings: Tensor, - labels: Tensor, - split_ratio: tuple[int, int, int], - batch_size: int, - ) -> None: - """Data module for linear probing. - - Parameters - ---------- - embeddings : Tensor - Input embeddings - labels : Tensor - Annotation labels - split_ratio : tuple[int, int, int] - Train/validate/test split ratio, must sum to 1. - batch_size : int - Batch sizes - """ - super().__init__() - if not embeddings.shape[0] == labels.shape[0]: - raise ValueError("Number of samples in embeddings and labels must match.") - if sum(split_ratio) != 1.0: - raise ValueError("Split ratio must sum to 1.") - self.dataset = TensorDataset(embeddings.float(), labels.long()) - self.split_ratio = split_ratio - self.batch_size = batch_size - - def setup(self, stage: Literal["fit", "validate", "test", "predict"]) -> None: - n = len(self.dataset) - train_size = int(n * self.split_ratio[0]) - val_size = int(n * self.split_ratio[1]) - test_size = n - train_size - val_size - self.train_dataset, self.val_dataset, self.test_dataset = ( - torch.utils.data.random_split( - self.dataset, [train_size, val_size, test_size] - ) + fov_selection = features["fov_name"].isin(train_fovs) + train_selection = fov_selection + test_selection = ~fov_selection + annotations = annotations.cat.codes.values.copy() + if remove_background_class: + label_selection = annotations != 0 + train_selection &= label_selection + test_selection &= label_selection + annotations -= 1 + train_features = features.values[train_selection] + test_features = features.values[test_selection] + if scale_features: + scaler = StandardScaler() + train_features = scaler.fit_transform(train_features) + test_features = scaler.fit_transform(test_features) + train_annotations = annotations[train_selection] + test_annotations = annotations[test_selection] + logistic_regression = LogisticRegression( + class_weight=class_weight, + random_state=random_state, + solver=solver, + ) + logistic_regression.fit(train_features, train_annotations) + prediction = logistic_regression.predict(test_features) + print("Trained logistic regression classifier.") + print( + "Training set accuracy:\n" + + classification_report( + logistic_regression.predict(train_features), train_annotations, digits=3 ) + ) + print( + "Test set accuracy:\n" + + classification_report(prediction, test_annotations, digits=3) + ) + return logistic_regression, ( + (train_features, train_annotations), + (test_features, test_annotations), + ) - def train_dataloader(self) -> DataLoader: - return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True) - def val_dataloader(self) -> DataLoader: - return DataLoader(self.val_dataset, batch_size=self.batch_size) +def linear_from_binary_logistic_regression( + logistic_regression: LogisticRegression, +) -> nn.Linear: + """Convert a binary logistic regression model to a ``torch.nn.Linear`` layer. - def test_dataloader(self) -> DataLoader: - return DataLoader(self.test_dataset, batch_size=self.batch_size) + Parameters + ---------- + logistic_regression : LogisticRegression + Trained logistic regression model. + Returns + ------- + nn.Linear + Converted linear model. + """ + weights = torch.from_numpy(logistic_regression.coef_).float() + bias = torch.from_numpy(logistic_regression.intercept_).float() + model = nn.Linear(in_features=weights.shape[1], out_features=1) + model.weight.data = weights + model.bias.data = bias + model.eval() + return model -class LinearClassifier(LightningModule): - def __init__(self, in_features: int, out_features: int, lr: float) -> None: - """Linear classifier. - Parameters - ---------- - in_features : int - Number of input feature channels - out_features : int - Number of output feature channels (number of classes) - lr : float - Learning rate - """ - super().__init__() - self.fc = nn.Linear(in_features, out_features) - self.lr = lr - self.loss = nn.BCEWithLogitsLoss() - - def forward(self, x: Tensor) -> Tensor: - return self.fc(x) - - def _fit_step(self, batch, stage: str) -> Tensor: - x, y = batch - preds = self(x) - target = nn.functional.one_hot(y, num_classes=preds.shape[1]).float() - loss = self.loss(preds, target) - self.log(f"loss/{stage}", loss, on_epoch=True, on_step=False) - return loss - - def training_step(self, batch, batch_idx: int) -> Tensor: - return self._fit_step(batch, stage="train") - - def validation_step(self, batch, batch_idx: int) -> None: - _ = self._fit_step(batch, stage="val") - - def configure_optimizers( - self, - ) -> tuple[list[optim.Optimizer], list[optim.lr_scheduler.LRScheduler]]: - return optim.AdamW(self.parameters()) - - def on_test_start(self) -> None: - self.test_labels: list[Tensor] = [] - self.test_predictions: list[Tensor] = [] - - def test_step(self, batch, batch_idx: int) -> None: - x, y = batch - preds = self(x) - self.test_labels.append(y) - self.test_predictions.append(preds) - - def on_test_epoch_end(self) -> None: - y = torch.cat(self.test_labels) - preds = torch.cat(self.test_predictions) - num_classes = self.fc.out_features - _logger.info("Test metrics:\n" + pformat(_test_metrics(preds, y, num_classes))) - - def predict_step(self, x: Tensor) -> Tensor: - logits = self(x) - return torch.argmax(logits, dim=1) - - -def train_and_test_linear_classifier( - embeddings: NDArray, - labels: NDArray, - num_classes: int, - trainer: Trainer, - split_ratio: tuple[int, int, int] = (0.4, 0.2, 0.4), - batch_size: int = 1024, - lr: float = 1e-3, -) -> None: - """Train and test a linear classifier. +class AssembledClassifier(torch.nn.Module): + """Assemble a contrastive encoder with a linear classifier. Parameters ---------- - embeddings : NDArray - Input embeddings, shape (n_samples, n_features). - labels : NDArray - Annotation labels, shape (n_samples,). - num_classes : int - Number of classes. - trainer : Trainer - Lightning Trainer object for training and testing. - Define the number of epochs, logging, etc. - split_ratio : tuple[int, int, int], optional - Train/validate/test split ratio, by default (0.4, 0.2, 0.4) - batch_size : int, optional - Batch size, by default 1024 - lr : float, optional - Learning rate, by default 1e-3 + backbone : ContrastiveEncoder + Encoder backbone. + classifier : nn.Linear + Classifier head. """ - if not isinstance(embeddings, np.ndarray) or not isinstance(labels, np.ndarray): - raise TypeError("Input embeddings and labels must be NumPy arrays.") - if not embeddings.ndim == 2: - raise ValueError("Input embeddings must have 2 dimensions.") - if not labels.ndim == 1: - raise ValueError("Labels must have 1 dimension.") - embeddings = torch.from_numpy(embeddings) - data = LinearProbingDataModule( - embeddings, torch.from_numpy(labels), split_ratio, batch_size - ) - model = LinearClassifier(embeddings.shape[1], num_classes, lr) - trainer.fit(model, data) - trainer.test(model, data) + + def __init__(self, backbone: ContrastiveEncoder, classifier: nn.Linear) -> None: + super().__init__() + self.backbone = backbone + self.classifier = classifier + + @staticmethod + def scale_features(x: Tensor) -> Tensor: + m = x.mean(-2, keepdim=True) + s = x.std(-2, unbiased=False, keepdim=True) + return (x - m) / s + + def forward(self, x: Tensor, scale_features: bool = False) -> Tensor: + x = self.backbone.stem(x) + x = self.backbone.encoder(x) + if scale_features: + x = self.scale_features(x) + x = self.classifier(x) + return x