Skip to content

Commit

Permalink
Refactor and generalize image attributes tokenizer as time series emb…
Browse files Browse the repository at this point in the history
…edding module
  • Loading branch information
nathanpainchaud committed Oct 24, 2023
1 parent 9ebe4aa commit 0a88530
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 72 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# @package _global_

defaults:
- /task/img_tokenizer/model: linear-embedding
- override /task/model: cardinal-ft-transformer
- override /task/optim: null
- override /data: cardinal
Expand Down Expand Up @@ -110,10 +111,8 @@ task:
d_token: ${task.embed_dim}

img_tokenizer:
_target_: didactic.tasks.cardiac_multimodal_representation.CardiacSequenceAttributesTokenizer
resample_dim: 128
embed_dim: ${task.embed_dim}
num_attrs: ${op.mul:${builtin.len:${task.views}},${builtin.len:${task.img_attrs}}}
_target_: didactic.models.time_series.TimeSeriesEmbedding
resample_dim: 64

optim:
optimizer:
Expand Down Expand Up @@ -141,7 +140,7 @@ callbacks:
# attention_rollout_kwargs:
# includes_cls_token: ${task.latent_token}

experiment_dirname: encoder=${hydra:runtime.choices.task/model}/n_clinical_attrs=${builtin.len:${task.clinical_attrs}},n_img_attrs=${op.mul:${builtin.len:${task.views}},${builtin.len:${task.img_attrs}}}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr},attrs_dropout=${task.attrs_dropout}
experiment_dirname: encoder=${hydra:runtime.choices.task/model}/img_tokenizer=${hydra:runtime.choices.task/img_tokenizer/model}/n_clinical_attrs=${n_clinical_attrs},n_img_attrs=${n_img_attrs}/contrastive=${oc.select:task.contrastive_loss_weight,0}/embed_dim=${task.embed_dim},depth=${task.model.encoder.num_layers},nhead=${task.model.encoder.encoder_layer.nhead},dropout=${task.model.encoder.encoder_layer.dropout}/mtr_p=${task.mtr_p},mt_by_attr=${task.mt_by_attr},attrs_dropout=${task.attrs_dropout}
hydra:
job:
config:
Expand Down Expand Up @@ -179,6 +178,8 @@ hydra:
- task.attrs_dropout

- task.embed_dim
- task/img_tokenizer/model

- task/model
- task.model.encoder.num_layers
- task.model.encoder.encoder_layer.nhead
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: torch.nn.Linear
in_features: ${task.img_tokenizer.resample_dim}
out_features: ${task.embed_dim}
54 changes: 54 additions & 0 deletions didactic/models/time_series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Any, Dict, Sequence

import torch
from torch import Tensor, nn
from torch.nn import functional as F


class TimeSeriesEmbedding(nn.Module):
"""Embedding for time series which resamples the time dim and/or passes through an arbitrary learnable model."""

def __init__(self, resample_dim: int, model: nn.Module = None):
"""Initializes class instance.
Args:
resample_dim: Target size for an interpolation resampling of the time series.
model: Model that learns to embed the time series. If not provided, no projection is learned and the
embedding is simply the resampled time series. The model should take as input a tensor of shape
(N, `resample_dim`) and output a tensor of shape (N, E), where E is the embedding size.
"""
super().__init__()
self.model = model
self.resample_dim = resample_dim

def forward(self, time_series: Dict[Any, Tensor] | Sequence[Tensor]) -> Tensor:
"""Stacks the time series, optionally 1) resampling them and/or 2) projecting them to a target embedding.
Args:
time_series: (K: S, V: (N, ?)) or S * (N, ?): Time series batches to embed, where the dimensionality of each
time series can vary.
Returns:
(N, S, E), Embedding of the time series.
"""
if not isinstance(time_series, dict):
time_series = {idx: t for idx, t in enumerate(time_series)}

# Resample time series to make sure all of them are of `resample_dim`
for t_id, t in time_series.items():
if t.shape[-1] != self.resample_dim:
# Temporarily reshape time series batch tensor to be 3D to be able to use torch's interpolation
# (N, ?) -> (N, `resample_dim`)
time_series[t_id] = F.interpolate(t.unsqueeze(1), size=self.resample_dim, mode="linear").squeeze(dim=1)

