Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Embeddings #213

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2a65660
adapt embedding layer to new input format of tuple information
AnFreTh Jan 23, 2025
4d5f94a
adapt basemodel encoding function to tuple input
AnFreTh Jan 23, 2025
adc6d19
batch now returns tuple and *data is passed to forward method
AnFreTh Jan 23, 2025
a02b9dd
first two basemodels adapted to new logic
AnFreTh Jan 23, 2025
10d1c00
major changes in handling embeddings as array/list inputs in addition…
AnFreTh Jan 23, 2025
cbe8dd3
dataset returns tuple of data (cat, num, emb), label
AnFreTh Jan 23, 2025
b84aa50
adjust two first basemodel configs to handle projection for embeddings
AnFreTh Jan 23, 2025
8cc3e83
adapt first only regressor and classifier to handle embeddings
AnFreTh Jan 23, 2025
6c0bc5c
preprocessor does not preprocess embeddings, but takes them as input …
AnFreTh Jan 23, 2025
743c214
feature dimensions adapted to new output format of get_feature_info
AnFreTh Jan 23, 2025
4ec70f8
adapting all basemodels to new dataset __getitem__ method
AnFreTh Jan 24, 2025
a2c7845
adapt lightning layer and preprocessor to account for no passed embed…
AnFreTh Jan 24, 2025
b8bc5e9
restructure configs to create parent config-class
AnFreTh Feb 12, 2025
a4c5992
fix minor bugs related to imports and dim identification
AnFreTh Feb 12, 2025
6fc04eb
fix bug related to column names in datamodule - turn int to string
AnFreTh Feb 12, 2025
e60dd80
make box-cox strictly positive
AnFreTh Feb 12, 2025
febf165
include unit tests
AnFreTh Feb 12, 2025
161f6de
remove dependence on rotary embeddings
AnFreTh Feb 12, 2025
bd998d3
include params relöated to [BUG] Missing Configuration Attributes in …
AnFreTh Feb 12, 2025
44d3b3a
test new unit test for pr-requests
AnFreTh Feb 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/pr-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: PR Unit Tests

on:
pull_request:
branches:
- develop
- master # Add any other branches where you want to enforce tests

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout Repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.12" # Change this to match your setup

- name: Install Dependencies
run: |
python -m pip install --upgrade pip
poetry install
pip install pytest

- name: Run Unit Tests
run: pytest tests/

- name: Verify Tests Passed
if: ${{ success() }}
run: echo "All tests passed! Pull request is allowed."

- name: Fail PR on Test Failure
if: ${{ failure() }}
run: exit 1 # This ensures the PR cannot be merged if tests fail
11 changes: 2 additions & 9 deletions mambular/arch_utils/layer_utils/attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from rotary_embedding_torch import RotaryEmbedding


class GEGLU(nn.Module):
Expand All @@ -25,7 +24,7 @@ def FeedForward(dim, mult=4, dropout=0.0):


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
Expand All @@ -34,18 +33,13 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary=False):
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.dropout = nn.Dropout(dropout)
self.rotary = rotary
dim = np.int64(dim / 2)
self.rotary_embedding = RotaryEmbedding(dim=dim)

def forward(self, x):
h = self.heads
x = self.norm(x)
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) # type: ignore
if self.rotary:
q = self.rotary_embedding.rotate_queries_or_keys(q)
k = self.rotary_embedding.rotate_queries_or_keys(k)
q = q * self.scale

sim = torch.einsum("b h i d, b h j d -> b h i j", q, k)
Expand All @@ -61,7 +55,7 @@ def forward(self, x):


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary=False):
def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout):
super().__init__()
self.layers = nn.ModuleList([])

Expand All @@ -74,7 +68,6 @@ def __init__(self, dim, depth, heads, dim_head, attn_dropout, ff_dropout, rotary
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
rotary=rotary,
),
FeedForward(dim, dropout=ff_dropout),
]
Expand Down
57 changes: 38 additions & 19 deletions mambular/arch_utils/layer_utils/embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


class EmbeddingLayer(nn.Module):
def __init__(self, num_feature_info, cat_feature_info, config):
def __init__(self, num_feature_info, cat_feature_info, emb_feature_info, config):
"""Embedding layer that handles numerical and categorical embeddings.

Parameters
Expand All @@ -28,6 +28,7 @@ def __init__(self, num_feature_info, cat_feature_info, config):
self.layer_norm_after_embedding = getattr(
config, "layer_norm_after_embedding", False
)
self.embedding_projection = getattr(config, "embedding_projection", True)
self.use_cls = getattr(config, "use_cls", False)
self.cls_position = getattr(config, "cls_position", 0)
self.embedding_dropout = (
Expand Down Expand Up @@ -100,6 +101,22 @@ def __init__(self, num_feature_info, cat_feature_info, config):
]
)

if len(emb_feature_info) >= 1:
if self.embedding_projection:
self.emb_embeddings = nn.ModuleList(
[
nn.Sequential(
nn.Linear(
feature_info["dimension"],
self.d_model,
bias=self.embedding_bias,
),
self.embedding_activation,
)
for feature_name, feature_info in emb_feature_info.items()
]
)

# Class token if required
if self.use_cls:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d_model))
Expand All @@ -108,15 +125,12 @@ def __init__(self, num_feature_info, cat_feature_info, config):
if self.layer_norm_after_embedding:
self.embedding_norm = nn.LayerNorm(self.d_model)

def forward(self, num_features=None, cat_features=None):
def forward(self, num_features, cat_features, emb_features):
"""Defines the forward pass of the model.

