This repository is the official implementation of Sparse Multimodal Fusion with Modal Channel Attention https://arxiv.org/abs/2403.20280
To install requirements:
pip install -r requirements.txt
This repository makes heavy use of Huggingface Accelerator and Datasets libraries for managing training and data and WandB for collecting performance metrics. A WandB account name needs to be provided in config files under the setting wandb_account_name
in order to collect train and eval metrics.
To train the model, choose a configuration file from the configs directory and run
accelerate launch train_accel_gpu.py <config_file_path>
Preprocessed datasets are available to download with the following links
To evaluate the model, run an inference using pretrained model weights, then train a linear probe to fit a target property.
To run a batch inference:
accelerate launch infer_accel_gpu.py <config_file_path>
To train a linear probe or MLP using the embeddings generated by inference:
accelerate launch lp_accel_gpu.py <config_file_path>
The full set of model checkpoints have been made available for models trained using the described dataset modality dropout of 0.4 for the purpose of reproducing experiments.
Dataset | Model Type | Link |
---|---|---|
TCGA | MCA | link (2GB) |
TCGA | MMA | link (2GB) |
CMU | MCA | link (2GB) |
CMU | MMA | link (2GB) |
Data encoders are provided which can be configured to collate any combination of sequence, tabular, and pre-embedded modality tokens in encoders.py
. The configuration for a given dataset is presented in the example yaml config files. An example configuration is shown below.
encoder_configs:
COVAREP: {type: 'EmbeddedSequenceEncoder', input_size: 74, max_tokens: 1500}
FACET: {type: 'EmbeddedSequenceEncoder', input_size: 35, max_tokens: 450}
OpenFace: {type: 'EmbeddedSequenceEncoder', input_size: 713, max_tokens: 450}
glove_vectors: {type: 'EmbeddedSequenceEncoder', input_size: 300, max_tokens: 50}
modality_config:
COVAREP: {type: 'embedded_sequence', pad_len: 1500, data_col_name: "data", pad_token: -10000}
FACET: {type: 'embedded_sequence', pad_len: 450, data_col_name: "data", pad_token: -10000}
OpenFace: {type: 'embedded_sequence', pad_len: 450, data_col_name: "data", pad_token: -10000}
glove_vectors: {type: 'embedded_sequence', pad_len: 50, data_col_name: "data", pad_token: -10000}
Multimodal data is received as a dictionary of tensors with keys given by the names under encoder_configs
and the type of data encoder is given by the type
field. The encoder types are: SequenceEncoder, TabularEncoder, SparseTabularEncoder, PatchEncoder, EmbeddedSequenceEncoder
.
The modality_config
filed defines the modality collator type. Available types are: sequence
, embedded_sequence
, matrix
. Tabular encoders use the sequence
collator.
These configuration options define the number of modalities and encoding of the input data before the multimodal fusion transformer and can be adjusted for other multimodal datasets.
Depending on the alignment in multimodal data and general transformer encoder hyperparameters, you may change model configuration parameters. See utils/config.py
for an extensive list of model hyperparameters.
Josiah Bjorgaard