Skip to content

Commit

Permalink
Migrate from rtdl to rtdl_revisiting_models dependency
Browse files Browse the repository at this point in the history
Some APIs (i.e. `CLSToken` and `FeatureTokenizer`) were not ported to the new dependency. Their original implementations were therefore copied and patched for the project.
  • Loading branch information
nathanpainchaud committed Jan 15, 2024
1 parent b9113da commit 7daeee9
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ task:
mt_by_attr: False

tabular_tokenizer:
_target_: rtdl.FeatureTokenizer
_target_: didactic.models.tabular.TabularEmbedding
d_token: ${task.embed_dim}

time_series_tokenizer:
Expand Down
65 changes: 65 additions & 0 deletions didactic/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,71 @@ def forward(self, x: Tensor) -> Tensor:
return x + self.positional_encoding[None, ...]


class CLSToken(nn.Module):
"""[CLS]-token for BERT-like inference.
When used as a module, the [CLS]-token is appended **to the end** of each item in the batch.
Notes:
- This is a port of the `CLSToken` class from v0.0.13 of the `rtdl` package. It mixes the original
implementation with the simpler code of `_CLSEmbedding` from v0.0.2 of the `rtdl_revisiting_models` package.
References:
- Original implementation is here: https://github.com/yandex-research/rtdl/blob/f395a2db37bac74f3a209e90511e2cb84e218973/rtdl/modules.py#L380-L446
Examples:
.. testcode::
batch_size = 2
n_tokens = 3
d_token = 4
cls_token = CLSToken(d_token, 'uniform')
x = torch.randn(batch_size, n_tokens, d_token)
x = cls_token(x)
assert x.shape == (batch_size, n_tokens + 1, d_token)
assert (x[:, -1, :] == cls_token.expand(len(x))).all()
"""

def __init__(self, d_token: int) -> None:
"""Initializes class instance.
Args:
d_token: the size of token
"""
super().__init__()
self.weight = nn.Parameter(torch.empty(d_token))
self.reset_parameters()

def reset_parameters(self) -> None:
"""Initializes the weights using a uniform distribution."""
d_rsqrt = self.weight.shape[-1] ** -0.5
nn.init.uniform_(self.weight, -d_rsqrt, d_rsqrt)

def expand(self, *leading_dimensions: int) -> Tensor:
"""Expand (repeat) the underlying [CLS]-token to a tensor with the given leading dimensions.
A possible use case is building a batch of [CLS]-tokens.
Note:
Under the hood, the `torch.Tensor.expand` method is applied to the underlying :code:`weight` parameter, so
gradients will be propagated as expected.
Args:
leading_dimensions: the additional new dimensions
Returns:
tensor of the shape :code:`(*leading_dimensions, len(self.weight))`
"""
if not leading_dimensions:
return self.weight
new_dims = (1,) * (len(leading_dimensions) - 1)
return self.weight.view(*new_dims, -1).expand(*leading_dimensions, -1)

def forward(self, x: Tensor) -> Tensor:
"""Append self **to the end** of each item in the batch (see `CLSToken`)."""
return torch.cat([x, self.expand(len(x), 1)], dim=1)


class SequencePooling(nn.Module):
"""Sequence pooling layer."""

Expand Down
100 changes: 100 additions & 0 deletions didactic/models/tabular.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import List, Optional

import torch
from rtdl_revisiting_models import CategoricalEmbeddings, LinearEmbeddings
from torch import Tensor, nn


def _all_or_none(values):
return all(x is None for x in values) or all(x is not None for x in values)


class TabularEmbedding(nn.Module):
"""Combines `LinearEmbeddings` and `CategoricalEmbeddings`.
The "Feature Tokenizer" module from "Revisiting Deep Learning Models for Tabular Data" by Gorishniy et al. (2021).
The module transforms continuous and categorical features to tokens (embeddings).
Notes:
- This is a port of the `FeatureTokenizer` class from v0.0.13 of the `rtdl` package using the updated underlying
`CategoricalEmbeddings` and `LinearEmbeddings` from v0.0.2 of the `rtdl_revisiting_models` package, instead of
the original `CategoricalFeatureTokenizer` and `NumericalFeatureTokenizer` from the `rtdl` package.
References:
- Original implementation is here: https://github.com/yandex-research/rtdl/blob/f395a2db37bac74f3a209e90511e2cb84e218973/rtdl/modules.py#L260-L377
Examples:
.. testcode::
n_objects = 4
n_num_features = 3
n_cat_features = 2
d_token = 7
x_num = torch.randn(n_objects, n_num_features)
x_cat = torch.tensor([[0, 1], [1, 0], [0, 2], [1, 1]])
# [2, 3] reflects cardinalities
tokenizer = FeatureTokenizer(n_num_features, [2, 3], d_token)
tokens = tokenizer(x_num, x_cat)
assert tokens.shape == (n_objects, n_num_features + n_cat_features, d_token)
"""

