From 0a88530202db2a63010d3661f5d34e49b880a03b Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Tue, 24 Oct 2023 23:49:17 +0200 Subject: [PATCH] Refactor and generalize image attributes tokenizer as time series embedding module --- .../cardiac-multimodal-representation.yaml | 11 +-- .../img_tokenizer/model/linear-embedding.yaml | 3 + didactic/models/time_series.py | 54 +++++++++++++++ .../cardiac_multimodal_representation.py | 67 ------------------- 4 files changed, 63 insertions(+), 72 deletions(-) create mode 100644 didactic/config/task/img_tokenizer/model/linear-embedding.yaml create mode 100644 didactic/models/time_series.py diff --git a/didactic/config/experiment/cardinal/cardiac-multimodal-representation.yaml b/didactic/config/experiment/cardinal/cardiac-multimodal-representation.yaml index fc485899..0f3dcbfb 100644 --- a/didactic/config/experiment/cardinal/cardiac-multimodal-representation.yaml +++ b/didactic/config/experiment/cardinal/cardiac-multimodal-representation.yaml @@ -1,6 +1,7 @@ # @package _global_ defaults: + - /task/img_tokenizer/model: linear-embedding - override /task/model: cardinal-ft-transformer - override /task/optim: null - override /data: cardinal @@ -110,10 +111,8 @@ task: d_token: ${task.embed_dim} img_tokenizer: - _target_: didactic.tasks.cardiac_multimodal_representation.CardiacSequenceAttributesTokenizer - resample_dim: 128 - embed_dim: ${task.embed_dim} - num_attrs: ${op.mul:${builtin.len:${task.views}},${builtin.len:${task.img_attrs}}} + _target_: didactic.models.time_series.TimeSeriesEmbedding + resample_dim: 64 optim: optimizer: @@ -141,7 +140,7 @@ callbacks: # attention_rollout_kwargs: # includes_cls_token: ${task.latent_token} -experiment_dirname: encoder=${hydra:runtime.choices.task/model}/n_clinical_attrs=${builtin.len:${task.clinical_attrs}},n_img_attrs=${op.mul:${builtin.len:${task.views}},${builtin.len:${task.img_attrs}}}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr},attrs_dropout=${task.attrs_dropout} +experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr},attrs_dropout=${task.attrs_dropout} hydra: job: config: @@ -179,6 +178,8 @@ hydra: - task.attrs_dropout - task.embed_dim + - task/img_tokenizer/model + - task/model - task.model.encoder.num_layers - task.model.encoder.encoder_layer.nhead diff --git a/didactic/config/task/img_tokenizer/model/linear-embedding.yaml b/didactic/config/task/img_tokenizer/model/linear-embedding.yaml new file mode 100644 index 00000000..a60cbf99 --- /dev/null +++ b/didactic/config/task/img_tokenizer/model/linear-embedding.yaml @@ -0,0 +1,3 @@ +_target_: torch.nn.Linear +in_features: ${task.img_tokenizer.resample_dim} +out_features: ${task.embed_dim} diff --git a/didactic/models/time_series.py b/didactic/models/time_series.py new file mode 100644 index 00000000..d9ab36ee --- /dev/null +++ b/didactic/models/time_series.py @@ -0,0 +1,54 @@ +from typing import Any, Dict, Sequence + +import torch +from torch import Tensor, nn +from torch.nn import functional as F + + +class TimeSeriesEmbedding(nn.Module): + """Embedding for time series which resamples the time dim and/or passes through an arbitrary learnable model.""" + + def __init__(self, resample_dim: int, model: nn.Module = None): + """Initializes class instance. + + Args: + resample_dim: Target size for an interpolation resampling of the time series. + model: Model that learns to embed the time series. If not provided, no projection is learned and the + embedding is simply the resampled time series. The model should take as input a tensor of shape + (N, `resample_dim`) and output a tensor of shape (N, E), where E is the embedding size. + """ + super().__init__() + self.model = model + self.resample_dim = resample_dim + + def forward(self, time_series: Dict[Any, Tensor] | Sequence[Tensor]) -> Tensor: + """Stacks the time series, optionally 1) resampling them and/or 2) projecting them to a target embedding. + + Args: + time_series: (K: S, V: (N, ?)) or S * (N, ?): Time series batches to embed, where the dimensionality of each + time series can vary. + + Returns: + (N, S, E), Embedding of the time series. + """ + if not isinstance(time_series, dict): + time_series = {idx: t for idx, t in enumerate(time_series)} + + # Resample time series to make sure all of them are of `resample_dim` + for t_id, t in time_series.items(): + if t.shape[-1] != self.resample_dim: + # Temporarily reshape time series batch tensor to be 3D to be able to use torch's interpolation + # (N, ?) -> (N, `resample_dim`) + time_series[t_id] = F.interpolate(t.unsqueeze(1), size=self.resample_dim, mode="linear").squeeze(dim=1) + + # Extract the time series from the dictionary and stack them along the batch dimension + x = list(time_series.values()) # (S, N, `resample_dim`) + + if self.model: + # If provided with a learnable model, use it to predict the embedding of each time series separately + x = [self.model(attr) for attr in x] # (S, N, `resample_dim`) -> (S, N, E) + + # Stack the embeddings of all the time series (along the batch dimension) to make only one tensor + x = torch.stack(x, dim=1) # (S, N, E) -> (N, S, E) + + return x diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index a19c85d0..fe0c7fd1 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -10,7 +10,6 @@ import torch.nn.functional as F from omegaconf import DictConfig from rtdl import FeatureTokenizer -from rtdl.modules import _TokenInitialization from torch import Tensor, nn from torch.nn import Parameter, ParameterDict, init from torchmetrics.functional import accuracy, mean_absolute_error @@ -28,72 +27,6 @@ CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute] -class CardiacSequenceAttributesTokenizer(nn.Module): - """Tokenizer that pre-processes attributes extracted from cardiac sequences for a transformer model.""" - - def __init__(self, resample_dim: int, embed_dim: int = None, num_attrs: int = None): - """Initializes class instance. - - Args: - resample_dim: Target size for a simple interpolation resampling of the attributes. Mutually exclusive - parameter with `cardiac_sequence_attrs_model`. - embed_dim: Size of the embedding in which to project the resampled attributes. If not specified, no - projection is learned and the embedding is directly the resampled attributes. Only used when - `resample_dim` is provided. - num_attrs: Number of attributes to tokenize. Only required when `embed_dim` is not None to initialize the - weights and bias parameters of the learnable embeddings. - """ - if embed_dim is not None and num_attrs is None: - raise ValueError( - "When opting for the resample+project method of tokenizing image attributes, you must indicate the " - "expected attributes to initialize the weights and biases for the projection." - ) - - super().__init__() - - self.resample_dim = resample_dim - - self.weight = None - if embed_dim: - initialization_ = _TokenInitialization.from_str("uniform") - self.weight = nn.Parameter(Tensor(num_attrs, resample_dim, embed_dim)) - self.bias = nn.Parameter(Tensor(num_attrs, embed_dim)) - for parameter in [self.weight, self.bias]: - initialization_.apply(parameter, embed_dim) - - @torch.inference_mode() - def forward(self, attrs: Dict[Any, Tensor] | Sequence[Tensor]) -> Tensor: - """Embeds image attributes by resampling them, and optionally projecting them to the target embedding. - - Args: - attrs: (K: S, V: (N, ?)) or S * (N, ?): Attributes to tokenize, where the dimensionality of each attribute - can vary. - - Returns: - (N, S, E), Tokenized version of the attributes. - """ - if not isinstance(attrs, dict): - attrs = {idx: attr for idx, attr in enumerate(attrs)} - - # Resample attributes to make sure all of them are of `resample_dim` - for attr_id, attr in attrs.items(): - if attr.shape[-1] != self.resample_dim: - # Temporarily reshape attribute batch tensor to be 3D to be able to use torch's interpolation - # (N, ?) -> (N, `resample_dim`) - attrs[attr_id] = F.interpolate(attr.unsqueeze(1), size=self.resample_dim, mode="linear").squeeze(dim=1) - - # Now that all attributes are of the same shape, merge them into one single tensor - x = torch.stack(list(attrs.values()), dim=1) # (N, S, L) - - if self.weight is not None: - # Broadcast along all but the last two dimensions, which perform the matrix multiply - # (N, S, 1, L) @ (S, L, E) -> (N, S, E) - x = (x[..., None, :] @ self.weight).squeeze(dim=-2) - x = x + self.bias[None] - - return x - - class CardiacMultimodalRepresentationTask(SharedStepsTask): """Multi-modal transformer to learn a representation from cardiac imaging and patient records data."""