# Extract the time series from the dictionary and stack them along the batch dimension
x = list(time_series.values()) # (S, N, `resample_dim`)

if self.model:
# If provided with a learnable model, use it to predict the embedding of each time series separately
x = [self.model(attr) for attr in x] # (S, N, `resample_dim`) -> (S, N, E)

# Stack the embeddings of all the time series (along the batch dimension) to make only one tensor
x = torch.stack(x, dim=1) # (S, N, E) -> (N, S, E)

return x
67 changes: 0 additions & 67 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch.nn.functional as F
from omegaconf import DictConfig
from rtdl import FeatureTokenizer
from rtdl.modules import _TokenInitialization
from torch import Tensor, nn
from torch.nn import Parameter, ParameterDict, init
from torchmetrics.functional import accuracy, mean_absolute_error
Expand All @@ -28,72 +27,6 @@
CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute]


class CardiacSequenceAttributesTokenizer(nn.Module):
"""Tokenizer that pre-processes attributes extracted from cardiac sequences for a transformer model."""

def __init__(self, resample_dim: int, embed_dim: int = None, num_attrs: int = None):
"""Initializes class instance.
Args:
resample_dim: Target size for a simple interpolation resampling of the attributes. Mutually exclusive
parameter with `cardiac_sequence_attrs_model`.
embed_dim: Size of the embedding in which to project the resampled attributes. If not specified, no
projection is learned and the embedding is directly the resampled attributes. Only used when
`resample_dim` is provided.
num_attrs: Number of attributes to tokenize. Only required when `embed_dim` is not None to initialize the
weights and bias parameters of the learnable embeddings.
"""
if embed_dim is not None and num_attrs is None:
raise ValueError(
"When opting for the resample+project method of tokenizing image attributes, you must indicate the "
"expected attributes to initialize the weights and biases for the projection."
)

super().__init__()

self.resample_dim = resample_dim

self.weight = None
if embed_dim:
initialization_ = _TokenInitialization.from_str("uniform")
self.weight = nn.Parameter(Tensor(num_attrs, resample_dim, embed_dim))
self.bias = nn.Parameter(Tensor(num_attrs, embed_dim))
for parameter in [self.weight, self.bias]:
initialization_.apply(parameter, embed_dim)

@torch.inference_mode()
def forward(self, attrs: Dict[Any, Tensor] | Sequence[Tensor]) -> Tensor:
"""Embeds image attributes by resampling them, and optionally projecting them to the target embedding.
Args:
attrs: (K: S, V: (N, ?)) or S * (N, ?): Attributes to tokenize, where the dimensionality of each attribute
can vary.
Returns:
(N, S, E), Tokenized version of the attributes.
"""
if not isinstance(attrs, dict):
attrs = {idx: attr for idx, attr in enumerate(attrs)}

# Resample attributes to make sure all of them are of `resample_dim`
for attr_id, attr in attrs.items():
if attr.shape[-1] != self.resample_dim:
# Temporarily reshape attribute batch tensor to be 3D to be able to use torch's interpolation
# (N, ?) -> (N, `resample_dim`)
attrs[attr_id] = F.interpolate(attr.unsqueeze(1), size=self.resample_dim, mode="linear").squeeze(dim=1)

# Now that all attributes are of the same shape, merge them into one single tensor
x = torch.stack(list(attrs.values()), dim=1) # (N, S, L)

if self.weight is not None:
# Broadcast along all but the last two dimensions, which perform the matrix multiply
# (N, S, 1, L) @ (S, L, E) -> (N, S, E)
x = (x[..., None, :] @ self.weight).squeeze(dim=-2)
x = x + self.bias[None]

return x


class CardiacMultimodalRepresentationTask(SharedStepsTask):
"""Multi-modal transformer to learn a representation from cardiac imaging and patient records data."""

Expand Down

0 comments on commit 0a88530

Please sign in to comment.