Welcome to the repository for TabTransformer with Masked Pretraining, which should be the public example demonstrating the training TabTransformers using Masked Language Modelling-style objectives.
- Domain Agnostic: Supports diverse domains with minimal customization.
- Extensible Framework: Easily adapt the framework for new datasets and use cases.
- Weights & Biases Integration: Leverage Weights & Biases (WandB) for experiment tracking, model performance visualization, and hyperparameter optimization.
- PyTorch Lightning Support: Utilizes the PyTorch Lightning framework to simplify training loops, logging, and model checkpointing while ensuring scalability and reproducibility.
├── tabular_datamodule.py # Data Preprocessing and dataloading
├── tabular_module.py # Implementation of TabTransformer architecture
├── train.py # Training script with CLI support
├── pyproject.toml # Poetry configuration file
Overall the repository is designed to be very hackable and the code should be pretty self-explanatory.
- Python 3.10+
- Poetry for dependency management
-
Clone the repository:
git clone https://github.com/charitarthchugh/masked-tabtransformers-lightning.git cd masked-tabtransformers-lightning
-
Install dependencies with Poetry:
poetry install
If you don't want to use Poetry, setting up PyTorch in a virtual environment and doing a pip install of other dependencies should work without issues
-
Activate the virtual environment:
poetry shell
The training script supports multiple configuration options. Use the following command to start training:
python train.py \
--train-data-path path/to/train.csv \
--val-data-path path/to/val.csv \
--test-data-path path/to/test.csv \
--categorical-columns "col1,col2" \
--numerical-columns "col3,col4" \
--batch-size 128 \
--num-epochs 10 \
--learning-rate 0.001 \
--output-dir ./outputs \
--logger wandb \
--wandb-project-name TabTransformer
Look at train.py
to easily modify this for your needs
I welcome contributions to enhance this project! Feel free to submit issues or pull requests.
If you find this project useful, please cite it as:
@misc{tabtransformer-mlm,
author = {},
title = {TabTransformer with Masked Language Modeling: Unlocking Tabular Data Potential},
year = {2024},
howpublished = {\url{https://github.com/charitarthchugh/masked-tabtransformer-lightning}},
}
I would like to thank Phil Wang (lucidrains) for his Tab-Transformer implementation.