From a545f38be89c376fbd87ddf192a5008791fcb4ac Mon Sep 17 00:00:00 2001 From: Sergiy Popovych Date: Fri, 2 Feb 2024 03:01:51 +0000 Subject: [PATCH] refactor: internal submodule --- .coveragerc | 6 + .github/workflows/testing.yaml | 2 +- .gitmodules | 6 +- scripts/benchmark_rigidity_map.py | 2 +- .../em_encoder/train/m3_m6_encoder_dict_c4.py | 4 +- .../em_encoder/train/m3_m8_encoder_dict.py | 4 +- tmp.py | 0 zetta_utils/__init__.py | 3 +- zetta_utils/alignment | 1 - zetta_utils/api/v0.py | 66 +-- zetta_utils/internal | 1 + .../mazepa_layer_processing/__init__.py | 4 +- .../alignment/aced_relaxation_flow.py | 3 +- .../alignment/common.py | 2 +- zetta_utils/segmentation/__init__.py | 7 - zetta_utils/segmentation/affinity.py | 284 ---------- zetta_utils/segmentation/affs_inferencer.py | 70 --- zetta_utils/segmentation/balance.py | 90 --- zetta_utils/segmentation/common.py | 23 - .../segmentation/embedding/__init__.py | 4 - zetta_utils/segmentation/embedding/common.py | 39 -- zetta_utils/segmentation/embedding/edge.py | 141 ----- zetta_utils/segmentation/embedding/mean.py | 179 ------ zetta_utils/segmentation/embedding/utils.py | 123 ---- zetta_utils/segmentation/inference.py | 43 -- zetta_utils/segmentation/loss.py | 120 ---- zetta_utils/training/lightning/__init__.py | 2 +- .../training/lightning/regimes/__init__.py | 7 - .../lightning/regimes/alignment/__init__.py | 2 - .../regimes/alignment/base_encoder.py | 303 ---------- .../alignment/deprecated/base_encoder.py | 524 ------------------ .../deprecated/encoding_coarsener.py | 273 --------- .../deprecated/encoding_coarsener_gen_x1.py | 179 ------ .../deprecated/encoding_coarsener_highres.py | 356 ------------ .../alignment/deprecated/minima_encoder.py | 254 --------- .../deprecated/misalignment_detector.py | 185 ------- .../alignment/misalignment_detector_aced.py | 256 --------- .../training/lightning/regimes/common.py | 79 --- .../lightning/regimes/naive_supervised.py | 76 --- .../training/lightning/regimes/noop.py | 36 -- .../regimes/segmentation/__init__.py | 1 - .../regimes/segmentation/base_affinity.py | 147 ----- .../regimes/segmentation/base_embedding.py | 131 ----- 43 files changed, 54 insertions(+), 3984 deletions(-) create mode 100644 tmp.py delete mode 160000 zetta_utils/alignment create mode 160000 zetta_utils/internal delete mode 100644 zetta_utils/segmentation/__init__.py delete mode 100644 zetta_utils/segmentation/affinity.py delete mode 100644 zetta_utils/segmentation/affs_inferencer.py delete mode 100644 zetta_utils/segmentation/balance.py delete mode 100644 zetta_utils/segmentation/common.py delete mode 100644 zetta_utils/segmentation/embedding/__init__.py delete mode 100644 zetta_utils/segmentation/embedding/common.py delete mode 100644 zetta_utils/segmentation/embedding/edge.py delete mode 100644 zetta_utils/segmentation/embedding/mean.py delete mode 100644 zetta_utils/segmentation/embedding/utils.py delete mode 100644 zetta_utils/segmentation/inference.py delete mode 100644 zetta_utils/segmentation/loss.py delete mode 100644 zetta_utils/training/lightning/regimes/__init__.py delete mode 100644 zetta_utils/training/lightning/regimes/alignment/__init__.py delete mode 100644 zetta_utils/training/lightning/regimes/alignment/base_encoder.py delete mode 100644 zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder.py delete mode 100644 zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py delete mode 100644 zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py delete mode 100644 zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py delete mode 100644 zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py delete mode 100644 zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py delete mode 100644 zetta_utils/training/lightning/regimes/alignment/misalignment_detector_aced.py delete mode 100644 zetta_utils/training/lightning/regimes/common.py delete mode 100644 zetta_utils/training/lightning/regimes/naive_supervised.py delete mode 100644 zetta_utils/training/lightning/regimes/noop.py delete mode 100644 zetta_utils/training/lightning/regimes/segmentation/__init__.py delete mode 100644 zetta_utils/training/lightning/regimes/segmentation/base_affinity.py delete mode 100644 zetta_utils/training/lightning/regimes/segmentation/base_embedding.py diff --git a/.coveragerc b/.coveragerc index 0e7de3e75..c6ba6ddc7 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,9 +3,15 @@ omit = *_tmp.py zetta_utils/log.py zetta_utils/mazepa_layer_processing/*.py + zetta_utils/mazepa_layer_processing/**/*.py zetta_utils/tensor_mapping/*.py + zetta_utils/tensor_mapping/**/*.py zetta_utils/convnet/architecture/deprecated/*.py + zetta_utils/convnet/architecture/deprecated/**/*.py zetta_utils/viz/*.py + zetta_utils/viz/**/*.py + zetta_utils/internal/*.py + zetta_utils/internal/**/*.py zetta_utils/mazepa_addons/*.py zetta_utils/mazepa_addons/**/*.py zetta_utils/alignment/*.py diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml index a021139a5..555aa0c2c 100644 --- a/.github/workflows/testing.yaml +++ b/.github/workflows/testing.yaml @@ -122,7 +122,7 @@ jobs: run: pylint ${{ steps.changed-py-files.outputs.all_changed_files }} - name: Run isort if: ${{ steps.changed-py-files.outputs.any_changed == 'true' }} - run: isort --check-only --df --verbose --profile black . + run: isort --check-only --df --verbose --om --profile black . mypy: strategy: matrix: diff --git a/.gitmodules b/.gitmodules index 151433cc5..44216a04b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ -[submodule "zetta_utils/alignment"] - path = zetta_utils/alignment - url = git@github.com:ZettaAI/alignment.git +[submodule "zetta_utils/internal"] + path = zetta_utils/internal + url = git@github.com:ZettaAI/internal.git diff --git a/scripts/benchmark_rigidity_map.py b/scripts/benchmark_rigidity_map.py index 10dc9bcc4..b7f1dc5e7 100644 --- a/scripts/benchmark_rigidity_map.py +++ b/scripts/benchmark_rigidity_map.py @@ -4,7 +4,7 @@ import torch import torchfields -from zetta_utils.alignment import field +from zetta_utils.internal.alignment import field def rotation_tensor(degrees): diff --git a/specs/nico/training/em_encoder/train/m3_m6_encoder_dict_c4.py b/specs/nico/training/em_encoder/train/m3_m6_encoder_dict_c4.py index 95b50fe82..6a414fbf9 100644 --- a/specs/nico/training/em_encoder/train/m3_m6_encoder_dict_c4.py +++ b/specs/nico/training/em_encoder/train/m3_m6_encoder_dict_c4.py @@ -22,9 +22,7 @@ EXP_VERSION = f"4.0.0_M3_M6_3px_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = ( - f"4.4.0_M3_M6_C1_lr0.0002_locality1.0_similarity0.0_l10.05-0.12_N1x4" - ) + START_EXP_VERSION = f"4.4.0_M3_M6_C1_lr0.0002_locality1.0_similarity0.0_l10.05-0.12_N1x4" MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_coarsener_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" diff --git a/specs/nico/training/em_encoder/train/m3_m8_encoder_dict.py b/specs/nico/training/em_encoder/train/m3_m8_encoder_dict.py index 5a11d757f..92feafcc9 100644 --- a/specs/nico/training/em_encoder/train/m3_m8_encoder_dict.py +++ b/specs/nico/training/em_encoder/train/m3_m8_encoder_dict.py @@ -22,9 +22,7 @@ EXP_VERSION = f"1.0.0_M3_M8_C{CHANNELS}_lr{LR}_locality{LOCALITY_WEIGHT}_similarity{SIMILARITY_WEIGHT}_l1{L1_WEIGHT_START_VAL}-{L1_WEIGHT_END_VAL}_N1x4" - START_EXP_VERSION = ( - f"1.2.0_M3_M7_C1_lr2e-05_locality1.0_similarity0.0_l10.12-0.12_N1x4" - ) + START_EXP_VERSION = f"1.2.0_M3_M7_C1_lr2e-05_locality1.0_similarity0.0_l10.12-0.12_N1x4" MODEL_CKPT = None # f"gs://zetta-research-nico/training_artifacts/general_coarsener_loss/{START_EXP_VERSION}/last.ckpt" BASE_PATH = "gs://zetta-research-nico/encoder/" diff --git a/tmp.py b/tmp.py new file mode 100644 index 000000000..e69de29bb diff --git a/zetta_utils/__init__.py b/zetta_utils/__init__.py index 3dca33a58..1513ae044 100644 --- a/zetta_utils/__init__.py +++ b/zetta_utils/__init__.py @@ -35,7 +35,7 @@ def try_load_train_inference(): # pragma: no cover def try_load_submodules(): # pragma: no cover try: - from . import alignment + from . import internal except ImportError: ... @@ -46,7 +46,6 @@ def load_inference_modules(): convnet, mazepa, mazepa_layer_processing, - segmentation, tensor_ops, tensor_typing, tensor_mapping, diff --git a/zetta_utils/alignment b/zetta_utils/alignment deleted file mode 160000 index 8f1d75b14..000000000 --- a/zetta_utils/alignment +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 8f1d75b14f09f6520a0e23b8d02204133f7032bd diff --git a/zetta_utils/api/v0.py b/zetta_utils/api/v0.py index 4f9fa907f..47b26d2d1 100644 --- a/zetta_utils/api/v0.py +++ b/zetta_utils/api/v0.py @@ -1,26 +1,6 @@ # pylint: disable=unused-import -from zetta_utils.alignment.aced_relaxation import ( - compute_aced_loss_new, - get_aced_match_offsets, - get_aced_match_offsets_naive, - perform_aced_relaxation, -) -from zetta_utils.alignment.base_coarsener import BaseCoarsener -from zetta_utils.alignment.base_encoder import BaseEncoder -from zetta_utils.alignment.encoding_coarsener import EncodingCoarsener -from zetta_utils.alignment.field import ( - gen_biased_perlin_noise_field, - get_rigidity_map, - get_rigidity_map_zcxy, - invert_field, - invert_field_opti, - percentile, - profile_field2d_percentile, -) -from zetta_utils.alignment.misalignment_detector import MisalignmentDetector, naive_misd -from zetta_utils.alignment.online_finetuner import align_with_online_finetuner from zetta_utils.augmentations.common import prob_aug from zetta_utils.augmentations.imgaug import imgaug_augment, imgaug_readproc from zetta_utils.augmentations.tensor import ( @@ -134,6 +114,37 @@ from zetta_utils.geometry.bbox import BBox3D from zetta_utils.geometry.bbox_strider import BBoxStrider from zetta_utils.geometry.vec import Vec3D, allclose, is_int_vec, is_raw_vec3d, isclose +from zetta_utils.internal.alignment.aced_relaxation import ( + compute_aced_loss_new, + get_aced_match_offsets, + get_aced_match_offsets_naive, + perform_aced_relaxation, +) +from zetta_utils.internal.alignment.base_coarsener import BaseCoarsener +from zetta_utils.internal.alignment.base_encoder import BaseEncoder +from zetta_utils.internal.alignment.encoding_coarsener import EncodingCoarsener +from zetta_utils.internal.alignment.field import ( + gen_biased_perlin_noise_field, + get_rigidity_map, + get_rigidity_map_zcxy, + invert_field, + invert_field_opti, + percentile, + profile_field2d_percentile, +) +from zetta_utils.internal.alignment.misalignment_detector import ( + MisalignmentDetector, + naive_misd, +) +from zetta_utils.internal.alignment.online_finetuner import align_with_online_finetuner +from zetta_utils.internal.regimes.alignment.base_encoder import BaseEncoderRegime +from zetta_utils.internal.regimes.alignment.misalignment_detector_aced import ( + MisalignmentDetectorAcedRegime, +) +from zetta_utils.internal.regimes.common import is_2d_image, log_results +from zetta_utils.internal.regimes.naive_supervised import NaiveSupervisedRegime +from zetta_utils.internal.regimes.noop import NoOpRegime +from zetta_utils.internal.segmentation.inference import run_affinities_inference_onnx from zetta_utils.layer.backend_base import Backend from zetta_utils.layer.db_layer.backend import DBBackend from zetta_utils.layer.db_layer.build import build_db_layer @@ -352,7 +363,6 @@ read_remote_annotations, write_remote_annotations, ) -from zetta_utils.segmentation.inference import run_affinities_inference_onnx from zetta_utils.tensor_mapping.tensor_mapping import TensorMapping from zetta_utils.tensor_ops.common import ( add, @@ -382,7 +392,8 @@ kornia_opening, skip_on_empty_data, ) -from zetta_utils.tensor_ops.transform import get_affine_field + +# from zetta_utils.tensor_ops.generators import get_affine_field from zetta_utils.training.datasets.joint_dataset import JointDataset from zetta_utils.training.datasets.layer_dataset import LayerDataset from zetta_utils.training.datasets.sample_indexers.base import SampleIndexer @@ -395,17 +406,6 @@ from zetta_utils.training.datasets.sample_indexers.volumetric_strided_indexer import ( VolumetricStridedIndexer, ) -from zetta_utils.training.lightning.regimes.alignment.base_encoder import ( - BaseEncoderRegime, -) -from zetta_utils.training.lightning.regimes.alignment.misalignment_detector_aced import ( - MisalignmentDetectorAcedRegime, -) -from zetta_utils.training.lightning.regimes.common import is_2d_image, log_results -from zetta_utils.training.lightning.regimes.naive_supervised import ( - NaiveSupervisedRegime, -) -from zetta_utils.training.lightning.regimes.noop import NoOpRegime from zetta_utils.training.lightning.train import lightning_train from zetta_utils.training.lightning.trainers.default import ( ConfigureTraceCallback, diff --git a/zetta_utils/internal b/zetta_utils/internal new file mode 160000 index 000000000..beffa7446 --- /dev/null +++ b/zetta_utils/internal @@ -0,0 +1 @@ +Subproject commit beffa7446d8cdb0bfa67e9c7080573dfe966dcd7 diff --git a/zetta_utils/mazepa_layer_processing/__init__.py b/zetta_utils/mazepa_layer_processing/__init__.py index fa28f23f0..88e5a99b1 100644 --- a/zetta_utils/mazepa_layer_processing/__init__.py +++ b/zetta_utils/mazepa_layer_processing/__init__.py @@ -6,8 +6,8 @@ ChunkableOpProtocol, VolumetricOpProtocol, ) -from . import alignment -from . import segmentation +from . import segmentation, alignment + from .common import ( ChunkedApplyFlowSchema, CallableOperation, diff --git a/zetta_utils/mazepa_layer_processing/alignment/aced_relaxation_flow.py b/zetta_utils/mazepa_layer_processing/alignment/aced_relaxation_flow.py index fb207a316..26181eb44 100644 --- a/zetta_utils/mazepa_layer_processing/alignment/aced_relaxation_flow.py +++ b/zetta_utils/mazepa_layer_processing/alignment/aced_relaxation_flow.py @@ -5,8 +5,9 @@ import attrs import torch -from zetta_utils import alignment, builder, mazepa, tensor_ops +from zetta_utils import builder, mazepa, tensor_ops from zetta_utils.geometry import BBox3D, Vec3D +from zetta_utils.internal import alignment from zetta_utils.layer.volumetric import ( VolumetricIndex, VolumetricIndexChunker, diff --git a/zetta_utils/mazepa_layer_processing/alignment/common.py b/zetta_utils/mazepa_layer_processing/alignment/common.py index 82ee69a40..342c00fe5 100644 --- a/zetta_utils/mazepa_layer_processing/alignment/common.py +++ b/zetta_utils/mazepa_layer_processing/alignment/common.py @@ -4,8 +4,8 @@ import torch -from zetta_utils import alignment from zetta_utils.geometry import Vec3D +from zetta_utils.internal import alignment from zetta_utils.layer.volumetric import VolumetricIndex, VolumetricLayer diff --git a/zetta_utils/segmentation/__init__.py b/zetta_utils/segmentation/__init__.py deleted file mode 100644 index 7386cfc41..000000000 --- a/zetta_utils/segmentation/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .affinity import AffinityLoss, AffinityProcessor -from .balance import BinaryClassBalancer -from .common import MultiHeadedProcessor -from .embedding import * -from .inference import run_affinities_inference_onnx -from .loss import * -from .affs_inferencer import AffinitiesInferencer diff --git a/zetta_utils/segmentation/affinity.py b/zetta_utils/segmentation/affinity.py deleted file mode 100644 index 6a6cf7bc6..000000000 --- a/zetta_utils/segmentation/affinity.py +++ /dev/null @@ -1,284 +0,0 @@ -# pylint: disable=unused-argument -from __future__ import annotations - -from functools import partial -from typing import Callable, Literal, Sequence, cast - -import attrs -import numpy as np -import torch -from torch import nn -from typeguard import typechecked - -from zetta_utils import builder, tensor_ops -from zetta_utils.geometry import Vec3D -from zetta_utils.geometry.bbox import Slices3D -from zetta_utils.layer import JointIndexDataProcessor -from zetta_utils.layer.volumetric import VolumetricIndex - -from .loss import LossWithMask - -NDIM = 3 - - -@typechecked -class EdgeSampler: - def __init__(self, edges: Sequence[Sequence[int]]) -> None: - assert len(edges) > 0 - assert all(len(edge) == NDIM for edge in edges) - self.edges = list(edges) - - def generate_edges(self) -> Sequence[Sequence[int]]: - return list(self.edges) - - -@typechecked -class EdgeDecoder(nn.Module): - def __init__(self, edges: Sequence[Sequence[int]], pad_crop: bool) -> None: - super().__init__() - assert len(edges) > 0 - assert all(len(edge) == NDIM for edge in edges) - self.edges = list(edges) - self.pad_crop = pad_crop - - def forward(self, x: torch.Tensor, idx: int) -> torch.Tensor: - assert x.ndim >= 4 - num_channels = x.shape[-(NDIM + 1)] # CZYX - assert num_channels == len(self.edges) - assert 0 <= idx < num_channels - data = x[..., [idx], :, :, :] - if not self.pad_crop: - edge = self.edges[idx] - data = tensor_ops.get_disp_pair(data, edge)[1] - return data - - -@typechecked -class EdgeCRF(nn.Module): - def __init__( - self, - criterion: Callable[..., nn.Module], - reduction: Literal["mean", "sum", "none"], - balancer: nn.Module | None, - ) -> None: - super().__init__() - try: - criterion_ = criterion(reduction="none") - except (KeyError, TypeError): - criterion_ = criterion() - assert criterion_.reduction == "none" - if isinstance(criterion_, LossWithMask): - self.criterion = criterion_ - else: - self.criterion = LossWithMask(criterion, reduction="none") - self.reduction = reduction - self.balancer = balancer - - def forward( - self, - preds: Sequence[torch.Tensor], - trgts: Sequence[torch.Tensor], - masks: Sequence[torch.Tensor], - ) -> torch.Tensor | list[torch.Tensor] | None: - assert len(preds) == len(trgts) == len(masks) > 0 - - dtype = preds[0].dtype - device = preds[0].device - - losses = [] - nmsk = torch.tensor(0, dtype=dtype, device=device) - - for pred, trgt, mask in zip(preds, trgts, masks): - if self.balancer is not None: - mask = self.balancer(trgt, mask) - loss_ = self.criterion(pred, trgt, mask) - if loss_ is not None: - losses.append(loss_) - nmsk += torch.count_nonzero(mask) - - if nmsk.item() == 0: - assert len(losses) == 0 - return None - - if self.reduction == "none": - return losses - - # Sum up losses - losses = list(map(torch.sum, losses)) - loss = torch.sum(torch.stack(losses)) - - if self.reduction == "mean": - assert nmsk.item() > 0 - loss /= nmsk.to(loss.dtype).item() - - return loss - - -@typechecked -def _compute_slices( - edges: Sequence[Sequence[int]], - pad_crop: bool, - symmetric: bool = False, -) -> list[Slices3D]: - assert len(edges) > 0 - assert all(len(edge) == NDIM for edge in edges) - - if not pad_crop: - slices = cast(Slices3D, tuple([slice(0, None)] * NDIM)) - return [slices] * len(edges) - - # Padding in the negative & positive directions - if symmetric: - pad_max = max(abs(np.concatenate(edges))) - pad_neg = pad_pos = np.array([pad_max] * 3) - else: - pad_neg = -np.amin(np.array(edges), axis=0, initial=0) - pad_pos = np.amax(np.array(edges), axis=0, initial=0) - - # Compute slices for each edge - result = [] - for edge in edges: - slices_ = [] - for lpad, rpad, disp in zip(pad_neg, pad_pos, edge): - start, end = lpad, -rpad - if disp > 0: - end += disp - else: - start += disp - slices_.append(slice(start, None) if end == 0 else slice(start, end)) - slices = cast(Slices3D, tuple(slices_)) - result.append(slices) - - return result - - -@typechecked -def _expand_slices_dim(slices: Slices3D, ndim: int) -> tuple[slice, ...]: - extra_dims = ndim - NDIM - assert extra_dims >= 0 - extra_slc = [slice(0, None)] * extra_dims - result = tuple(extra_slc + list(slices)) - return result - - -@builder.register("AffinityLoss") -@typechecked -class AffinityLoss(nn.Module): - def __init__( - self, - edges: Sequence[Sequence[int]], - criterion: Callable[..., nn.Module], - reduction: Literal["mean", "sum", "none"] = "none", - balancer: nn.Module | None = None, - pad_crop: bool = False, - ) -> None: - super().__init__() - self.slices = _compute_slices(edges, pad_crop) - self.sampler = EdgeSampler(edges) - self.decoder = EdgeDecoder(edges, pad_crop) - self.criterion = EdgeCRF(criterion, reduction, balancer) - - def forward( - self, - pred: torch.Tensor, - trgt: torch.Tensor, - mask: torch.Tensor, - ) -> torch.Tensor | list[torch.Tensor] | None: - preds = [] # type: list[torch.Tensor] - trgts = [] - masks = [] - edges = self.sampler.generate_edges() - for idx, (edge, slices) in enumerate(zip(edges, self.slices)): - aff, msk = tensor_ops.seg_to_aff(trgt[slices], edge, mask=mask) - preds.append(self.decoder(pred, idx)) - trgts.append(aff) - masks.append(msk) - return self.criterion(preds, trgts, masks) - - -@builder.register("AffinityProcessor") -@typechecked -@attrs.mutable -class AffinityProcessor(JointIndexDataProcessor): # pragma: no cover - source: str - spec: dict[str, Sequence[Sequence[int]]] - symmetric: bool = False - - pad_neg: Vec3D[int] = attrs.field(init=False) - pad_pos: Vec3D[int] = attrs.field(init=False) - slices: list[Slices3D] = attrs.field(init=False) - prepared_resolution: Vec3D | None = attrs.field(init=False, default=None) - - def __attrs_post_init__(self): - edges_union = [] - for edges in self.spec.values(): - assert len(edges) > 0 - assert all(len(edge) == NDIM for edge in edges) - edges_union += list(edges) - - # Padding in the negative & positive directions - if self.symmetric: - pad_max = max(abs(np.concatenate(edges_union))) - self.pad_neg = -Vec3D[int](pad_max, pad_max, pad_max) - self.pad_pos = Vec3D[int](pad_max, pad_max, pad_max) - else: - self.pad_neg = Vec3D[int](*np.amin(np.array(edges_union), axis=0, initial=0)) - self.pad_pos = Vec3D[int](*np.amax(np.array(edges_union), axis=0, initial=0)) - - # Slices for cropping - self.slices = _compute_slices(edges_union, pad_crop=True, symmetric=self.symmetric) - - def _crop_data(self, data: torch.Tensor) -> torch.Tensor: - assert self.prepared_resolution is not None - idx = ( - VolumetricIndex.from_coords( - start_coord=Vec3D(0, 0, 0), - end_coord=Vec3D(*data.shape[-3:]), - resolution=Vec3D(*self.prepared_resolution), - ) - .translated_start(-self.pad_neg) - .translated_end(-self.pad_pos) - ) - slc = _expand_slices_dim(idx.to_slices(), ndim=data.ndim) - return data[slc] - - def process_index( - self, idx: VolumetricIndex, mode: Literal["read", "write"] - ) -> VolumetricIndex: - """Pad index to have sufficient context for computing affinities""" - self.prepared_resolution = idx.resolution - result = idx.translated_start(self.pad_neg).translated_end(self.pad_pos) - return result - - def process_data( - self, data: dict[str, torch.Tensor], mode: Literal["read", "write"] - ) -> dict[str, torch.Tensor]: - # Process other data first - for key, value in data.items(): - if not key.startswith(self.source): - data[key] = self._crop_data(value) - - # Segmentation and mask - seg = data[self.source] - mask = data[self.source + "_mask"] - - # Expand slices dim - slices = list(map(partial(_expand_slices_dim, ndim=seg.ndim), self.slices)) - - # Process affinity - for target, edges in self.spec.items(): - affs, msks = [], [] - for edge, slc in zip(edges, slices[: len(edges)]): - aff, msk = tensor_ops.seg_to_aff(seg[slc], edge, mask=mask[slc]) - affs.append(aff) - msks.append(msk) - data[target] = torch.cat(affs, dim=-4) - data[target + "_mask"] = torch.cat(msks, dim=-4) - slices = slices[len(edges) :] - - # Process segmentation and mask - data[self.source] = self._crop_data(seg) - data[self.source + "_mask"] = self._crop_data(mask) - - self.prepared_resolution = None - return data diff --git a/zetta_utils/segmentation/affs_inferencer.py b/zetta_utils/segmentation/affs_inferencer.py deleted file mode 100644 index de12d0116..000000000 --- a/zetta_utils/segmentation/affs_inferencer.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -from typing import Sequence - -import attrs -import einops -import numpy as np -import torch -from typeguard import typechecked - -from zetta_utils import builder, convnet - - -@builder.register("AffinitiesInferencer") -@typechecked -@attrs.frozen -class AffinitiesInferencer: - # Input uint8 [ 0 .. 255] - # Output float [ 0 .. 255] - - # Don't create the model during initialization for efficient serialization - model_path: str - output_channels: Sequence[int] - - bg_mask_channel: int | None = None - bg_mask_threshold: float = 0.0 - bg_mask_invert_threshold: bool = False - - def __call__( - self, - image: torch.Tensor, - image_mask: torch.Tensor, - output_mask: torch.Tensor, - ) -> torch.Tensor: - - if image.dtype == torch.uint8: - data_in = image.float() / 255.0 # [0.0 .. 1.0] - else: - raise ValueError(f"Unsupported image dtype: {image.dtype}") - - # mask input - data_in = data_in * image_mask - data_in = einops.rearrange(data_in, "C X Y Z -> C Z Y X") - data_in = data_in.unsqueeze(0).float() - - data_out = convnet.utils.load_and_run_model(path=self.model_path, data_in=data_in) - - # Extract requested channels - arrays = [] - for channel in self.output_channels: - arrays.append(data_out[:, channel, ...]) - if self.bg_mask_channel is not None: - arrays.append(data_out[:, self.bg_mask_channel, ...]) - data_out = torch.Tensor(np.stack(arrays, axis=1)[0]) - - # mask output with bg_mask - num_channels = len(self.output_channels) - output = data_out[0:num_channels, :, :, :] - if self.bg_mask_channel is not None: - if self.bg_mask_invert_threshold: - bg_mask = data_out[num_channels:, :, :, :] > self.bg_mask_threshold - else: - bg_mask = data_out[num_channels:, :, :, :] < self.bg_mask_threshold - output = torch.Tensor(output) * bg_mask - - # mask output - output = einops.rearrange(output, "C Z Y X -> C X Y Z") - output = output * output_mask - - return output diff --git a/zetta_utils/segmentation/balance.py b/zetta_utils/segmentation/balance.py deleted file mode 100644 index 2dcf65c10..000000000 --- a/zetta_utils/segmentation/balance.py +++ /dev/null @@ -1,90 +0,0 @@ -from __future__ import annotations - -from functools import partial - -import numpy as np -import torch -from torch import nn -from typeguard import typechecked - -from zetta_utils import builder - - -@builder.register("BinaryClassBalancer") -@typechecked -class BinaryClassBalancer(nn.Module): # pragma: no cover - """ - Computes a weight map by balancing foreground/background. - - :param weight0: - :param weight1: - :param clipmin: - :param clipmax: - :param group: - """ - - def __init__( - self, - weight0: float | None = None, - weight1: float | None = None, - clipmin: float = 0.01, - clipmax: float = 0.99, - group: int = 0, - ): - super().__init__() - assert weight0 > 0 if weight0 is not None else True - assert weight1 > 0 if weight1 is not None else True - self.weight0 = weight0 - self.weight1 = weight1 - self.dynamic = (weight0 is None) and (weight1 is None) - self.clip = partial(np.clip, a_min=clipmin, a_max=clipmax) - self.group = group - - def forward( - self, - trgt: torch.Tensor, - mask: torch.Tensor | None = None, - ) -> torch.Tensor: - if mask is None: - mask = torch.ones_like(trgt, dtype=torch.float32) - - num_channels = trgt.shape[-4] - group = self.group if self.group > 0 else num_channels - - balanced = [] - for i in range(0, num_channels, group): - start = i - end = min(i + group, num_channels) - balanced.append( - self._balance( - trgt[..., start:end, :, :, :], - mask[..., start:end, :, :, :], - ) - ) - result = torch.cat(balanced, dim=-4) - return result - - def _balance(self, trgt: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - dtype = mask.dtype - ones = mask * torch.eq(trgt, 1).type(dtype) - zeros = mask * torch.eq(trgt, 0).type(dtype) - - # Dynamic balancing - if self.dynamic: - - n_ones = ones.sum().item() - n_zeros = zeros.sum().item() - if (n_ones + n_zeros) > 0: - ones *= self.clip(n_zeros / (n_ones + n_zeros)) - zeros *= self.clip(n_ones / (n_ones + n_zeros)) - - # Static weighting - else: - - if self.weight1 is not None: - ones *= self.weight1 - - if self.weight0 is not None: - zeros *= self.weight0 - - return (ones + zeros).type(dtype) diff --git a/zetta_utils/segmentation/common.py b/zetta_utils/segmentation/common.py deleted file mode 100644 index 0787485b1..000000000 --- a/zetta_utils/segmentation/common.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import attrs -import torch -from typeguard import typechecked - -from zetta_utils import builder -from zetta_utils.layer import DataProcessor - - -@builder.register("MultiHeadedProcessor") -@typechecked -@attrs.frozen -class MultiHeadedProcessor(DataProcessor): # pragma: no cover - spec: dict[str, list[str]] - - def __call__(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - for source, targets in self.spec.items(): - for target in targets: - data[target] = data[source] - if source + "_mask" in data: - data[target + "_mask"] = data[source + "_mask"] - return data diff --git a/zetta_utils/segmentation/embedding/__init__.py b/zetta_utils/segmentation/embedding/__init__.py deleted file mode 100644 index 3a2c56991..000000000 --- a/zetta_utils/segmentation/embedding/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .common import EmbeddingProcessor -from .edge import EdgeLoss -from .mean import MeanLoss -from .utils import vec_to_pca, vec_to_rgb diff --git a/zetta_utils/segmentation/embedding/common.py b/zetta_utils/segmentation/embedding/common.py deleted file mode 100644 index 72eb843ab..000000000 --- a/zetta_utils/segmentation/embedding/common.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -import attrs -import cc3d -import numpy as np -import torch -from typeguard import typechecked - -from zetta_utils import builder -from zetta_utils.layer import DataProcessor -from zetta_utils.tensor_ops import convert - - -@builder.register("EmbeddingProcessor") -@typechecked -@attrs.mutable -class EmbeddingProcessor(DataProcessor): # pragma: no cover - source: str - target: str - split_label: bool = False - - def __call__(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: - - # Segmentation and mask - seg = data[self.source] - msk = data[self.source + "_mask"] - - # Target - data[self.target] = seg - data[self.target + "_mask"] = msk - - # Split label - if self.split_label: - seg_np = convert.to_np(seg).astype(np.uint64) - seg_cc = cc3d.connected_components(seg_np) - seg_split = convert.astype(seg_cc, seg, cast=True) - data[self.target + "_split"] = seg_split - - return data diff --git a/zetta_utils/segmentation/embedding/edge.py b/zetta_utils/segmentation/embedding/edge.py deleted file mode 100644 index fbd7eadad..000000000 --- a/zetta_utils/segmentation/embedding/edge.py +++ /dev/null @@ -1,141 +0,0 @@ -from __future__ import annotations - -import random -from functools import partial -from typing import Literal, Sequence - -import torch -from torch import nn -from typeguard import typechecked - -from zetta_utils import builder, tensor_ops -from zetta_utils.geometry import Vec3D - -from ..affinity import EdgeCRF -from ..loss import LossWithMask - -NDIM = 3 - - -@typechecked -class EdgeSampler: - def __init__( - self, - edges: Sequence[Sequence[int]] | None, - bounds: Sequence[Sequence[int]], - ) -> None: - edges = [] if edges is None else edges - assert len(edges) > 0 or len(bounds) > 0 - assert all(len(edge) == NDIM for edge in edges) - assert all(len(bound) == NDIM for bound in bounds) - assert all(bound != (0, 0, 0) for bound in bounds) - self.edges = list(edges) - self.bounds = list(bounds) - - def generate_edges( - self, - num_edges: Sequence[int], - ) -> list[Vec3D]: - """Generate `self.edges` if any, and then additionally generate random - edges for each prespecified bound. - """ - assert len(num_edges) == len(self.bounds) - assert all(n_edge >= 0 for n_edge in num_edges) - - result = [Vec3D(*edge) for edge in self.edges] - - # Sample random edges for each bound - for bound, n_edge in zip(self.bounds, num_edges): - sampled = 0 - while sampled < n_edge: - edge = Vec3D(*[random.randint(-abs(bnd), abs(bnd)) for bnd in bound]) - if edge == Vec3D(0, 0, 0): - continue - result.append(edge) - sampled += 1 - - return result - - -@typechecked -def compute_affinity( - data1: torch.Tensor, - data2: torch.Tensor, - dim: int = -4, - keepdims: bool = True, -) -> torch.Tensor: - """Compute an affinity map from a pair of embeddings based on l2 distance.""" - dist = torch.sum((data1 - data2) ** 2, dim=dim, keepdim=keepdims) - result = torch.exp(-dist) - return result - - -@typechecked -class EdgeDecoder: - def __call__(self, data: torch.Tensor, disp: Vec3D) -> torch.Tensor: - pair = tensor_ops.get_disp_pair(data, disp) - result = compute_affinity(pair[0], pair[1]) - return result - - -@builder.register("EdgeLoss") -@typechecked -class EdgeLoss(nn.Module): - def __init__( - self, - bounds: Sequence[Sequence[int]], - num_edges: Sequence[int], - edges: Sequence[Sequence[int]] | None = None, - reduction: Literal["mean", "sum", "none"] = "none", - balancer: nn.Module | None = None, - ) -> None: - super().__init__() - assert len(num_edges) == len(bounds) - self.num_edges = list(num_edges) - self.sampler = EdgeSampler(edges, bounds) - self.decoder = EdgeDecoder() - self.criterion = EdgeCRF( - partial( - LossWithMask, - criterion=nn.BCEWithLogitsLoss, - reduction="none", - ), - reduction, - balancer, - ) - - def forward( - self, - pred: torch.Tensor, - trgt: torch.Tensor, - mask: torch.Tensor, - splt: torch.Tensor | None = None, - ) -> torch.Tensor | list[torch.Tensor] | None: - preds = [] # type: list[torch.Tensor] - trgts = [] - masks = [] - edges = self.sampler.generate_edges(self.num_edges) - for edge in edges: - aff, msk = self._generate_target(trgt, mask, edge, splt) - preds.append(self.decoder(pred, edge)) - trgts.append(aff) - masks.append(msk) - return self.criterion(preds, trgts, masks) - - @staticmethod - def _generate_target( - trgt: torch.Tensor, - mask: torch.Tensor, - edge: Vec3D, - splt: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - # Ignore background - mask *= (trgt != 0).to(mask.dtype) - aff, msk = tensor_ops.seg_to_aff(trgt, edge, mask=mask) - - # Mask out interactions between local split by connected components - if splt is not None: - splt_aff = tensor_ops.seg_to_aff(splt, edge) - msk[aff != splt_aff] = 0 - - return aff, msk diff --git a/zetta_utils/segmentation/embedding/mean.py b/zetta_utils/segmentation/embedding/mean.py deleted file mode 100644 index fb52ebd3b..000000000 --- a/zetta_utils/segmentation/embedding/mean.py +++ /dev/null @@ -1,179 +0,0 @@ -# pylint: disable = no-self-use -from __future__ import annotations - -from typing import Sequence - -import attrs -import numpy as np -import numpy.typing as npt -import torch -from torch import nn -from typeguard import typechecked - -from zetta_utils import builder -from zetta_utils.tensor_ops import convert - - -@typechecked -def create_mapping(trgt: npt.NDArray, splt: npt.NDArray, mask: npt.NDArray) -> list[list[int]]: - trgt, splt = trgt.astype(np.uint64), splt.astype(np.uint64) - encoded = (2 ** 32) * trgt + splt - encoded[mask == 0] = 0 - unq = np.unique(encoded) - mapping: dict[int, list[int]] = {} - for unq_id in unq: - trgt_id, splt_id = int(unq_id // (2 ** 32)), int(unq_id % (2 ** 32)) - mapping[trgt_id] = mapping.get(trgt_id, []) + [splt_id] - result = list(mapping.values()) - return result - - -@builder.register("MeanLoss") -@typechecked -@attrs.mutable -class MeanLoss(nn.Module): - alpha: float = 1.0 - beta: float = 1.0 - gamma: float = 0.001 - delta_v: float = 0.0 - delta_d: float = 1.5 - recompute_ext: bool = False - - def __attrs_pre_init__(self): - super().__init__() - - def forward( - self, - embd: torch.Tensor, - trgt: torch.Tensor, - mask: torch.Tensor, - splt: torch.Tensor | None = None, - ) -> torch.Tensor | None: - """ - :param embd: Embeddings - :param trgt: Target segmentation - :param mask: Segmentation mask - :param splt: Connected components of the target segmentation - """ - groups = None - if self.recompute_ext: - assert splt is not None - trgt_np = np.squeeze(convert.to_np(trgt)) - splt_np = np.squeeze(convert.to_np(splt)) - mask_np = np.squeeze(convert.to_np(mask)) - groups = create_mapping(trgt_np, splt_np, mask_np) - trgt = splt - - trgt = trgt.to(torch.int) - trgt *= (mask > 0).to(torch.int) - - # Unique nonzero IDs - ids = np.unique(convert.to_np(trgt)) - ids = ids[ids != 0].tolist() - - mext = self.compute_ext_matrix(ids, groups, self.recompute_ext) - vecs = self.generate_vecs(embd, trgt, ids) - means = [torch.mean(vec, dim=0) for vec in vecs] - weights = [1.0] * len(means) - - # Compute loss - loss_int = self.compute_loss_int(vecs, means, weights, embd.device) - loss_ext = self.compute_loss_ext(means, weights, mext, embd.device) - loss_nrm = self.compute_loss_nrm(means, embd.device) - - result = (self.alpha * loss_int) + (self.beta * loss_ext) + (self.gamma * loss_nrm) - return result - - def compute_loss_int( - self, - vecs: list[torch.Tensor], - means: list[torch.Tensor], - weights: list[float], - device: torch.device, - ) -> torch.Tensor: - """Compute the internal term of the loss.""" - assert len(vecs) == len(means) == len(weights) - zero = lambda: torch.zeros(1).to(device).squeeze() - loss = zero() - for vec, mean, weight in zip(vecs, means, weights): - margin = torch.norm(vec - mean, p=1, dim=1) - self.delta_v - loss += weight * torch.mean(torch.max(margin, zero()) ** 2) - result = loss / max(1.0, len(vecs)) - return result - - def compute_loss_ext( - self, - means: list[torch.Tensor], - weights: list[float], - mext: torch.Tensor | None, - device: torch.device, - ) -> torch.Tensor: - """Compute the external term of the loss.""" - assert len(means) == len(weights) - zero = lambda: torch.zeros(1).to(device).squeeze() - loss = zero() - count = len(means) - if (count > 1) and (mext is not None): - means1 = torch.stack(means).unsqueeze(0) # 1 x N x Dim - means2 = torch.stack(means).unsqueeze(1) # N x 1 x Dim - margin = 2 * self.delta_d - torch.norm(means2 - means1, p=1, dim=2) - margin = margin[mext] - loss = torch.sum(torch.max(margin, zero()) ** 2) - result = loss / max(1.0, count * (count - 1)) - return result - - def compute_loss_nrm(self, means: list[torch.Tensor], device: torch.device) -> torch.Tensor: - """Compute the regularization term of the loss.""" - zero = lambda: torch.zeros(1).to(device).squeeze() - loss = zero() - if len(means) > 0: - loss += torch.mean(torch.norm(torch.stack(means), p=1, dim=1)) - result = loss - return result - - def generate_vecs( - self, - embd: torch.Tensor, - trgt: torch.Tensor, - ids: Sequence[int], - ) -> list[torch.Tensor]: - """ - Generate a list of vectorized embeddings for each ground truth object. - """ - result = [] - for obj_id in ids: - obj = torch.nonzero(trgt == obj_id) - z, y, x = obj[:, -3], obj[:, -2], obj[:, -1] - vec = embd[0, :, z, y, x].transpose(0, 1) # Count x Dim - result.append(vec) - return result - - def compute_ext_matrix( - self, - ids: Sequence[int], - groups: Sequence[Sequence[int]] | None = None, - recompute_ext: bool = False, - ) -> torch.Tensor | None: - """ - Compute a matrix that indicates the presence of 'external' interaction - between objects. - """ - num_ids = len(ids) - mext = torch.ones((num_ids, num_ids)) - torch.eye(num_ids) - - # Recompute external matrix - if recompute_ext: - assert groups is not None - idmap = {x: i for i, x in enumerate(ids)} - for group in groups: - for i, id_i in enumerate(group): - for id_j in group[i + 1 :]: - mext[idmap[id_i], idmap[id_j]] = 0 - mext[idmap[id_j], idmap[id_i]] = 0 - - # Safeguard - if mext.sum().item() == 0: - return None - - result = mext.to(torch.bool) - return result diff --git a/zetta_utils/segmentation/embedding/utils.py b/zetta_utils/segmentation/embedding/utils.py deleted file mode 100644 index 4da0ef0e2..000000000 --- a/zetta_utils/segmentation/embedding/utils.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -from typing import Sequence - -import numpy as np -import torch -from sklearn.decomposition import PCA -from torch.nn import functional as F -from typeguard import typechecked - -from zetta_utils import builder, tensor_ops -from zetta_utils.tensor_ops import convert -from zetta_utils.tensor_typing import TensorTypeVar - - -@typechecked -def vec_to_pca(data: TensorTypeVar) -> TensorTypeVar: - """ - Transform feature vectors into an RGB map with PCA dimensionality reduction. - - :param data: Feature vectors - """ - assert (data.ndim == 5) and (data.shape[0] == 1) - data_np = convert.to_np(data).astype(np.float32) - - # pylint: disable=invalid-name - dim = data_np.shape[-4] - data_tp = data_np.transpose(0, 2, 3, 4, 1) - X = data_tp.reshape(-1, dim) # (n_samples, n_features) - pca = PCA(dim).fit_transform(X) - pca_tp = pca.reshape(data_tp.shape) - pca_np = pca_tp.transpose(0, 4, 1, 2, 3) - result = convert.astype(pca_np, data) - return result - - -@typechecked -def vec_to_rgb(data: TensorTypeVar) -> TensorTypeVar: - """ - Transform feature vectors into an RGB map by slicing the first three - channels and then rescale them. - - :param data: Feature vectors - :return: RGB map - """ - assert (data.ndim == 5) and (data.shape[0] == 1) - data_np = convert.to_np(data).astype(np.float32) - rgbmap = data_np[:, 0:3, ...] - rgbmap -= np.min(rgbmap) - rgbmap /= np.max(rgbmap) - result = convert.astype(rgbmap, data) - return result - - -@builder.register("vec_to_affs") -@typechecked -def vec_to_affs( - vec: torch.Tensor, - edges: Sequence[Sequence[int]] = ((1, 0, 0), (0, 1, 0), (0, 0, 1)), # assume CXYZ - delta_d: float = 1.5, -) -> torch.Tensor: - - assert vec.ndimension() >= 4 - assert len(edges) > 0 - - affs = [] - for edge in edges: - pair = tensor_ops.get_disp_pair(vec.numpy(), edge) - aff = _compute_affinity(pair[0], pair[1], delta_d=delta_d) # - pad = [] - for e in reversed(edge): - if e > 0: - pad.extend([e, 0]) - else: - pad.extend([0, abs(e)]) - affs.append(F.pad(aff, pad)) - - assert len(affs) > 0 - for aff in affs: - assert affs[0].size() == aff.size() - - return torch.cat(affs, dim=-4) - - -def _compute_affinity( - embd1: torch.Tensor, - embd2: torch.Tensor, - dim: int = -4, - keepdims: bool = True, - delta_d: float = 1.5, -) -> torch.Tensor: - """Compute an affinity map from a pair of embeddings.""" - norm = torch.norm(embd1 - embd2, p=1, dim=dim, keepdim=keepdims) - margin = (2 * delta_d - norm) / (2 * delta_d) - zero = torch.zeros(1).to(embd1.device, dtype=embd1.dtype) - result = torch.max(zero, margin) ** 2 - return result - - -@builder.register("vec_to_affs_v1") -@typechecked -def vec_to_affs_v1( - embeddings: torch.Tensor, - offsets: Sequence[int] = (1, 1, 1), - delta_mult: int = 15000, -) -> torch.Tensor: - # Tri's naive implementation - metric_out = np.zeros(shape=(3,) + embeddings.shape[1:]) - # compute mean-square - metric_out[0, offsets[0] :, :, :] = ( - (embeddings[:, offsets[0] :, :, :] - embeddings[:, : -offsets[0], :, :]) ** 2 - ).mean(axis=0) - metric_out[1, :, offsets[1] :, :] = ( - (embeddings[:, :, offsets[1] :, :] - embeddings[:, :, : -offsets[1], :]) ** 2 - ).mean(axis=0) - metric_out[2, :, :, offsets[2] :] = ( - (embeddings[:, :, :, offsets[2] :] - embeddings[:, :, :, : -offsets[2]]) ** 2 - ).mean(axis=0) - metric_out *= delta_mult - # convert to affinities and torch.tensor - metric_out[metric_out > 1] = 1 - metric_out_ = torch.Tensor(1.0 - metric_out) - return metric_out_ diff --git a/zetta_utils/segmentation/inference.py b/zetta_utils/segmentation/inference.py deleted file mode 100644 index 2a2d0eff9..000000000 --- a/zetta_utils/segmentation/inference.py +++ /dev/null @@ -1,43 +0,0 @@ -import cachetools -import fsspec -import numpy as np -import onnx -import onnxruntime as ort -import torch - -from zetta_utils import builder - -_session_cache: cachetools.LRUCache = cachetools.LRUCache(maxsize=1) - - -@cachetools.cached(_session_cache) -def _get_session(model_path: str): # pragma: no cover - with fsspec.open(model_path, "rb") as f: - if model_path.endswith(".onnx"): - onnx_model = onnx.load(f) - return ort.InferenceSession( - onnx_model.SerializeToString(), providers=["CUDAExecutionProvider"] - ) - - -@builder.register("run_affinities_inference_onnx") -def run_affinities_inference_onnx( - image: torch.Tensor, - image_mask: torch.Tensor, - output_mask: torch.Tensor, - model_path: str, - myelin_mask_threshold: float, -) -> torch.Tensor: # pragma: no cover - - session = _get_session(model_path) - output = session.run(None, {"input": (image * image_mask).unsqueeze(0).float().numpy()})[0] - - aff = output[:, :3, ...] - msk = np.amax(output[:, 3:, ...], axis=-4, keepdims=True) - output = np.concatenate((aff, msk), axis=-4) - output = torch.Tensor(output[0]) - output_aff = output[0:3, :, :, :] - output_mye_mask = output[3:, :, :, :] < myelin_mask_threshold - output = torch.permute(torch.Tensor(output_aff) * output_mask * output_mye_mask, (0, 2, 3, 1)) - - return output diff --git a/zetta_utils/segmentation/loss.py b/zetta_utils/segmentation/loss.py deleted file mode 100644 index 5abb510f1..000000000 --- a/zetta_utils/segmentation/loss.py +++ /dev/null @@ -1,120 +0,0 @@ -from __future__ import annotations - -from typing import Callable, Literal - -import numpy as np -import torch -from torch import nn -from typeguard import typechecked - -from zetta_utils import builder - - -@builder.register("LossWithMask") -@typechecked -class LossWithMask(nn.Module): # pragma: no cover - def __init__( - self, - criterion: Callable[..., nn.Module], - reduction: Literal["mean", "sum", "none"] = "sum", - balancer: nn.Module | None = None, - ) -> None: - super().__init__() - try: - self.criterion = criterion(reduction="none") - except (KeyError, TypeError): - self.criterion = criterion() - assert self.criterion.reduction == "none" - self.reduction = reduction - self.balancer = balancer - self.balanced = False - - def forward( - self, - pred: torch.Tensor, - trgt: torch.Tensor, - mask: torch.Tensor, - ) -> torch.Tensor | None: - nmsk = torch.count_nonzero(mask) - if nmsk.item() == 0: - return None - - # Optional class balancing - if (not self.balanced) and (self.balancer is not None): - mask = self.balancer(trgt, mask) - - loss = mask * self.criterion(pred, trgt) - if self.reduction == "none": - return loss - - loss = torch.sum(loss) - - if self.reduction == "mean": - assert nmsk.item() > 0 - loss /= nmsk.to(loss.dtype).item() - - return loss - - -@builder.register("BinaryLossWithMargin") -@typechecked -class BinaryLossWithMargin(LossWithMask): - def __init__( - self, - criterion: Callable[..., nn.Module], - reduction: Literal["mean", "sum", "none"] = "sum", - balancer: nn.Module | None = None, - margin: float = 0, - logits: bool = False, - ) -> None: - super().__init__(criterion, reduction, balancer) - self.margin = np.clip(margin, 0, 1) - self.logits = logits - - def forward( - self, - pred: torch.Tensor, - trgt: torch.Tensor, - mask: torch.Tensor, - ) -> torch.Tensor | None: - # Optional class balancing - if self.balancer is not None: - mask = self.balancer(trgt, mask) - self.balanced = True - - high = 1 - self.margin - low = self.margin - activ = torch.sigmoid(pred) if self.logits else pred - hmask = torch.ge(activ, high) * torch.eq(trgt, 1) - lmask = torch.le(activ, low) * torch.eq(trgt, 0) - mask *= 1 - (hmask | lmask).to(mask.dtype) - return super().forward(pred, trgt, mask) - - -@builder.register("BinaryLossWithInverseMargin") -@typechecked -class BinaryLossWithInverseMargin(LossWithMask): - def __init__( - self, - criterion: Callable[..., nn.Module], - reduction: Literal["mean", "sum", "none"] = "sum", - balancer: nn.Module | None = None, - margin: float = 0, - ) -> None: - super().__init__(criterion, reduction, balancer) - self.margin = np.clip(margin, 0, 1) - - def forward( - self, - pred: torch.Tensor, - trgt: torch.Tensor, - mask: torch.Tensor, - ) -> torch.Tensor | None: - # Optional class balancing - if self.balancer is not None: - mask = self.balancer(trgt, mask) - self.balanced = True - - trgt[torch.eq(trgt, 1)] = 1 - self.margin - trgt[torch.eq(trgt, 0)] = self.margin - return super().forward(pred, trgt, mask) diff --git a/zetta_utils/training/lightning/__init__.py b/zetta_utils/training/lightning/__init__.py index 7142ef097..76e8338b9 100644 --- a/zetta_utils/training/lightning/__init__.py +++ b/zetta_utils/training/lightning/__init__.py @@ -1 +1 @@ -from . import regimes, train, trainers +from . import train, trainers diff --git a/zetta_utils/training/lightning/regimes/__init__.py b/zetta_utils/training/lightning/regimes/__init__.py deleted file mode 100644 index e17ac8ad0..000000000 --- a/zetta_utils/training/lightning/regimes/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from . import common - -from . import naive_supervised -from . import noop - -from . import alignment -from . import segmentation diff --git a/zetta_utils/training/lightning/regimes/alignment/__init__.py b/zetta_utils/training/lightning/regimes/alignment/__init__.py deleted file mode 100644 index e645e5759..000000000 --- a/zetta_utils/training/lightning/regimes/alignment/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import base_encoder, misalignment_detector_aced -from .deprecated import encoding_coarsener diff --git a/zetta_utils/training/lightning/regimes/alignment/base_encoder.py b/zetta_utils/training/lightning/regimes/alignment/base_encoder.py deleted file mode 100644 index b3c63da7c..000000000 --- a/zetta_utils/training/lightning/regimes/alignment/base_encoder.py +++ /dev/null @@ -1,303 +0,0 @@ -# pragma: no cover -# pylint: disable=too-many-locals - -import os -from math import log2 -from typing import Optional - -import attrs -import einops -import numpy as np -import pytorch_lightning as pl -import torch -import torchfields -import wandb -from PIL import Image as PILImage -from pytorch_lightning import seed_everything - -from zetta_utils import builder, distributions, tensor_ops, viz - - -@builder.register("BaseEncoderRegime", versions="==0.0.2") -@attrs.mutable(eq=False) -class BaseEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors - model: torch.nn.Module - lr: float - train_log_row_interval: int = 200 - val_log_row_interval: int = 25 - max_displacement_px: float = 16.0 - l1_weight_start_val: float = 0.0 - l1_weight_end_val: float = 0.0 - l1_weight_start_epoch: int = 0 - l1_weight_end_epoch: int = 0 - locality_weight: float = 1.0 - similarity_weight: float = 0.0 - zero_value: float = 0 - ds_factor: int = 1 - worst_val_loss: float = attrs.field(init=False, default=0) - worst_val_sample: dict = attrs.field(init=False, factory=dict) - worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - - equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) - equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) - empty_tissue_threshold: float = 0.4 - - def __attrs_pre_init__(self): - super().__init__() - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer - - def log_results(self, mode: str, title_suffix: str = "", **kwargs): - if not self.logger: - return - images = [] - for k, v in kwargs.items(): - for b in range(1): - if v.dtype in (np.uint8, torch.uint8): - img = v[b].squeeze() - img[-1, -1] = 255 - img[-2, -2] = 255 - img[-1, -2] = 0 - img[-2, -1] = 0 - images.append( - wandb.Image( - PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), - caption=f"{k}_b{b}", - ) - ) - elif v.dtype in (torch.int8, np.int8): - img = v[b].squeeze().byte() + 127 - img[-1, -1] = 255 - img[-2, -2] = 255 - img[-1, -2] = 0 - img[-2, -1] = 0 - images.append( - wandb.Image( - PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), - caption=f"{k}_b{b}", - ) - ) - elif v.dtype in (torch.bool, bool): - img = v[b].squeeze().byte() * 255 - img[-1, -1] = 255 - img[-2, -2] = 255 - img[-1, -2] = 0 - img[-2, -1] = 0 - images.append( - wandb.Image( - PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), - caption=f"{k}_b{b}", - ) - ) - else: - if v.size(1) == 2 and k != "field": - img = torch.cat([v, torch.zeros_like(v[:, :1])], dim=1) - else: - img = v - v_min = img[b].min().round(decimals=4) - v_max = img[b].max().round(decimals=4) - images.append( - wandb.Image( - viz.rendering.Renderer()(img[b].squeeze()), - caption=f"{k}_b{b} | min: {v_min} | max: {v_max}", - ) - ) - - self.logger.log_image(f"results/{mode}_{title_suffix}_slider", images=images) - - def validation_epoch_start(self, _): # pylint: disable=no-self-use - seed_everything(42) - - def on_validation_epoch_end(self): - env_seed = os.environ.get("PL_GLOBAL_SEED") - if env_seed is not None: - seed_everything(int(env_seed) + self.current_epoch) - else: - seed_everything(None) - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.train_log_row_interval == 0 - - with torchfields.set_identity_mapping_cache(True, clear_cache=False): - loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) - - return loss - - @staticmethod - def _get_warped(img, field=None): - if field is not None: - img_warped = field.from_pixels()(img) - else: - img_warped = img - - return img_warped - - @staticmethod - def _down_zeros_mask(zeros_mask, count): - if count <= 0: - return zeros_mask - - scale_factor = 0.5 ** count - return ( - torch.nn.functional.interpolate( - zeros_mask.float(), scale_factor=scale_factor, mode="bilinear" - ) - > 0.99 - ) - - def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): - src = batch["images"]["src_img"] - tgt = batch["images"]["tgt_img"] - - # if ( - # (src == self.zero_value) + (tgt == self.zero_value) - # ).bool().sum() / src.numel() > self.empty_tissue_threshold: - # return None # Can't return None with DDP! - - # Get random field - combination of pregenerated Perlin noise and a random affine transform - seed_field = batch["field"].field_() - f_warp = seed_field * self.max_displacement_px - f_aff = ( - einops.rearrange( - tensor_ops.transform.get_affine_field( - size=src.shape[-1], - rot_deg=self.equivar_rot_deg_distr() if mode == "train" else 90.0, - scale=self.equivar_scale_distr() if mode == "train" else 1.0, - shear_x_deg=self.equivar_shear_deg_distr() if mode == "train" else 0.0, - shear_y_deg=self.equivar_shear_deg_distr() if mode == "train" else 0.0, - trans_x_px=self.equivar_trans_px_distr() if mode == "train" else 0.0, - trans_y_px=self.equivar_trans_px_distr() if mode == "train" else 0.0, - ), - "C X Y Z -> Z C X Y", - ) - .pixels() # type: ignore[attr-defined] - .to(seed_field.device) - ).repeat_interleave(src.size(0), dim=0) - f1_transform = f_aff.from_pixels()(f_warp.from_pixels()).pixels() - - # Warp Images and Tissue mask - src_f1 = self._get_warped(src, field=f1_transform) - tgt_f1 = self._get_warped(tgt, field=f1_transform) - - # Generate encodings: src, src_f1_enc, src_enc_f1, tgt_f1_enc - src_enc = self.model(src) - src_enc_f1 = torch.nn.functional.pad(src_enc, (1, 1, 1, 1), mode="replicate") - src_enc_f1 = ( - torch.nn.functional.pad(f1_transform, (self.ds_factor,) * 4, mode="replicate") - .from_pixels() # type: ignore[attr-defined] - .down(int(log2(self.ds_factor))) - .sample(src_enc_f1, padding_mode="border") - ) - src_enc_f1 = torch.nn.functional.pad(src_enc_f1, (-1, -1, -1, -1)) - tgt_f1_enc = self.model(tgt_f1) - - crop = 256 // self.ds_factor - src_f1 = src_f1[..., 256:-256, 256:-256] - tgt_f1 = tgt_f1[..., 256:-256, 256:-256] - src_enc_f1 = src_enc_f1[..., crop:-crop, crop:-crop] - tgt_f1_enc = tgt_f1_enc[..., crop:-crop, crop:-crop] - - # Alignment loss: Ensure even close to local optima solutions produce larger errors - # than the local optimum solution - abs_error_local_opt = ( - (src_enc_f1 - tgt_f1_enc)[:, :, 1:-1, 1:-1].pow(2).mean(dim=-3, keepdim=True) - ) - abs_error_1px_shift = torch.stack( - [ - (src_enc_f1[:, :, 2:, 1:-1] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), - (src_enc_f1[:, :, :-2, 1:-1] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), - (src_enc_f1[:, :, 1:-1, 2:] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), - (src_enc_f1[:, :, 1:-1, :-2] - tgt_f1_enc[:, :, 1:-1, 1:-1]).pow(2), - (tgt_f1_enc[:, :, 2:, 2:] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), - (tgt_f1_enc[:, :, :-2, 2:] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), - (tgt_f1_enc[:, :, 2:, :-2] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), - (tgt_f1_enc[:, :, :-2, :-2] - src_enc_f1[:, :, 1:-1, 1:-1]).pow(2), - ] - ).mean(dim=-3, keepdim=True) - - locality_error_map = ( - ((abs_error_local_opt - abs_error_1px_shift + 4.0) * 0.2) - .pow( - 8.0 # increase to put more focus on locations where bad alignment - # still produces similar encodings - try 8? -> 42 - ) - .logsumexp(dim=0) - ) - - locality_loss = ( - locality_error_map.sum() / locality_error_map.size(0) * self.ds_factor * self.ds_factor - ) - - l1_loss_map = (tgt_f1_enc.abs() + src_enc_f1.abs())[:, :, 1:-1, 1:-1] - l1_loss = ( - l1_loss_map.sum() - / (2 * tgt_f1_enc.size(0) * tgt_f1_enc.size(1)) - * self.ds_factor - * self.ds_factor - ) - - l1_weight_ratio = min( - 1.0, - max(0, self.current_epoch - self.l1_weight_start_epoch) - / max(1, self.l1_weight_end_epoch - self.l1_weight_start_epoch), - ) - l1_weight = ( - l1_weight_ratio * self.l1_weight_end_val - + (1.0 - l1_weight_ratio) * self.l1_weight_start_val - ) - - loss = locality_loss * self.locality_weight + l1_loss * l1_weight - self.log( - f"loss/{mode}", loss, on_step=True, on_epoch=True, sync_dist=True, rank_zero_only=False - ) - self.log( - f"loss/{mode}_l1_weight", - l1_weight, - on_step=False, - on_epoch=True, - prog_bar=False, - sync_dist=False, - rank_zero_only=True, - ) - self.log( - f"loss/{mode}_l1", - l1_loss, - on_step=True, - on_epoch=True, - prog_bar=False, - sync_dist=True, - rank_zero_only=False, - ) - if log_row: - self.log_results( - mode, - sample_name, - src=src, - src_enc=src_enc, - src_f1=src_f1, - src_enc_f1=src_enc_f1, - tgt_f1=tgt_f1, - tgt_f1_enc=tgt_f1_enc, - field=f_warp.tensor_(), - locality_error_map=locality_error_map, - l1_loss_map=l1_loss_map, - weighted_loss_map=( - locality_error_map / locality_error_map.size(0) * self.locality_weight - + l1_loss_map / (2 * tgt_f1_enc.size(0)) * l1_weight - ), - ) - return loss - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.val_log_row_interval == 0 - sample_name = f"{batch_idx // self.val_log_row_interval}" - - with torchfields.set_identity_mapping_cache(True, clear_cache=False): - loss = self.compute_metroem_loss( - batch=batch, mode="val", log_row=log_row, sample_name=sample_name - ) - return loss diff --git a/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder.py deleted file mode 100644 index c9b12baef..000000000 --- a/zetta_utils/training/lightning/regimes/alignment/deprecated/base_encoder.py +++ /dev/null @@ -1,524 +0,0 @@ -# type: ignore -# pragma: no cover -# pylint: disable=too-many-locals, function-redefined - -from typing import Optional - -import attrs -import cc3d -import einops -import numpy as np -import pytorch_lightning as pl -import torch -import torchfields -import wandb -from PIL import Image as PILImage -from pytorch_lightning import seed_everything - -from zetta_utils import builder, distributions, tensor_ops, viz -from zetta_utils.training.lightning.regimes.common import log_results - - -@builder.register("BaseEncoderRegime", versions="==0.0.1") -@attrs.mutable(eq=False) -class BaseEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors - model: torch.nn.Module - lr: float - train_log_row_interval: int = 200 - val_log_row_interval: int = 25 - field_magn_thr: float = 1 - max_displacement_px: float = 16.0 - post_weight: float = 1.5 - zero_value: float = 0 - worst_val_loss: float = attrs.field(init=False, default=0) - worst_val_sample: dict = attrs.field(init=False, factory=dict) - worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - - equivar_weight: float = 1.0 - equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) - equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) - empty_tissue_threshold: float = 0.4 - - def __attrs_pre_init__(self): - super().__init__() - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer - - def log_results(self, mode: str, title_suffix: str = "", **kwargs): - if self.logger is None: - return - images = [] - for k, v in kwargs.items(): - for b in range(1): - if v.dtype in (np.uint8, torch.uint8): - img = v[b].squeeze() - img[-1, -1] = 255 - img[-2, -2] = 255 - img[-1, -2] = 0 - img[-2, -1] = 0 - images.append( - wandb.Image( - PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), - caption=f"{k}_b{b}", - ) - ) - elif v.dtype in (torch.int8, np.int8): - img = v[b].squeeze().byte() + 127 - img[-1, -1] = 255 - img[-2, -2] = 255 - img[-1, -2] = 0 - img[-2, -1] = 0 - images.append( - wandb.Image( - PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), - caption=f"{k}_b{b}", - ) - ) - elif v.dtype in (torch.bool, bool): - img = v[b].squeeze().byte() * 255 - img[-1, -1] = 255 - img[-2, -2] = 255 - img[-1, -2] = 0 - img[-2, -1] = 0 - images.append( - wandb.Image( - PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), - caption=f"{k}_b{b}", - ) - ) - else: - v_min = v[b].min().round(decimals=4) - v_max = v[b].max().round(decimals=4) - images.append( - wandb.Image( - viz.rendering.Renderer()(v[b].squeeze()), - caption=f"{k}_b{b} | min: {v_min} | max: {v_max}", - ) - ) - - self.logger.log_image(f"results/{mode}_{title_suffix}_slider", images=images) - - def validation_epoch_start(self, _): # pylint: disable=no-self-use - seed_everything(42) - - def on_validation_epoch_end(self): - self.log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None - seed_everything(None) - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.train_log_row_interval == 0 - - with torchfields.set_identity_mapping_cache(True, clear_cache=False): - loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) - - return loss - - def _get_warped(self, img, field=None): - img_padded = torch.nn.functional.pad(img, (1, 1, 1, 1), value=self.zero_value) - if field is not None: - img_warped = field.from_pixels()(img) - else: - img_warped = img - - zeros_padded = img_padded == self.zero_value - zeros_padded_cc = np.array( - [ - cc3d.connected_components( - x.detach().squeeze().cpu().numpy(), connectivity=4 - ).reshape(zeros_padded[0].shape) - for x in zeros_padded - ] - ) - - non_tissue_zeros_padded = zeros_padded.clone() - non_tissue_zeros_padded[ - torch.tensor(zeros_padded_cc != zeros_padded_cc.ravel()[0], device=zeros_padded.device) - ] = False # keep masking resin, restore somas in center - - if field is not None: - zeros_warped = ( - torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") - .from_pixels() # type: ignore[attr-defined] - .sample((~zeros_padded).float(), padding_mode="border") - <= 0.1 - ) - non_tissue_zeros_warped = ( - torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") - .from_pixels() # type: ignore[attr-defined] - .sample((~non_tissue_zeros_padded).float(), padding_mode="border") - <= 0.1 - ) - else: - zeros_warped = zeros_padded - non_tissue_zeros_warped = non_tissue_zeros_padded - - zeros_warped = torch.nn.functional.pad(zeros_warped, (-1, -1, -1, -1)) - non_tissue_zeros_warped = torch.nn.functional.pad( - non_tissue_zeros_warped, (-1, -1, -1, -1) - ) - - img_warped[zeros_warped] = self.zero_value - return img_warped, ~zeros_warped, ~non_tissue_zeros_warped - - def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): - src = batch["images"]["src_img"] - tgt = batch["images"]["tgt_img"] - - if ( - (src == self.zero_value) + (tgt == self.zero_value) - ).bool().sum() / src.numel() > self.empty_tissue_threshold: - return None - - seed_field = batch["field"].field_() - f_warp_large = seed_field * self.max_displacement_px - f_warp_small = ( - seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) - ) - - f_aff = ( - einops.rearrange( - tensor_ops.transform.get_affine_field( - size=src.shape[-1], - rot_deg=self.equivar_rot_deg_distr(), - scale=self.equivar_scale_distr(), - shear_x_deg=self.equivar_shear_deg_distr(), - shear_y_deg=self.equivar_shear_deg_distr(), - trans_x_px=self.equivar_trans_px_distr(), - trans_y_px=self.equivar_trans_px_distr(), - ), - "C X Y Z -> Z C X Y", - ) - .pixels() # type: ignore[attr-defined] - .to(seed_field.device) - ).repeat_interleave(src.size(0), dim=0) - f1_trans = f_aff.from_pixels()(f_warp_large.from_pixels()).pixels() - f2_trans = f_warp_small.from_pixels()(f1_trans.from_pixels()).pixels() - - magn_field = f_warp_small - - src_f1, _, src_nonzeros_f1 = self._get_warped(src, f1_trans) - src_f2, _, src_nonzeros_f2 = self._get_warped(src, f2_trans) - tgt_f1, _, tgt_nonzeros_f1 = self._get_warped(tgt, f1_trans) - - src_zeros_f1 = ~src_nonzeros_f1 - src_zeros_f2 = ~src_nonzeros_f2 - tgt_zeros_f1 = ~tgt_nonzeros_f1 - - src_enc = self.model(src) - src_f1_enc = self.model(src_f1) - - src_enc_f1 = torch.nn.functional.pad(src_enc, (1, 1, 1, 1), value=0.0) - src_enc_f1 = ( - torch.nn.functional.pad(f1_trans, (1, 1, 1, 1), mode="replicate") # type: ignore - .from_pixels() - .sample(src_enc_f1, padding_mode="border") - ) - src_enc_f1 = torch.nn.functional.pad(src_enc_f1, (-1, -1, -1, -1), value=0.0) - - equi_diff = (src_enc_f1 - src_f1_enc).abs() - equi_loss = equi_diff[src_zeros_f1 != 0].sum() - equi_loss = equi_diff.sum() / equi_diff.size(0) - equi_diff_map = equi_diff.clone() - equi_diff_map[src_zeros_f1] = 0 - - src_f2_enc = self.model(src_f2) - tgt_f1_enc = self.model(tgt_f1) - - pre_diff = (src_f1_enc - tgt_f1_enc).abs() - - pre_tissue_mask = ~(tgt_zeros_f1 | src_zeros_f1) - pre_loss = pre_diff[..., pre_tissue_mask].sum() / pre_diff.size(0) - pre_diff_masked = pre_diff.clone() - pre_diff_masked[..., pre_tissue_mask == 0] = 0 - - post_tissue_mask = ~(tgt_zeros_f1 | src_zeros_f2) - post_magn_mask = (magn_field.abs().max(1, keepdim=True)[0] > self.field_magn_thr).tensor_() - - post_diff_map = (src_f2_enc - tgt_f1_enc).abs() - post_mask = post_magn_mask * post_tissue_mask - if post_mask.sum() < 256: - return None - - post_loss = post_diff_map[..., post_mask].sum() / post_diff_map.size(0) - - post_diff_masked = post_diff_map.clone() - post_diff_masked[..., post_mask == 0] = 0 - - loss = pre_loss - post_loss * self.post_weight + equi_loss * self.equivar_weight - self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True) - if log_row: - self.log_results( - mode, - sample_name, - src=src, - src_enc=src_enc, - src_f1=src_f1, - src_enc_f1=src_enc_f1, - src_f1_enc=src_f1_enc, - src_f2_enc=src_f2_enc, - tgt_f1=tgt_f1, - tgt_f1_enc=tgt_f1_enc, - field=seed_field.tensor_(), - equi_diff_map=equi_diff_map, - post_diff_masked=post_diff_masked, - pre_diff_masked=pre_diff_masked, - ) - return loss - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.val_log_row_interval == 0 - sample_name = f"{batch_idx // self.val_log_row_interval}" - - with torchfields.set_identity_mapping_cache(True, clear_cache=False): - loss = self.compute_metroem_loss( - batch=batch, mode="val", log_row=log_row, sample_name=sample_name - ) - return loss - - -@builder.register("BaseEncoderRegime", versions="==0.0.0") -@attrs.mutable(eq=False) -class BaseEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors - model: torch.nn.Module - lr: float - train_log_row_interval: int = 200 - val_log_row_interval: int = 25 - field_magn_thr: float = 1 - post_weight: float = 0.5 - zero_value: float = 0 - zero_conserve_weight: float = 0.5 - worst_val_loss: float = attrs.field(init=False, default=0) - worst_val_sample: dict = attrs.field(init=False, factory=dict) - worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - - equivar_weight: float = 1.0 - equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) - equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) - - def __attrs_pre_init__(self): - super().__init__() - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer - - def validation_epoch_end(self, _): - log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.train_log_row_interval == 0 - loss = self.compute_metroem_loss(batch=batch, mode="train", log_row=log_row) - return loss - - def _get_warped(self, img, field): - img_warped = field.field().from_pixels()(img) - zeros_warped = field.field().from_pixels()((img == self.zero_value).float()) > 0.1 - img_warped[zeros_warped] = 0 - return img_warped, zeros_warped - - def compute_metroem_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): - src = batch["images"]["src"] - tgt = batch["images"]["tgt"] - - if ((src == self.zero_value) + (tgt == self.zero_value)).bool().sum() / src.numel() > 0.4: - return None - - seed_field = batch["field"] - seed_field = ( - seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) - ) - - f_aff = ( - einops.rearrange( - tensor_ops.transform.get_affine_field( - size=src.shape[-1], - rot_deg=self.equivar_rot_deg_distr(), - scale=self.equivar_scale_distr(), - shear_x_deg=self.equivar_shear_deg_distr(), - shear_y_deg=self.equivar_shear_deg_distr(), - trans_x_px=self.equivar_trans_px_distr(), - trans_y_px=self.equivar_trans_px_distr(), - ), - "C X Y Z -> Z C X Y", - ) - .field() # type: ignore - .pixels() - .to(seed_field.device) - ) - f1_trans = torch.tensor(f_aff.from_pixels()(seed_field.field().from_pixels()).pixels()) - f2_trans = torch.tensor( - seed_field.field() - .from_pixels()(f1_trans.field().from_pixels()) # type: ignore - .pixels() - ) - - src_f1, src_zeros_f1 = self._get_warped(src, f1_trans) - src_f2, src_zeros_f2 = self._get_warped(src, f2_trans) - tgt_f1, tgt_zeros_f1 = self._get_warped(tgt, f1_trans) - - src_enc = self.model(src) - src_enc_f1 = f1_trans.field().from_pixels()(src_enc) # type: ignore - src_f1_enc = self.model(src_f1) - - equi_diff = (src_enc_f1 - src_f1_enc).abs() - equi_loss = equi_diff[src_zeros_f1 == 0].sum() - equi_diff_map = equi_diff.clone() - equi_diff_map[src_zeros_f1] = 0 - - src_f2_enc = self.model(src_f2) - tgt_f1_enc = self.model(tgt_f1) - - pre_diff = (src_f1_enc - tgt_f1_enc).abs() - - pre_tissue_mask = ( - tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f1, width=5) == 0 - ) - pre_loss = pre_diff[..., pre_tissue_mask].sum() - pre_diff_masked = pre_diff.clone() - pre_diff_masked[..., pre_tissue_mask == 0] = 0 - - post_tissue_mask = ( - tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f2, width=5) == 0 - ) - - post_magn_mask = seed_field.abs().max(1)[0] > self.field_magn_thr - post_magn_mask[..., 0:10, :] = 0 - post_magn_mask[..., -10:, :] = 0 - post_magn_mask[..., :, 0:10] = 0 - post_magn_mask[..., :, -10:] = 0 - - post_diff_map = (src_f2_enc - tgt_f1_enc).abs() - post_mask = post_magn_mask * post_tissue_mask - if post_mask.sum() < 256: - return None - - post_loss = post_diff_map[..., post_mask].sum() - - post_diff_masked = post_diff_map.clone() - post_diff_masked[..., post_mask == 0] = 0 - - loss = pre_loss - post_loss * self.post_weight + equi_loss * self.equivar_weight - self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True) - if log_row: - log_results( - mode, - sample_name, - src=src, - src_enc=src_enc, - src_f1=src_f1, - src_enc_f1=src_enc_f1, - src_f1_enc=src_f1_enc, - src_f2_enc=src_f2_enc, - tgt_f1=tgt_f1, - tgt_f1_enc=tgt_f1_enc, - field=torch.tensor(seed_field), - equi_diff_map=equi_diff_map, - post_diff_masked=post_diff_masked, - pre_diff_masked=pre_diff_masked, - ) - return loss - - def compute_metroem_loss_old( - self, batch: dict, mode: str, log_row: bool, sample_name: str = "" - ): - src = batch["images"]["src"] - tgt = batch["images"]["tgt"] - - field = batch["field"] - - tgt_zeros = tensor_ops.mask.kornia_dilation(tgt == self.zero_value, width=3) - src_zeros = tensor_ops.mask.kornia_dilation(src == self.zero_value, width=3) - - pre_tissue_mask = (src_zeros + tgt_zeros) == 0 - if pre_tissue_mask.sum() / src.numel() < 0.4: - return None - - zero_magns = 0 - tgt_enc = self.model(tgt) - zero_magns += tgt_enc[tgt_zeros].abs().sum() - - src_warped = field.field().from_pixels()(src) - src_warped_enc = self.model(src_warped) - src_zeros_warped = field.field().from_pixels()(src_zeros.float()) > 0.1 - - zero_magns += src_warped_enc[src_zeros_warped].abs().sum() - - # src_enc = (~(field.field().from_pixels()))(src_warped_enc) - src_enc = self.model(src) - - pre_diff = (src_enc - tgt_enc).abs() - pre_loss = pre_diff[..., pre_tissue_mask].sum() - pre_diff_masked = pre_diff.clone() - pre_diff_masked[..., pre_tissue_mask == 0] = 0 - - post_tissue_mask = ( - tensor_ops.mask.kornia_dilation(src_zeros_warped + tgt_zeros, width=5) == 0 - ) - post_magn_mask = field.abs().sum(1) > self.field_magn_thr - - post_magn_mask[..., 0:10, :] = 0 - post_magn_mask[..., -10:, :] = 0 - post_magn_mask[..., :, 0:10] = 0 - post_magn_mask[..., :, -10:] = 0 - post_diff_map = (src_warped_enc - tgt_enc).abs() - post_mask = post_magn_mask * post_tissue_mask - post_diff_masked = post_diff_map.clone() - post_diff_masked[..., post_tissue_mask == 0] = 0 - if post_mask.sum() < 256: - return None - - post_loss = post_diff_map[..., post_mask].sum() - loss = pre_loss - post_loss * self.post_weight + zero_magns * self.zero_conserve_weight - self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_zcons", zero_magns, on_step=True, on_epoch=True) - if log_row: - log_results( - mode, - sample_name, - src=src, - src_enc=src_enc, - src_warped_enc=src_warped_enc, - tgt=tgt, - tgt_enc=tgt_enc, - field=field, - post_diff_masked=post_diff_masked, - pre_diff_masked=pre_diff_masked, - ) - return loss - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.val_log_row_interval == 0 - sample_name = f"{batch_idx // self.val_log_row_interval}" - - loss = self.compute_metroem_loss( - batch=batch, mode="val", log_row=log_row, sample_name=sample_name - ) - return loss diff --git a/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py deleted file mode 100644 index d2823c8a6..000000000 --- a/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener.py +++ /dev/null @@ -1,273 +0,0 @@ -# pragma: no cover - -import random -from typing import List, Optional, Union - -import attrs -import pytorch_lightning as pl -import torch -import torchvision -import wandb -from PIL import Image - -from zetta_utils import builder, tensor_ops -from zetta_utils.training.lightning.train import distributed_available - - -@builder.register("EncodingCoarsenerRegime", versions="==0.0.0") -@attrs.mutable(eq=False) -class EncodingCoarsenerRegime(pl.LightningModule): # pylint: disable=too-many-ancestors - encoder: torch.nn.Module - decoder: torch.nn.Module - lr: float - apply_counts: List[int] = [1] - invar_angle_range: List[Union[int, float]] = [1, 180] - invar_mse_weight: float = 0.0 - diffkeep_angle_range: List[Union[int, float]] = [1, 180] - diffkeep_weight: float = 0.0 - min_nonz_frac: float = 0.2 - worst_val_loss: float = attrs.field(init=False, default=0) - worst_val_sample: dict = attrs.field(init=False, factory=dict) - worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - - def __attrs_pre_init__(self): - super().__init__() - - def log_results(self, mode: str, title_suffix: str = "", **kwargs): - if not self.logger: - return - self.logger.log_image( - f"results/{mode}_{title_suffix}_slider", - images=[wandb.Image(v.squeeze(), caption=k) for k, v in kwargs.items()], - ) - # images = torchvision.utils.make_grid([img[0] for img, _ in img_spec]) - # caption = ",".join(cap for _, cap in img_spec) + title_suffix - # wandb.log({f"results/{mode}_row": [wandb.Image(images, caption)]}) - - def on_validation_epoch_end(self): - self.log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - interval = 25 - log_row = batch_idx % interval == 0 - sample_name = f"{batch_idx // interval}" - data_in = batch["data_in"] - - losses = [ - self.compute_loss(data_in, count, "val", log_row, sample_name=sample_name) - for count in self.apply_counts - ] - losses_clean = [l for l in losses if l is not None] - loss = sum(losses_clean) - self.log( - "loss/train", - loss, - on_step=True, - on_epoch=True, - sync_dist=distributed_available(), - rank_zero_only=True, - ) - return loss - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - data_in = batch["data_in"] - log_row = batch_idx % 100 == 0 - - losses = [ - self.compute_loss(data_in, count, "train", log_row) for count in self.apply_counts - ] - losses_clean = [l for l in losses if l is not None] - loss = sum(losses_clean) - self.log( - "loss/train", - loss, - on_step=True, - on_epoch=True, - sync_dist=distributed_available(), - rank_zero_only=True, - ) - return loss - - def compute_loss( - self, - data_in: torch.Tensor, - apply_count: int, - mode: str, - log_row: bool, - sample_name: str = "", - ): - setting_name = f"{mode}_apply{apply_count}" - - enc = data_in - for _ in range(apply_count): - enc = self.encoder(enc) - - recons = enc - for _ in range(apply_count): - recons = self.decoder(recons) - - loss_map_recons = (data_in - recons) ** 2 - - loss_recons = loss_map_recons.mean() - if log_row: - self.log_results( - f"{setting_name}_recons", - sample_name, - data_in=data_in, - naive=tensor_ops.interpolate( - data_in, - size=(enc.shape[-2], enc.shape[-1]), - mode="img", - unsqueeze_input_to=4, - ), - enc=enc, - recons=recons, - loss_map_recons=loss_map_recons, - ) - - self.log( - f"loss/{setting_name}_recons", - loss_recons, - on_step=True, - on_epoch=True, - sync_dist=distributed_available(), - rank_zero_only=True, - ) - - if self.invar_mse_weight > 0: - loss_inv = self.compute_invar_loss( - data_in, enc, apply_count, log_row, setting_name, sample_name - ) - self.log( - f"loss/{setting_name}_inv", - loss_inv, - on_step=True, - on_epoch=True, - sync_dist=distributed_available(), - rank_zero_only=True, - ) - else: - loss_inv = 0 - - if self.diffkeep_weight > 0: - loss_diffkeep = self.compute_diffkeep_loss(data_in, enc, log_row, sample_name) - self.log( - f"loss/{setting_name}_diffkeep", - loss_diffkeep, - on_step=True, - on_epoch=True, - sync_dist=distributed_available(), - rank_zero_only=True, - ) - else: - loss_diffkeep = 0 - - loss = ( - loss_recons + self.invar_mse_weight * loss_inv + self.diffkeep_weight * loss_diffkeep - ) - - self.log( - f"loss/{setting_name}", - loss, - on_step=True, - on_epoch=True, - sync_dist=distributed_available(), - rank_zero_only=True, - ) - - if mode == "val": - if loss > self.worst_val_loss: - self.worst_val_loss = loss - self.worst_val_sample = { - "data_in": data_in, - "enc": enc, - "recons": recons, - "loss_map": loss_map_recons, - } - - return loss - - def compute_invar_loss( - self, - data_in: torch.Tensor, - enc: torch.Tensor, - apply_count: int, - log_row: bool, - setting_name: str, - sample_name: str = "", - ): - angle = random.uniform(self.invar_angle_range[0], self.invar_angle_range[1]) - data_in_rot = torchvision.transforms.functional.rotate( - img=data_in, - angle=angle, - interpolation=Image.BILINEAR, - ) - enc_rot = torchvision.transforms.functional.rotate( - img=enc, - angle=angle, - interpolation=Image.BILINEAR, - ) - rot_input_enc = data_in_rot - for _ in range(apply_count): - rot_input_enc = self.encoder(rot_input_enc) - loss_map_invar = (enc_rot - rot_input_enc) ** 2 - result = loss_map_invar.mean() - if log_row: - self.log_results( - f"{setting_name}_inv", - title_suffix=sample_name, - data_in=data_in, - rot_input_enc=rot_input_enc, - enc_rot=enc_rot, - loss_map_invar=loss_map_invar, - ) - return result - - def compute_diffkeep_loss( - self, - data_in: torch.Tensor, - enc: torch.Tensor, - log_row: bool, - setting_name: str, - sample_name: str = "", - ): - angle = random.uniform(self.diffkeep_angle_range[0], self.diffkeep_angle_range[1]) - data_in_rot = torchvision.transforms.functional.rotate( - img=data_in, - angle=angle, - interpolation=Image.BILINEAR, - ) - enc_rot = torchvision.transforms.functional.rotate( - img=enc, - angle=angle, - interpolation=Image.BILINEAR, - ) - data_in_diff = (data_in - data_in_rot) ** 2 - enc_diff = (enc - enc_rot) ** 2 - data_in_diff_downs = tensor_ops.interpolate( - data_in_diff, size=enc_diff.shape[-2:], mode="img", unsqueeze_input_to=4 - ) - loss_map_diffkeep = (data_in_diff_downs - enc_diff).abs() - - result = loss_map_diffkeep.mean() - if log_row: - self.log_results( - f"{setting_name}_diffkeep", - title_suffix=sample_name, - data_in=data_in, - enc_diff=enc_diff, - data_in_diff_downs=data_in_diff_downs, - loss_map_diffkeep=loss_map_diffkeep, - ) - return result - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer diff --git a/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py deleted file mode 100644 index fdbd5f1fd..000000000 --- a/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_gen_x1.py +++ /dev/null @@ -1,179 +0,0 @@ -# pragma: no cover -# pylint: disable=too-many-locals, no-self-use - -from typing import Optional - -import attrs -import einops -import pytorch_lightning as pl -import torch -import wandb -from pytorch_lightning import seed_everything - -from zetta_utils import builder, distributions, tensor_ops, viz - - -@builder.register("EncodingCoarsenerGenX1Regime", versions="==0.0.0") -@attrs.mutable(eq=False) -class EncodingCoarsenerGenX1Regime(pl.LightningModule): # pylint: disable=too-many-ancestors - encoder: torch.nn.Module - decoder: torch.nn.Module - lr: float - train_log_row_interval: int = 200 - val_log_row_interval: int = 25 - field_magn_thr: float = 1 - zero_value: float = 0 - worst_val_loss: float = attrs.field(init=False, default=0) - worst_val_sample: dict = attrs.field(init=False, factory=dict) - worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - - min_data_thr: float = 0.85 - equivar_weight: float = 1.0 - significance_weight: float = 0.5 - centering_weight: float = 0.5 - equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) - equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) - - def __attrs_pre_init__(self): - super().__init__() - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer - - def log_results(self, mode: str, title_suffix: str = "", **kwargs): - if not self.logger: - return - self.logger.log_image( - f"results/{mode}_{title_suffix}_slider", - images=[ - wandb.Image(viz.rendering.Renderer()(v.squeeze()), caption=k) - for k, v in kwargs.items() - ], - ) - - def validation_epoch_start(self, _): - seed_everything(42) - - def on_validation_epoch_end(self): - self.log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None - seed_everything(None) - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.train_log_row_interval == 0 - loss = self.compute_gen_x1_loss(batch=batch, mode="train", log_row=log_row) - return loss - - def compute_gen_x1_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): - src = batch["src"] - seed_field = batch["field"] - seed_field = ( - seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) - ).field() - if ((src == self.zero_value)).bool().sum() / src.numel() > self.min_data_thr: - return None - equivar_rot = self.equivar_rot_deg_distr() - - equivar_field = ( - einops.rearrange( - tensor_ops.transform.get_affine_field( - size=src.shape[-1], - rot_deg=equivar_rot, - ), - "C X Y Z -> Z C X Y", - ) - .field() # type: ignore - .to(src.device)(seed_field.from_pixels()) - .pixels() - ) - - equivar_field_inv = (~seed_field.from_pixels())( - tensor_ops.transform.get_affine_field( - size=src.shape[-1], - rot_deg=-equivar_rot, - ) - .field() # type: ignore - .to(src.device) - ).pixels() - - src_warped = equivar_field.from_pixels()(src) - enc_warped = self.encoder(src_warped) - enc = tensor_ops.interpolate( - equivar_field_inv, scale_factor=enc_warped.shape[-1] / src.shape[-1], mode="field" - ).from_pixels()(enc_warped) - dec = self.decoder(enc) - - tissue_final = equivar_field_inv.from_pixels()( - equivar_field.from_pixels()(torch.ones_like(src)) - ) - - diff_map = (src - dec).abs() - diff_loss = diff_map[tissue_final != 0].sum() - diff_map[tissue_final == 0] = 0 - wanted_significance = torch.nn.functional.max_pool2d( - src.abs(), kernel_size=int(src.shape[-1] / enc.shape[-1]) - ) - enc_error = torch.nn.functional.avg_pool2d( - diff_map.abs(), kernel_size=int(src.shape[-1] / enc.shape[-1]) - ) - significance_loss_map = (enc_error - (wanted_significance - enc.abs().mean(1))).abs() - tissue_final_downs = torch.nn.functional.max_pool2d( - tissue_final, kernel_size=int(src.shape[-1] / enc.shape[-1]) - ) - significance_loss = significance_loss_map[..., tissue_final_downs.squeeze() == 1].sum() - significance_loss_map[..., tissue_final_downs.squeeze() == 0] = 0 - centering_loss = enc.sum((0, 2, 3)).abs().sum() - loss = ( - diff_loss - + self.significance_weight * significance_loss - + self.centering_weight * centering_loss - ) - self.log(f"loss/{mode}_significance", significance_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_diff", diff_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_centering", centering_loss, on_step=True, on_epoch=True) - - if log_row: - self.log_results( - mode, - sample_name, - src=src, - src_warped=src_warped, - src_warped_abs=src_warped.abs(), - enc_warped_naive=tensor_ops.interpolate( - src_warped, - size=(enc_warped.shape[-2], enc_warped.shape[-1]), - mode="img", - unsqueeze_input_to=4, - ), - enc_warped=enc_warped, - enc_warped_abs=enc_warped.abs().mean(1), - enc=enc, - dec=dec, - diff_map=diff_map, - significance_loss_map=significance_loss_map, - enc_error=enc_error, - waned_significance=wanted_significance, - waned_zeros=wanted_significance == 0, - enc_zeros=enc.abs().mean(1) < 0.01, - # tissue_final_downs=tissue_final_downs, - ) - - return loss - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.val_log_row_interval == 0 - sample_name = f"{batch_idx // self.val_log_row_interval}" - - loss = self.compute_gen_x1_loss( - batch=batch, mode="val", log_row=log_row, sample_name=sample_name - ) - return loss diff --git a/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py deleted file mode 100644 index 08a8f495f..000000000 --- a/zetta_utils/training/lightning/regimes/alignment/deprecated/encoding_coarsener_highres.py +++ /dev/null @@ -1,356 +0,0 @@ -# pragma: no cover - -import random -from typing import List, Optional - -import attrs -import PIL -import pytorch_lightning as pl -import torch -import torchfields -import torchvision -import wandb - -import zetta_utils as zu -from zetta_utils import builder, convnet, tensor_ops # pylint: disable=unused-import - - -# TODO: Refactor function -def warp_by_px(image, direction, pixels): - - fields = torch.zeros( - 1, 2, image.shape[-2], image.shape[-1], device=image.device - ).field() # type: ignore - - if direction == 0: - fields[0, 0, :, :] = 0 - fields[0, 1, :, :] = pixels - elif direction == 1: - fields[0, 0, :, :] = pixels - fields[0, 1, :, :] = 0 - elif direction == 2: - fields[0, 0, :, :] = 0 - fields[0, 1, :, :] = -pixels - elif direction == 3: - fields[0, 0, :, :] = -pixels - fields[0, 1, :, :] = 0 - elif direction == 4: - fields[0, 0, :, :] = pixels ** 0.5 - fields[0, 1, :, :] = pixels ** 0.5 - elif direction == 5: - fields[0, 0, :, :] = -(pixels ** 0.5) - fields[0, 1, :, :] = pixels ** 0.5 - elif direction == 6: - fields[0, 0, :, :] = pixels ** 0.5 - fields[0, 1, :, :] = -(pixels ** 0.5) - elif direction == 7: - fields[0, 0, :, :] = -(pixels ** 0.5) - fields[0, 1, :, :] = -(pixels ** 0.5) - return ( - fields.from_pixels().expand( - image.shape[0], fields.shape[1], fields.shape[2], fields.shape[3] - ) - )(image) - - -# TODO: Refactor function -def center_crop_norm(image): - norm = torchvision.transforms.Normalize(0, 1) - crop = torchvision.transforms.CenterCrop(image.shape[-2] // 2) - return crop(norm(image)) - - -@builder.register("EncodingCoarsenerHighRes", versions="==0.0.0") -@attrs.mutable(eq=False) -class EncodingCoarsenerHighRes(pl.LightningModule): # pylint: disable=too-many-ancestors - encoder: torch.nn.Module - decoder: torch.nn.Module - lr: float - encoder_ckpt_path: Optional[str] = None - decoder_ckpt_path: Optional[str] = None - apply_counts: List[int] = [1] - residual_range: List[float] = [0.1, 5.0] - residual_weight: float = 0.0 - field_scale: List[float] = [1.0, 1.0] - field_weight: float = 0.0 - meanstd_weight: float = 0.0 - invar_weight: float = 0.0 - min_nonz_frac: float = 0.2 - worst_val_loss: float = attrs.field(init=False, default=0) - worst_val_sample: dict = attrs.field(init=False, default=attrs.Factory(dict)) - worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - - def __attrs_pre_init__(self): - super().__init__() - - def __attrs_post_init__(self): - if self.encoder_ckpt_path is not None: - convnet.utils.load_weights_file(self, self.encoder_ckpt_path, ["encoder"]) - - if self.decoder_ckpt_path is not None: - convnet.utils.load_weights_file(self, self.decoder_ckpt_path, ["decoder"]) - - def log_results(self, mode: str, title_suffix: str = "", **kwargs): - if not self.logger: - return - self.logger.log_image( - f"results/{mode}_{title_suffix}_slider", - images=[wandb.Image(v.squeeze(), caption=k) for k, v in kwargs.items()], - ) - # images = torchvision.utils.make_grid([img[0] for img, _ in img_spec]) - # caption = ",".join(cap for _, cap in img_spec) + title_suffix - # wandb.log({f"results/{mode}_row": [wandb.Image(images, caption)]}) - - def on_validation_epoch_end(self): - self.log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - interval = 25 - log_row = batch_idx % interval == 0 - sample_name = f"{batch_idx // interval}" - data_in = batch["image"]["data_in"] - field_in = batch["field"]["data_in"] - - losses = [ - self.compute_loss(data_in, field_in, count, "val", log_row, sample_name=sample_name) - for count in self.apply_counts - ] - losses_clean = [l for l in losses if l is not None] - if len(losses_clean) == 0: - loss = None - else: - loss = sum(losses_clean) - self.log("loss/train", loss, on_step=True, on_epoch=True) - return loss - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - data_in = batch["image"]["data_in"] - field_in = batch["field"]["data_in"] - log_row = batch_idx % 100 == 0 - - losses = [ - self.compute_loss(data_in, field_in, count, "train", log_row) - for count in self.apply_counts - ] - losses_clean = [l for l in losses if l is not None] - if len(losses_clean) == 0: - loss = None - else: - loss = sum(losses_clean) - self.log("loss/train", loss, on_step=True, on_epoch=True) - return loss - - def compute_loss( # pylint: disable=too-many-locals, too-many-branches - self, - data_in: torch.Tensor, - field_in: torchfields.Field, - apply_count: int, - mode: str, - log_row: bool, - sample_name: str = "", - ): - setting_name = f"{mode}_apply{apply_count}" - - if (data_in != 0).sum() / data_in.numel() < self.min_nonz_frac: - loss = None - else: - enc = data_in - for _ in range(apply_count): - enc = self.encoder(enc) - - recons = enc - for _ in range(apply_count): - recons = self.decoder(recons) - - mean_loss = (torch.mean(data_in) - torch.mean(enc)) ** 2 - std_loss = (torch.std(data_in) - torch.std(enc)) ** 2 - - loss_map_recons = (data_in - recons) ** 2 - - loss_recons = loss_map_recons.mean() - - if log_row: - self.log_results( - f"{setting_name}_recons", - sample_name, - data_in=data_in[0:1, :, :, :], - naive=zu.tensor_ops.interpolate( - data_in[0:1, :, :, :], size=(enc.shape[-2], enc.shape[-1]), mode="img" - ), - enc=enc[0:1, :, :, :], - recons=recons[0:1, :, :, :], - loss_map_recons=loss_map_recons[0:1, :, :, :], - ) - - self.log(f"loss/{setting_name}_recons", loss_recons, on_step=True, on_epoch=True) - - field_in *= random.uniform(self.field_scale[0], self.field_scale[1]) - - if self.field_weight > 0: - loss_field = self.compute_field_loss( - data_in, field_in, enc, apply_count, log_row, setting_name, sample_name - ) - self.log(f"loss/{setting_name}_field", loss_field, on_step=True, on_epoch=True) - else: - loss_field = 0 - - if self.invar_weight > 0: - loss_invar = self.compute_invar_loss( - data_in, enc, apply_count, log_row, setting_name, sample_name - ) - self.log(f"loss/{setting_name}_invar", loss_invar, on_step=True, on_epoch=True) - else: - loss_invar = 0 - - if self.residual_weight > 0: - loss_res = self.compute_residual_loss( - data_in, enc, apply_count, log_row, setting_name, sample_name - ) - self.log(f"loss/{setting_name}_res", loss_res, on_step=True, on_epoch=True) - else: - loss_res = 0 - - loss = ( - loss_recons - + self.meanstd_weight * mean_loss - + self.meanstd_weight * std_loss - + +self.residual_weight * loss_res - + self.field_weight * loss_field - + self.invar_weight * loss_invar - ) - - self.log(f"loss/{setting_name}", loss, on_step=True, on_epoch=True) - - if mode == "val": - if loss > self.worst_val_loss: - self.worst_val_loss = loss - self.worst_val_sample = { - "data_in": data_in, - "enc": enc, - "recons": recons, - "loss_map": loss_map_recons, - } - - return loss - - def compute_field_loss( - self, - data_in: torch.Tensor, - field_in: torchfields.Field, - enc: torch.Tensor, - apply_count: int, - log_row: bool, - setting_name: str, - sample_name: str = "", - ): - field_in_apply = 2 * torchfields.Field(field_in.cpu()).cuda().from_pixels() - enc_warped = field_in_apply(data_in) - for _ in range(apply_count): - enc_warped = self.encoder(enc_warped) - warped_enc = torch.nn.functional.interpolate( - field_in_apply, scale_factor=(1 / 2) ** apply_count, mode="bilinear" - )(enc) - - loss_field = (warped_enc - enc_warped) ** 2 - - result = loss_field.mean() - if log_row: - self.log_results( - f"{setting_name}_field", - title_suffix=sample_name, - data_in=data_in[0], - field_in=field_in[0], - enc=enc[0], - enc_warped=enc_warped[0], - warped_enc=warped_enc[0], - loss_field=loss_field[0], - ) - return result - - def compute_residual_loss( # pylint: disable=too-many-locals - self, - data_in: torch.Tensor, - enc: torch.Tensor, - apply_count: int, - log_row: bool, - setting_name: str, - sample_name: str = "", - ): - px_a = random.uniform(self.residual_range[0], self.residual_range[1]) - px_b = random.uniform(self.residual_range[0], self.residual_range[1]) - - px_a *= 2 ** apply_count - px_b *= 2 ** apply_count - - direction = random.choice([0, 1, 2, 3, 4, 5, 6, 7]) - outputs_a = warp_by_px(data_in, direction, px_a) - outputs_b = warp_by_px(data_in, direction, px_b) - for _ in range(apply_count): - outputs_a = self.encoder(outputs_a) - outputs_b = self.encoder(outputs_b) - outputs_a = center_crop_norm(outputs_a) - outputs_b = center_crop_norm(outputs_b) - encodings = center_crop_norm(enc).expand(outputs_a.shape) - - loss_a = (encodings - outputs_a) ** 2 - loss_b = (encodings - outputs_b) ** 2 - - loss_res = (loss_a.mean() * px_a - loss_b.mean() / px_b) ** 2 - - result = loss_res.mean() - if log_row: - self.log_results( - f"{setting_name}_res", - title_suffix=sample_name, - data_in=data_in, - loss_a=loss_a, - loss_b=loss_b, - ) - return result - - def compute_invar_loss( - self, - data_in: torch.Tensor, - enc: torch.Tensor, - apply_count: int, - log_row: bool, - setting_name: str, - sample_name: str = "", - ): - angle = random.uniform(-180, 180) - data_in_rot = torchvision.transforms.functional.rotate( - img=data_in, - angle=angle, - interpolation=PIL.Image.BILINEAR, - ) - enc_rot = torchvision.transforms.functional.rotate( - img=enc, - angle=angle, - interpolation=PIL.Image.BILINEAR, - ) - rot_input_enc = data_in_rot - for _ in range(apply_count): - rot_input_enc = self.encoder(rot_input_enc) - loss_map_invar = (enc_rot - rot_input_enc) ** 2 - result = loss_map_invar.mean() - if log_row: - self.log_results( - f"{setting_name}_inv", - title_suffix=sample_name, - data_in=data_in, - rot_input_enc=rot_input_enc, - enc_rot=enc_rot, - loss_map_invar=loss_map_invar, - ) - return result - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer diff --git a/zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py deleted file mode 100644 index 39d9e2adc..000000000 --- a/zetta_utils/training/lightning/regimes/alignment/deprecated/minima_encoder.py +++ /dev/null @@ -1,254 +0,0 @@ -# pragma: no cover -# pylint: disable=too-many-locals - -from typing import Optional - -import attrs -import einops -import pytorch_lightning as pl -import torch -import wandb - -from zetta_utils import builder, distributions, tensor_ops, viz - - -@builder.register("MinimaEncoderRegime", versions="==0.0.0") -@attrs.mutable(eq=False) -class MinimaEncoderRegime(pl.LightningModule): # pylint: disable=too-many-ancestors - model: torch.nn.Module - lr: float - train_log_row_interval: int = 200 - val_log_row_interval: int = 25 - field_magn: float = 1 - - worst_val_loss: float = attrs.field(init=False, default=0) - worst_val_sample: dict = attrs.field(init=False, factory=dict) - worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - - equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) - equivar_shear_deg_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - equivar_scale_distr: distributions.Distribution = distributions.uniform_distr(0.9, 1.1) - - def __attrs_pre_init__(self): - super().__init__() - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer - - def log_results(self, mode: str, title_suffix: str = "", **kwargs): - if not self.logger: - return - self.logger.log_image( - f"results/{mode}_{title_suffix}_slider", - images=[ - wandb.Image(viz.rendering.Renderer()(v.squeeze()), caption=k) - for k, v in kwargs.items() - ], - ) - - def on_validation_epoch_end(self): - self.log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.train_log_row_interval == 0 - loss = self.compute_minima_loss(batch=batch, mode="train", log_row=log_row) - return loss - - def _get_warped(self, img, field): - img_warped = field.field().from_pixels()(img) - zeros_warped = field.field().from_pixels()((img == self.zero_value).float()) > 0.1 - img_warped[zeros_warped] = 0 - return img_warped, zeros_warped - - def compute_minima_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): - src = batch["images"]["src"] - tgt = batch["images"]["tgt"] - - if ((src == self.zero_value) + (tgt == self.zero_value)).bool().sum() / src.numel() > 0.4: - return None - - seed_field = batch["field"] - seed_field = ( - seed_field * self.field_magn_thr / torch.quantile(seed_field.abs().max(1)[0], 0.5) - ) - - f_aff = ( - einops.rearrange( - tensor_ops.transform.get_affine_field( - size=src.shape[-1], - rot_deg=self.equivar_rot_deg_distr(), - scale=self.equivar_scale_distr(), - shear_x_deg=self.equivar_shear_deg_distr(), - shear_y_deg=self.equivar_shear_deg_distr(), - trans_x_px=self.equivar_trans_px_distr(), - trans_y_px=self.equivar_trans_px_distr(), - ), - "C X Y Z -> Z C X Y", - ) - .field() # type: ignore - .pixels() - .to(seed_field.device) - ) - f1_trans = torch.tensor(f_aff.from_pixels()(seed_field.field().from_pixels()).pixels()) - f2_trans = torch.tensor( - seed_field.field() - .from_pixels()(f1_trans.field().from_pixels()) # type: ignore - .pixels() - ) - - src_f1, src_zeros_f1 = self._get_warped(src, f1_trans) - src_f2, src_zeros_f2 = self._get_warped(src, f2_trans) - tgt_f1, tgt_zeros_f1 = self._get_warped(tgt, f1_trans) - - src_enc = self.model(src) - src_enc_f1 = f1_trans.field().from_pixels()(src_enc) # type: ignore - src_f1_enc = self.model(src_f1) - - equi_diff = (src_enc_f1 - src_f1_enc).abs() - equi_loss = equi_diff[src_zeros_f1 == 0].sum() - equi_diff_map = equi_diff.clone() - equi_diff_map[src_zeros_f1] = 0 - - src_f2_enc = self.model(src_f2) - tgt_f1_enc = self.model(tgt_f1) - - pre_diff = (src_f1_enc - tgt_f1_enc).abs() - - pre_tissue_mask = ( - tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f1, width=5) == 0 - ) - pre_loss = pre_diff[..., pre_tissue_mask].sum() - pre_diff_masked = pre_diff.clone() - pre_diff_masked[..., pre_tissue_mask == 0] = 0 - - post_tissue_mask = ( - tensor_ops.mask.kornia_dilation(tgt_zeros_f1 + src_zeros_f2, width=5) == 0 - ) - - post_magn_mask = seed_field.abs().max(1)[0] > self.field_magn_thr - post_magn_mask[..., 0:10, :] = 0 - post_magn_mask[..., -10:, :] = 0 - post_magn_mask[..., :, 0:10] = 0 - post_magn_mask[..., :, -10:] = 0 - - post_diff_map = (src_f2_enc - tgt_f1_enc).abs() - post_mask = post_magn_mask * post_tissue_mask - if post_mask.sum() < 256: - return None - - post_loss = post_diff_map[..., post_mask].sum() - - post_diff_masked = post_diff_map.clone() - post_diff_masked[..., post_mask == 0] = 0 - - loss = pre_loss - post_loss * self.post_weight + equi_loss * self.equivar_weight - self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_equi", equi_loss, on_step=True, on_epoch=True) - if log_row: - self.log_results( - mode, - sample_name, - src=src, - src_enc=src_enc, - src_f1=src_f1, - src_enc_f1=src_enc_f1, - src_f1_enc=src_f1_enc, - src_f2_enc=src_f2_enc, - tgt_f1=tgt_f1, - tgt_f1_enc=tgt_f1_enc, - field=torch.tensor(seed_field), - equi_diff_map=equi_diff_map, - post_diff_masked=post_diff_masked, - pre_diff_masked=pre_diff_masked, - ) - return loss - - def compute_minima_loss_old( - self, batch: dict, mode: str, log_row: bool, sample_name: str = "" - ): - src = batch["images"]["src"] - tgt = batch["images"]["tgt"] - - field = batch["field"] - - tgt_zeros = tensor_ops.mask.kornia_dilation(tgt == self.zero_value, width=3) - src_zeros = tensor_ops.mask.kornia_dilation(src == self.zero_value, width=3) - - pre_tissue_mask = (src_zeros + tgt_zeros) == 0 - if pre_tissue_mask.sum() / src.numel() < 0.4: - return None - - zero_magns = 0 - tgt_enc = self.model(tgt) - zero_magns += tgt_enc[tgt_zeros].abs().sum() - - src_warped = field.field().from_pixels()(src) - src_warped_enc = self.model(src_warped) - src_zeros_warped = field.field().from_pixels()(src_zeros.float()) > 0.1 - - zero_magns += src_warped_enc[src_zeros_warped].abs().sum() - - # src_enc = (~(field.field().from_pixels()))(src_warped_enc) - src_enc = self.model(src) - - pre_diff = (src_enc - tgt_enc).abs() - pre_loss = pre_diff[..., pre_tissue_mask].sum() - pre_diff_masked = pre_diff.clone() - pre_diff_masked[..., pre_tissue_mask == 0] = 0 - - post_tissue_mask = ( - tensor_ops.mask.kornia_dilation(src_zeros_warped + tgt_zeros, width=5) == 0 - ) - post_magn_mask = field.abs().sum(1) > self.field_magn_thr - - post_magn_mask[..., 0:10, :] = 0 - post_magn_mask[..., -10:, :] = 0 - post_magn_mask[..., :, 0:10] = 0 - post_magn_mask[..., :, -10:] = 0 - post_diff_map = (src_warped_enc - tgt_enc).abs() - post_mask = post_magn_mask * post_tissue_mask - post_diff_masked = post_diff_map.clone() - post_diff_masked[..., post_tissue_mask == 0] = 0 - if post_mask.sum() < 256: - return None - - post_loss = post_diff_map[..., post_mask].sum() - loss = 0 # TODO - self.log(f"loss/{mode}", loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_pre", pre_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_post", post_loss, on_step=True, on_epoch=True) - self.log(f"loss/{mode}_zcons", zero_magns, on_step=True, on_epoch=True) - if log_row: - self.log_results( - mode, - sample_name, - src=src, - src_enc=src_enc, - src_warped_enc=src_warped_enc, - tgt=tgt, - tgt_enc=tgt_enc, - field=field, - post_diff_masked=post_diff_masked, - pre_diff_masked=pre_diff_masked, - ) - return loss - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.val_log_row_interval == 0 - sample_name = f"{batch_idx // self.val_log_row_interval}" - - loss = self.compute_minima_loss( - batch=batch, mode="val", log_row=log_row, sample_name=sample_name - ) - return loss diff --git a/zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py b/zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py deleted file mode 100644 index 920305a8f..000000000 --- a/zetta_utils/training/lightning/regimes/alignment/deprecated/misalignment_detector.py +++ /dev/null @@ -1,185 +0,0 @@ -# pragma: no cover - -import random -from typing import Optional - -import attrs -import pytorch_lightning as pl -import torch -import torchvision -import wandb - -import zetta_utils as zu -from zetta_utils import builder, convnet, tensor_ops # pylint: disable=unused-import - - -@builder.register("MisalignmentDetectorRegime", versions="==0.0.0") -@attrs.mutable(eq=False) -class MisalignmentDetectorRegime(pl.LightningModule): # pylint: disable=too-many-ancestors - detector: torch.nn.Module - lr: float - max_disp: float - downsample_power: int = 0 - min_nonz_frac: float = 0.2 - worst_val_loss: float = attrs.field(init=False, default=0) - worst_val_sample: dict = attrs.field(init=False, default=attrs.Factory(dict)) - worst_val_sample_idx: Optional[int] = attrs.field(init=False, default=None) - - def __attrs_pre_init__(self): - super().__init__() - - # TODO: factor this out - @staticmethod - def augment_field(field): - random_vertical_flip = torchvision.transforms.RandomVerticalFlip() - random_horizontal_flip = torchvision.transforms.RandomHorizontalFlip() - angle = random.choice([0, 90, 180, 270]) - return random_horizontal_flip( - random_vertical_flip(torchvision.transforms.functional.rotate(field, angle)) - ) - - @staticmethod - def norm_field(field, threshold): - field_i = field[:, 0:1, :, :] - field_j = field[:, 1:2, :, :] - field_norm = (field_i ** 2 + field_j ** 2) ** 0.5 - return torch.clamp(field_norm, 0, threshold) - - def log_results(self, mode: str, title_suffix: str = "", **kwargs): - if not self.logger: - return - self.logger.log_image( - f"results/{mode}_{title_suffix}_slider", - images=[wandb.Image(v.squeeze(), caption=k) for k, v in kwargs.items()], - ) - # images = torchvision.utils.make_grid([img[0] for img, _ in img_spec]) - # caption = ",".join(cap for _, cap in img_spec) + title_suffix - # wandb.log({f"results/{mode}_row": [wandb.Image(images, caption)]}) - - def on_validation_epoch_end(self): - self.log_results( - "val", - "worst", - **self.worst_val_sample, - ) - self.worst_val_loss = 0 - self.worst_val_sample = {} - self.worst_val_sample_idx = None - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - interval = 25 - log_row = batch_idx % interval == 0 - - image = batch["image"]["data_in"] - # field0 = torchfields.Field(batch["field0"]["data_in"]) - # field1 = torchfields.Field(batch["field1"]["data_in"]) - field0 = batch["field0"]["data_in"] - field1 = batch["field1"]["data_in"] - - loss = self.compute_loss(image, field0, field1, "validate", log_row) - if loss is not None: - self.log("loss/train", loss, on_step=True, on_epoch=True) - return loss - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - image = batch["image"]["data_in"] - field0 = batch["field0"]["data_in"] - field1 = batch["field1"]["data_in"] - log_row = batch_idx % 100 == 0 - - loss = self.compute_loss(image, field0, field1, "train", log_row) - if loss is not None: - self.log("loss/train", loss, on_step=True, on_epoch=True) - return loss - - def compute_loss( - self, - image: torch.Tensor, - field0: torch.Tensor, - field1: torch.Tensor, - mode: str, - log_row: bool, - sample_name: str = "", - ): - setting_name = f"{mode}" - - if (image != 0).sum() / image.numel() < self.min_nonz_frac: - loss = None - else: - with torch.no_grad(): - field0 = self.augment_field(field0) - field1 = self.augment_field(field1) - labels = self.norm_field(field0 - field1, self.max_disp) - if self.downsample_power != 0: - image = zu.tensor_ops.interpolate( - image, - size=( - image.shape[-2] // 2 ** self.downsample_power, - image.shape[-1] // 2 ** self.downsample_power, - ), - mode="img", - ) - field0 = zu.tensor_ops.interpolate( - field0, - size=( - field0.shape[-2] // 2 ** self.downsample_power, - field0.shape[-1] // 2 ** self.downsample_power, - ), - mode="img", - ) - field1 = zu.tensor_ops.interpolate( - field1, - size=( - field1.shape[-2] // 2 ** self.downsample_power, - field1.shape[-1] // 2 ** self.downsample_power, - ), - mode="img", - ) - labels = zu.tensor_ops.interpolate( - labels, - size=( - labels.shape[-2] // 2 ** self.downsample_power, - labels.shape[-1] // 2 ** self.downsample_power, - ), - mode="img", - ) - # fields are typed as Tensor - assert hasattr(field0, "field_") - assert hasattr(field1, "field_") - image[:, 0:1, :, :] = field0.field_().from_pixels()(image[:, 0:1, :, :]) - image[:, 1:2, :, :] = field1.field_().from_pixels()(image[:, 1:2, :, :]) - pred_labels = self.detector(image) - - loss_map = (pred_labels - labels) ** 2 - loss = loss_map.mean() - - if log_row: - self.log_results( - f"{setting_name}_recons", - sample_name, - image0=image[0, 0, :, :].detach().cpu(), - image1=image[0, 1, :, :].detach().cpu(), - # field0=field0, - # field1=field1, - pred_labels=pred_labels[0, :, :, :].detach().cpu(), - labels=labels[0, :, :, :].detach().cpu(), - loss_map=loss_map[0, :, :, :].detach().cpu(), - ) - - self.log(f"loss/{setting_name}", loss, on_step=True, on_epoch=True) - - if mode == "val": - if loss > self.worst_val_loss: - self.worst_val_loss = loss - self.worst_val_sample = { - "image": image, - "pred_labels": pred_labels, - "labels": labels, - "loss_map": loss_map, - } - - return loss - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer diff --git a/zetta_utils/training/lightning/regimes/alignment/misalignment_detector_aced.py b/zetta_utils/training/lightning/regimes/alignment/misalignment_detector_aced.py deleted file mode 100644 index 7a7e3cb7e..000000000 --- a/zetta_utils/training/lightning/regimes/alignment/misalignment_detector_aced.py +++ /dev/null @@ -1,256 +0,0 @@ -# pylint: disable=too-many-locals -import os -from typing import Literal, Optional - -import attrs -import cc3d -import numpy as np -import pytorch_lightning as pl -import torch -import wandb -from PIL import Image as PILImage -from pytorch_lightning import seed_everything - -from zetta_utils import builder, convnet, distributions, viz - - -@builder.register("MisalignmentDetectorAcedRegime") -@attrs.mutable(eq=False) -class MisalignmentDetectorAcedRegime(pl.LightningModule): # pylint: disable=too-many-ancestors - model: torch.nn.Module - lr: float - train_log_row_interval: int = 200 - val_log_row_interval: int = 25 - field_magn_thr: float = 1 - penalize_soma: bool = True - - max_src_displacement_px: distributions.Distribution = distributions.uniform_distr(8.0, 32.0) - equivar_rot_deg_distr: distributions.Distribution = distributions.uniform_distr(0, 360) - equivar_trans_px_distr: distributions.Distribution = distributions.uniform_distr(-10, 10) - - zero_value: float = 0 - output_mode: Literal["binary", "displacement"] = "binary" - - encoder_path: Optional[str] = None - encoder: torch.nn.Module = attrs.field(init=False, default=torch.nn.Identity()) - - def __attrs_pre_init__(self): - super().__init__() - - def __attrs_post_init__(self): - if self.encoder_path is not None: - self.encoder = convnet.utils.load_model(self.encoder_path, use_cache=True).eval() - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) - return optimizer - - def log_results(self, mode: str, title_suffix: str = "", **kwargs): - if not self.logger: - return - images = [] - for k, v in kwargs.items(): - for b in range(1): - if v.dtype in (np.uint8, torch.uint8): - img = v[b].squeeze() - img[-1, -1] = 255 - img[-2, -2] = 255 - img[-1, -2] = 0 - img[-2, -1] = 0 - images.append( - wandb.Image( - PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), - caption=f"{k}_b{b}", - ) - ) - elif v.dtype in (torch.int8, np.int8): - img = v[b].squeeze().byte() + 127 - img[-1, -1] = 255 - img[-2, -2] = 255 - img[-1, -2] = 0 - img[-2, -1] = 0 - images.append( - wandb.Image( - PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), - caption=f"{k}_b{b}", - ) - ) - elif v.dtype in (torch.bool, bool): - img = v[b].squeeze().byte() * 255 - img[-1, -1] = 255 - img[-2, -2] = 255 - img[-1, -2] = 0 - img[-2, -1] = 0 - images.append( - wandb.Image( - PILImage.fromarray(viz.rendering.Renderer()(img), mode="RGB"), - caption=f"{k}_b{b}", - ) - ) - else: - v_min = v[b].min().round(decimals=4) - v_max = v[b].max().round(decimals=4) - images.append( - wandb.Image( - viz.rendering.Renderer()(v[b].squeeze()), - caption=f"{k}_b{b} | min: {v_min} | max: {v_max}", - ) - ) - - self.logger.log_image(f"results/{mode}_{title_suffix}_slider", images=images) - - def validation_epoch_start(self, _): # pylint: disable=no-self-use - seed_everything(42) - - def on_validation_epoch_end(self): - env_seed = os.environ.get("PL_GLOBAL_SEED") - if env_seed is not None: - seed_everything(int(env_seed) + self.current_epoch) - else: - seed_everything(None) - - def _get_warped(self, img, field=None): - img_padded = torch.nn.functional.pad(img, (1, 1, 1, 1), value=self.zero_value) - if field is not None: - assert hasattr(field, "from_pixels") # mypy torchfields compatibility - img_warped = field.from_pixels()(img) - else: - img_warped = img - - zeros_padded = img_padded == self.zero_value - zeros_padded_cc = np.array( - [ - cc3d.connected_components( - x.detach().squeeze().cpu().numpy(), connectivity=4 - ).reshape(zeros_padded[0].shape) - for x in zeros_padded - ] - ) - - non_tissue_zeros_padded = zeros_padded.clone() - non_tissue_zeros_padded[ - torch.tensor(zeros_padded_cc != zeros_padded_cc.ravel()[0], device=zeros_padded.device) - ] = False # keep masking resin, restore somas in center - - if field is not None: - zeros_warped = ( - torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") - .from_pixels() # type: ignore - .sample((~zeros_padded).float(), padding_mode="border") - <= 0.1 - ) - non_tissue_zeros_warped = ( - torch.nn.functional.pad(field, (1, 1, 1, 1), mode="replicate") - .from_pixels() # type: ignore - .sample((~non_tissue_zeros_padded).float(), padding_mode="border") - <= 0.1 - ) - else: - zeros_warped = zeros_padded - non_tissue_zeros_warped = non_tissue_zeros_padded - - zeros_warped = torch.nn.functional.pad(zeros_warped, (-1, -1, -1, -1)) - non_tissue_zeros_warped = torch.nn.functional.pad( - non_tissue_zeros_warped, (-1, -1, -1, -1) - ) - - img_warped[zeros_warped] = self.zero_value - return img_warped, ~zeros_warped, ~non_tissue_zeros_warped - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - log_row = batch_idx % self.train_log_row_interval == 0 - losses = [ - self.compute_misd_loss( - batch=batch, - mode="train", - log_row=log_row, - ) - ] - losses_clean = [l for l in losses if l is not None] - if len(losses_clean) == 0: - return None - loss = sum(losses_clean) - return loss - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - with torch.no_grad(): - log_row = batch_idx % self.val_log_row_interval == 0 - sample_name = f"{batch_idx // self.val_log_row_interval}" - - losses = [ - self.compute_misd_loss( - batch=batch, mode="val", log_row=log_row, sample_name=sample_name - ) - ] - losses_clean = [l for l in losses if l is not None] - if len(losses_clean) == 0: - return None - loss = sum(losses_clean) - return loss - - def compute_misd_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): - src = batch["images"]["src"] - tgt = batch["images"]["tgt"] - - if ((src == self.zero_value) + (tgt == self.zero_value)).bool().sum() / src.numel() > 0.7: - return None - - gt_displacement = batch["images"]["displacement"] - gt_labels = gt_displacement.clone() - - if self.output_mode == "binary": - gt_labels = gt_labels > self.field_magn_thr - - src_warped, src_warped_tissue_wo_soma, src_warped_tissue_w_soma = self._get_warped( - src, field=None - ) - tgt_warped, tgt_warped_tissue_wo_soma, tgt_warped_tissue_w_soma = self._get_warped( - tgt, field=None - ) - - if self.penalize_soma: - intersect_tissue = src_warped_tissue_w_soma & tgt_warped_tissue_w_soma - else: - # Create mask that excludes soma interior from loss, but keep thin tissue in between - # from either section - joint_tissue = tgt_warped_tissue_wo_soma + src_warped_tissue_wo_soma - # Previous mask also added partial tissue at boundary - don't want to penalize there - intersect_tissue = joint_tissue & src_warped_tissue_w_soma & tgt_warped_tissue_w_soma - - if intersect_tissue.sum() == 0: - return None - - with torch.no_grad(): - src_encoded = self.encoder(src_warped) - tgt_encoded = self.encoder(tgt_warped) - prediction = self.model(torch.cat((src_encoded, tgt_encoded), 1)) - - fg_ratio = (gt_labels & intersect_tissue).sum() / intersect_tissue.sum() - if 0.0 < fg_ratio < 1.0: - weight = (1.0 - fg_ratio) * gt_labels + fg_ratio * ~gt_labels - else: - weight = torch.ones_like(gt_labels, dtype=torch.float32) - weight[intersect_tissue == 0] = 0.0 - - loss_map = torch.nn.functional.binary_cross_entropy_with_logits( - prediction, gt_labels.float(), weight=weight, reduction="none" - ) - - loss = loss_map[intersect_tissue].sum() / loss_map.size(0) - - self.log(f"loss/{mode}", loss.item(), on_step=True, on_epoch=True) - - if log_row: - self.log_results( - mode, - sample_name, - src=src, - tgt=tgt, - final_tissue=intersect_tissue, - gt_displacement=gt_displacement, - gt_labels=gt_labels, - weight=weight, - prediction=prediction, - loss_map=loss_map, - ) - return loss diff --git a/zetta_utils/training/lightning/regimes/common.py b/zetta_utils/training/lightning/regimes/common.py deleted file mode 100644 index 45e9f9a5f..000000000 --- a/zetta_utils/training/lightning/regimes/common.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations - -from functools import reduce - -import wandb -from pytorch_lightning.loggers.logger import Logger -from pytorch_lightning.loggers.wandb import WandbLogger -from torchvision.utils import make_grid - -from zetta_utils import tensor_ops, viz -from zetta_utils.geometry import Vec3D -from zetta_utils.tensor_typing import Tensor - - -def is_2d_image(tensor): - return len(tensor.squeeze().shape) == 2 or ( - len(tensor.squeeze().shape) == 3 and tensor.squeeze().shape[0] <= 3 - ) - - -def log_results(mode: str, title_suffix: str = "", logger: Logger | None = None, **kwargs): - if all(is_2d_image(v) for v in kwargs.values()): - row = [ - wandb.Image(viz.rendering.Renderer()(v.squeeze()), caption=k) - for k, v in kwargs.items() - ] - if logger is None: - wandb.log({f"results/{mode}_{title_suffix}_slider": row}) - else: - logger.log_image(f"results/{mode}_{title_suffix}_slider", images=row) - else: - max_z = max(v.shape[-1] for v in kwargs.values()) - - for z in range(max_z): - row = [] - for k, v in kwargs.items(): - if is_2d_image(v): - rendered = viz.rendering.Renderer()(v.squeeze()) - else: - rendered = viz.rendering.Renderer()(v[..., z].squeeze()) - - row.append(wandb.Image(rendered, caption=k)) - - if logger is None: - wandb.log({f"results/{mode}_{title_suffix}_slider_z{z}": row}) - else: - logger.log_image(f"results/{mode}_{title_suffix}_slider_z{z}", images=row) - - -def render_3d_result(data: Tensor): - assert 3 <= data.ndim <= 5 - data_ = tensor_ops.convert.to_torch(data, device="cpu") - data_ = data_[0, ...] if data_.ndim > 4 else data_ - data_ = data_[0:3, ...] if data_.ndim > 3 else data_ - depth = data_.shape[-1] - imgs = [data_[..., z] for z in range(depth)] - return make_grid(imgs, nrow=depth, padding=0) - - -def log_3d_results( - logger: Logger | None, - mode: str, - title_suffix: str = "", - **kwargs, -) -> None: - sizes = [Vec3D(*v.shape[-3:]) for v in kwargs.values()] # type: list[Vec3D] - min_s = reduce(lambda acc, cur: cur if acc > cur else acc, sizes) - - row = [] - for k, v in kwargs.items(): - data = tensor_ops.crop_center(v, min_s) - rendered = render_3d_result(data) - row.append(wandb.Image(rendered, caption=k)) - - if logger is None: - wandb.log({f"results/{mode}_{title_suffix}_slider": row}) - else: - assert isinstance(logger, WandbLogger) - logger.experiment.log({f"results/{mode}_{title_suffix}_slider": row}) diff --git a/zetta_utils/training/lightning/regimes/naive_supervised.py b/zetta_utils/training/lightning/regimes/naive_supervised.py deleted file mode 100644 index 96c1fc2a3..000000000 --- a/zetta_utils/training/lightning/regimes/naive_supervised.py +++ /dev/null @@ -1,76 +0,0 @@ -# pylint: disable=arguments-differ,no-self-use,too-many-ancestors -import attrs -import pytorch_lightning as pl -import torch -from pytorch_lightning import seed_everything - -from zetta_utils import builder - -from .common import log_results - - -@builder.register("NaiveSupervisedRegime") -@attrs.mutable(eq=False) -class NaiveSupervisedRegime(pl.LightningModule): - model: torch.nn.Module - lr: float - min_nonz_frac: float = 0.2 - - train_log_row_interval: int = 200 - val_log_row_interval: int = 25 - - def __attrs_pre_init__(self): - super().__init__() - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) - return optimizer - - def validation_epoch_start(self, _): - seed_everything(42) - - def on_validation_epoch_end(self): - seed_everything(None) - - def validation_step(self, batch, batch_idx): - log_row = batch_idx % self.val_log_row_interval == 0 - sample_name = f"{batch_idx // self.val_log_row_interval}" - loss = self.compute_loss(batch=batch, mode="val", log_row=log_row, sample_name=sample_name) - return loss - - def compute_loss(self, batch: dict, mode: str, log_row: bool, sample_name: str = ""): - data_in = batch["data_in"] - target = batch["target"] - - if (data_in != 0).sum() / data_in.numel() < self.min_nonz_frac: - return None - - result = self.model(data_in) - loss_map = (target - result) ** 2 - - if "loss_weights" in batch: - loss_weights = batch["loss_weights"] - loss = (loss_map * loss_weights).mean() - else: - loss = loss_map.sum() - self.log(f"loss/{mode}", loss.item(), on_step=True, on_epoch=True) - - if log_row: - log_results( - mode, - sample_name, - data_in=data_in, - target=target, - result=result, - loss_map=loss_map, - ) - return loss - - def training_step(self, batch, batch_idx): - log_row = batch_idx % self.train_log_row_interval == 0 - sample_name = "" - - loss = self.compute_loss( - batch=batch, mode="train", log_row=log_row, sample_name=sample_name - ) - return loss diff --git a/zetta_utils/training/lightning/regimes/noop.py b/zetta_utils/training/lightning/regimes/noop.py deleted file mode 100644 index 66f4ed2cf..000000000 --- a/zetta_utils/training/lightning/regimes/noop.py +++ /dev/null @@ -1,36 +0,0 @@ -# pylint: disable=unused-argument -import time - -import attrs -import pytorch_lightning as pl -import torch - -from zetta_utils import builder - -# import wandb - -# from zetta_utils import viz - - -@builder.register("NoOpRegime") -@attrs.mutable(eq=False) -class NoOpRegime(pl.LightningModule): # pylint: disable=too-many-ancestors - def __attrs_pre_init__(self): - super().__init__() - - def validation_step(self, batch, batch_idx): # pylint: disable=arguments-differ - return None - - def training_step(self, batch, batch_idx): # pylint: disable=arguments-differ - time.sleep(0.1) - # self.log('yo', torch.tensor(0.45), on_step=True, on_epoch=True) - # viz.rendering.Renderer()(torch.ones((1024, 1024))) - # wandb.log( - # { - # "yo": wandb.Image(np.ones(1024, 1024)) - # } - # ) - - def configure_optimizers(self): - optimizer = torch.optim.Adam([torch.nn.Parameter()], lr=0) - return optimizer diff --git a/zetta_utils/training/lightning/regimes/segmentation/__init__.py b/zetta_utils/training/lightning/regimes/segmentation/__init__.py deleted file mode 100644 index e6b8e8f8d..000000000 --- a/zetta_utils/training/lightning/regimes/segmentation/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from . import base_affinity, base_embedding diff --git a/zetta_utils/training/lightning/regimes/segmentation/base_affinity.py b/zetta_utils/training/lightning/regimes/segmentation/base_affinity.py deleted file mode 100644 index e80ba32f6..000000000 --- a/zetta_utils/training/lightning/regimes/segmentation/base_affinity.py +++ /dev/null @@ -1,147 +0,0 @@ -# pragma: no cover -# pylint: disable=arguments-differ,too-many-ancestors - -from __future__ import annotations - -from typing import NamedTuple - -import attrs -import pytorch_lightning as pl -import torch - -from zetta_utils import builder, tensor_ops -from zetta_utils.training.lightning.train import distributed_available - -from ..common import log_3d_results - - -@builder.register("BaseAffinityRegime") -@attrs.mutable(eq=False) -class BaseAffinityRegime(pl.LightningModule): - model: torch.nn.Module - lr: float - criteria: dict[str, torch.nn.Module] - loss_weights: dict[str, float] - amsgrad: bool = True - logits: bool = True - group: int = 3 - - train_log_row_interval: int = 200 - val_log_row_interval: int = 25 - - # DDP - sync_dist: bool = True - - def __attrs_pre_init__(self): - super().__init__() - - def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), - lr=self.lr, - amsgrad=self.amsgrad, - ) - return optimizer - - def training_step(self, batch, batch_idx): - log_row = batch_idx % self.train_log_row_interval == 0 - loss = self.compute_loss(batch=batch, mode="train", log_row=log_row) - return loss - - def validation_step(self, batch, batch_idx): - log_row = batch_idx % self.val_log_row_interval == 0 - loss = self.compute_loss(batch=batch, mode="val", log_row=log_row) - return loss - - def compute_loss( - self, batch: dict[str, torch.Tensor], mode: str, log_row: bool, sample_name: str = "" - ): - data_in = batch["data_in"] - results = self.model(data_in) - - # Compute loss - losses = [] - for key, criterion in self.criteria.items(): - pred = getattr(results, key) - trgt = batch[key] - mask = batch[key + "_mask"] - loss = criterion(pred, trgt, mask) - if loss is None: - continue - loss_w = self.loss_weights[key] - losses += [loss_w * loss] - self.log( - f"loss/{key}/{mode}", - loss.item(), - on_step=True, - on_epoch=True, - sync_dist=(distributed_available() and self.sync_dist), - rank_zero_only=True, - ) - - if len(losses) == 0: - return None - - loss = sum(losses) - self.log( - f"loss/{mode}", - loss.item(), - on_step=True, - on_epoch=True, - sync_dist=(distributed_available() and self.sync_dist), - rank_zero_only=True, - ) - - if log_row: - log_3d_results( - self.logger, - mode, - title_suffix=sample_name, - **self.create_row(batch, results), - ) - - return loss - - def create_row( - self, batch: dict[str, torch.Tensor], results: NamedTuple - ) -> dict[str, torch.Tensor]: - row = { - "data_in": batch["data_in"], - "target": tensor_ops.seg_to_rgb(batch["target"]), - } - for key in self.criteria.keys(): - trgt = batch[key] - mask = batch[key + "_mask"] - pred = getattr(results, key) - pred = torch.sigmoid(pred) if self.logits else pred - - # Chop prediction into groups for visualization purpose - num_channels = pred.shape[-4] - group = self.group if self.group > 0 else num_channels - if num_channels > group: - for i in range(0, num_channels, group): - start, end = i, min(i + group, num_channels) - idx = f"[{start}:{end}]" - row[f"{key}{idx}"] = pred[..., start:end, :, :, :] - else: - row[f"{key}"] = pred - # Chop target into groups for visualization purpose - num_channels = trgt.shape[-4] - group = self.group if self.group > 0 else num_channels - if num_channels > group: - for i in range(0, num_channels, group): - start, end = i, min(i + group, num_channels) - idx = f"[{start}:{end}]" - row[f"{key}_target{idx}"] = trgt[..., start:end, :, :, :] - # Optional mask - mask_ = mask[..., start:end, :, :, :] - if torch.count_nonzero(mask_) < torch.numel(mask_): - row[f"{key}_mask{idx}"] = mask_ - else: - if not torch.equal(trgt, batch["target"]): - row[f"{key}_target"] = trgt - - # Optional mask - if torch.count_nonzero(mask) < torch.numel(mask): - row[f"{key}_mask"] = mask - return row diff --git a/zetta_utils/training/lightning/regimes/segmentation/base_embedding.py b/zetta_utils/training/lightning/regimes/segmentation/base_embedding.py deleted file mode 100644 index 68d950118..000000000 --- a/zetta_utils/training/lightning/regimes/segmentation/base_embedding.py +++ /dev/null @@ -1,131 +0,0 @@ -# pragma: no cover -# pylint: disable=arguments-differ,too-many-ancestors - -from __future__ import annotations - -import attrs -import pytorch_lightning as pl -import torch - -from zetta_utils import builder, tensor_ops -from zetta_utils.segmentation import vec_to_pca, vec_to_rgb -from zetta_utils.training.lightning.train import distributed_available - -from ..common import log_3d_results - - -@builder.register("BaseEmbeddingRegime") -@attrs.mutable(eq=False) -class BaseEmbeddingRegime(pl.LightningModule): - model: torch.nn.Module - lr: float - criteria: dict[str, torch.nn.Module] - loss_weights: dict[str, float] - amsgrad: bool = True - - train_log_row_interval: int = 200 - val_log_row_interval: int = 25 - - # DDP - sync_dist: bool = True - - def __attrs_pre_init__(self): - super().__init__() - - def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), - lr=self.lr, - amsgrad=self.amsgrad, - ) - return optimizer - - def training_step(self, batch, batch_idx): - log_row = batch_idx % self.train_log_row_interval == 0 - loss = self.compute_loss(batch=batch, mode="train", log_row=log_row) - return loss - - def validation_step(self, batch, batch_idx): - log_row = batch_idx % self.val_log_row_interval == 0 - loss = self.compute_loss(batch=batch, mode="val", log_row=log_row) - return loss - - def compute_loss( - self, batch: dict[str, torch.Tensor], mode: str, log_row: bool, sample_name: str = "" - ): - data_in = batch["data_in"] - results = self.model(data_in) - - # Create mask if not exist - for key, criterion in self.criteria.items(): - if key + "_mask" in batch: - continue - batch[key + "_mask"] = torch.ones_like(batch[key]) - - # Compute loss - losses = [] - for key, criterion in self.criteria.items(): - pred = results[key] - trgt = batch[key] - mask = batch[key + "_mask"] - splt = batch.get(key + "_split", None) - loss = criterion(pred, trgt, mask, splt) - if loss is None: - continue - loss_w = self.loss_weights[key] - losses += [loss_w * loss] - self.log( - f"loss/{key}/{mode}", - loss.item(), - on_step=True, - on_epoch=True, - sync_dist=(distributed_available() and self.sync_dist), - rank_zero_only=True, - ) - - if len(losses) == 0: - return None - - loss = sum(losses) - self.log( - f"loss/{mode}", - loss.item(), - on_step=True, - on_epoch=True, - sync_dist=(distributed_available() and self.sync_dist), - rank_zero_only=True, - ) - - if log_row: - log_3d_results( - self.logger, - mode, - title_suffix=sample_name, - **self.create_row(batch, results), - ) - - return loss - - def create_row( - self, batch: dict[str, torch.Tensor], results: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor]: - row = { - "data_in": batch["data_in"], - "target": tensor_ops.seg_to_rgb(batch["target"]), - } - - for key in self.criteria.keys(): - mask = batch[key + "_mask"] - pred = results[key] - - # PCA dimensionality reduction - vec = pred[[0], ...] - pca = vec_to_pca(vec) - row[f"{key}[0:3]"] = vec_to_rgb(vec) - row[f"{key}_PCA"] = vec_to_rgb(pca) - - # Optional mask - if torch.count_nonzero(mask) < torch.numel(mask): - row[f"{key}_mask"] = mask - - return row