def __init__(
self,
n_num_features: int,
cat_cardinalities: List[int],
d_token: int,
) -> None:
"""Initializes class instance.
Args:
n_num_features: the number of continuous features. Pass :code:`0` if there are no numerical features.
cat_cardinalities: the number of unique values for each feature. Pass an empty list if there are no
categorical features.
d_token: the size of one token.
"""
super().__init__()
assert n_num_features >= 0, "n_num_features must be non-negative"
assert (
n_num_features or cat_cardinalities
), "at least one of n_num_features or cat_cardinalities must be positive/non-empty"
self.num_tokenizer = LinearEmbeddings(n_num_features, d_token) if n_num_features else None
self.cat_tokenizer = CategoricalEmbeddings(cat_cardinalities, d_token) if cat_cardinalities else None

@property
def n_tokens(self) -> int:
"""The number of tokens."""
return sum(x.n_tokens for x in [self.num_tokenizer, self.cat_tokenizer] if x is not None)

@property
def d_token(self) -> int:
"""The size of one token."""
return self.cat_tokenizer.d_token if self.num_tokenizer is None else self.num_tokenizer.d_token # type: ignore

def forward(self, x_num: Optional[Tensor], x_cat: Optional[Tensor]) -> Tensor:
"""Perform the forward pass.
Args:
x_num: continuous features. Must be presented if :code:`n_num_features > 0` was passed to the constructor.
x_cat: categorical features. Must be presented if non-empty :code:`cat_cardinalities` was passed to the
constructor.
Returns:
tokens
Raises:
AssertionError: if the described requirements for the inputs are not met.
"""

assert x_num is not None or x_cat is not None, "At least one of x_num and x_cat must be presented"
assert _all_or_none(
[self.num_tokenizer, x_num]
), "If self.num_tokenizer is (not) None, then x_num must (not) be None"
assert _all_or_none(
[self.cat_tokenizer, x_cat]
), "If self.cat_tokenizer is (not) None, then x_cat must (not) be None"
x = []
if self.num_tokenizer is not None:
x.append(self.num_tokenizer(x_num))
if self.cat_tokenizer is not None:
x.append(self.cat_tokenizer(x_cat))
return x[0] if len(x) == 1 else torch.cat(x, dim=1)
9 changes: 4 additions & 5 deletions didactic/tasks/cardiac_multimodal_representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@

import autogluon.multimodal.models.ft_transformer
import hydra
import rtdl
import torch
from omegaconf import DictConfig
from rtdl import FeatureTokenizer
from torch import Tensor, nn
from torch.nn import Parameter, ParameterDict, init
from torchmetrics.functional import accuracy, mean_absolute_error
Expand All @@ -21,7 +19,8 @@
from vital.tasks.generic import SharedStepsTask
from vital.utils.decorators import auto_move_data

from didactic.models.layers import PositionalEncoding, SequencePooling
from didactic.models.layers import CLSToken, PositionalEncoding, SequencePooling
from didactic.models.tabular import TabularEmbedding
from didactic.models.time_series import TimeSeriesEmbedding

logger = logging.getLogger(__name__)
Expand All @@ -41,7 +40,7 @@ def __init__(
ordinal_mode: bool = True,
contrastive_loss: Callable[[Tensor, Tensor], Tensor] | DictConfig = None,
contrastive_loss_weight: float = 0,
tabular_tokenizer: Optional[FeatureTokenizer | DictConfig] = None,
tabular_tokenizer: Optional[TabularEmbedding | DictConfig] = None,
time_series_tokenizer: Optional[TimeSeriesEmbedding | DictConfig] = None,
cls_token: bool = True,
sequence_pooling: bool = False,
Expand Down Expand Up @@ -269,7 +268,7 @@ def __init__(

# Initialize parameters of method for reducing the dimensionality of the encoder's output to only one token
if self.hparams.cls_token:
self.cls_token = rtdl.CLSToken(self.hparams.embed_dim, "uniform")
self.cls_token = CLSToken(self.hparams.embed_dim)
if self.hparams.sequence_pooling:
self.sequence_pooling = SequencePooling(self.hparams.embed_dim)

Expand Down
27 changes: 10 additions & 17 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ python = "~3.10.6"
torch = { version = "~2.0.0", source = "pytorch-cu118" }
torchvision = { version = "~0.15.1", source = "pytorch-cu118" }
pytorch-lightning = "~2.0.0"
# Temporary workaround to install RTDL from my own fork, which relaxes an unecessarily strict constraint version that
# prohibits PyTorch's 2-series
rtdl = { git = "https://github.com/nathanpainchaud/rtdl.git", branch = "torch-2.X-support" }
rtdl-revisiting-models = "~0.0.2"
hydra-core = "~1.3.0"
hydra-joblib-launcher = "*"
holoviews = { version = "*", extras = ["recommended"] }
Expand Down

0 comments on commit 7daeee9

Please sign in to comment.