Parameters
----------
num_features : Tensor, optional
Tensor containing the numerical features.
cat_features : Tensor, optional
Tensor containing the categorical features.
data: tuple of lists of tensors

Returns
-------
Expand All @@ -128,6 +142,7 @@ def forward(self, num_features=None, cat_features=None):
ValueError
If no features are provided to the model.
"""
num_embeddings, cat_embeddings, emb_embeddings = None, None, None

# Class token initialization
if self.use_cls:
Expand All @@ -147,8 +162,6 @@ def forward(self, num_features=None, cat_features=None):
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
if self.layer_norm_after_embedding:
cat_embeddings = self.embedding_norm(cat_embeddings)
else:
cat_embeddings = None

# Process numerical embeddings based on embedding_type
if self.embedding_type == "plr":
Expand All @@ -161,25 +174,31 @@ def forward(self, num_features=None, cat_features=None):
num_embeddings = self.num_embeddings(num_features)
if self.layer_norm_after_embedding:
num_embeddings = self.embedding_norm(num_embeddings)
else:
num_embeddings = None
else:
# For linear and ndt embeddings, handle each feature individually
if self.num_embeddings and num_features is not None:
num_embeddings = [emb(num_features[i]) for i, emb in enumerate(self.num_embeddings)] # type: ignore
num_embeddings = torch.stack(num_embeddings, dim=1)
if self.layer_norm_after_embedding:
num_embeddings = self.embedding_norm(num_embeddings)

if emb_features != []:
if self.embedding_projection:
emb_embeddings = [
emb(emb_features[i]) for i, emb in enumerate(self.emb_embeddings)
]
emb_embeddings = torch.stack(emb_embeddings, dim=1)
else:
num_embeddings = None

# Combine categorical and numerical embeddings
if cat_embeddings is not None and num_embeddings is not None:
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
elif cat_embeddings is not None:
x = cat_embeddings
elif num_embeddings is not None:
x = num_embeddings
emb_embeddings = torch.stack(emb_features, dim=1)
if self.layer_norm_after_embedding:
emb_embeddings = self.embedding_norm(emb_embeddings)

embeddings = [
e for e in [cat_embeddings, num_embeddings, emb_embeddings] if e is not None
]

if embeddings:
x = torch.cat(embeddings, dim=1) if len(embeddings) > 1 else embeddings[0]
else:
raise ValueError("No features provided to the model.")

Expand Down
4 changes: 2 additions & 2 deletions mambular/base_models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def pool_sequence(self, out):
else:
raise ValueError(f"Invalid pooling method: {self.hparams.pooling_method}")

def encode(self, num_features, cat_features):
def encode(self, data):
if not hasattr(self, "embedding_layer"):
raise ValueError("The model does not have an embedding layer")

Expand All @@ -237,7 +237,7 @@ def encode(self, num_features, cat_features):
raise ValueError("The model does not generate contextualized embeddings")

# Get the actual layer and call it
x = self.embedding_layer(num_features=num_features, cat_features=cat_features)
x = self.embedding_layer(*data)

if getattr(self.hparams, "shuffle_embeddings", False):
x = x[:, self.perm, :]
Expand Down
23 changes: 9 additions & 14 deletions mambular/base_models/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
from ..configs.fttransformer_config import DefaultFTTransformerConfig
from .basemodel import BaseModel
import numpy as np


class FTTransformer(BaseModel):
Expand Down Expand Up @@ -52,22 +53,18 @@ class FTTransformer(BaseModel):

def __init__(
self,
cat_feature_info,
num_feature_info,
feature_information: tuple, # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
num_classes=1,
config: DefaultFTTransformerConfig = DefaultFTTransformerConfig(), # noqa: B008
**kwargs,
):
super().__init__(config=config, **kwargs)
self.save_hyperparameters(ignore=["cat_feature_info", "num_feature_info"])
self.save_hyperparameters(ignore=["feature_information"])
self.returns_ensemble = False
self.cat_feature_info = cat_feature_info
self.num_feature_info = num_feature_info

# embedding layer
self.embedding_layer = EmbeddingLayer(
num_feature_info=num_feature_info,
cat_feature_info=cat_feature_info,
*feature_information,
config=config,
)

Expand All @@ -87,25 +84,23 @@ def __init__(
)

# pooling
n_inputs = len(num_feature_info) + len(cat_feature_info)
n_inputs = np.sum([len(info) for info in feature_information])
self.initialize_pooling_layers(config=config, n_inputs=n_inputs)

def forward(self, num_features, cat_features):
def forward(self, *data):
"""Defines the forward pass of the model.

Parameters
----------
num_features : Tensor
Tensor containing the numerical features.
cat_features : Tensor
Tensor containing the categorical features.
data : tuple
Input tuple of tensors of num_features, cat_features, embeddings.

Returns
-------
Tensor
The output predictions of the model.
"""
x = self.embedding_layer(num_features, cat_features)
x = self.embedding_layer(*data)

x = self.encoder(x)

Expand Down
Loading
Loading