An implementation of a diffusion model sampler using a UNet transformer to generate handwritten digit samples.
Explore the docs »
View Demo
·
Report Bug
·
Request Feature
Table of Contents
Diffusion models have shown great promise in generating high-quality samples in various domains. In this project, we utilize a UNet transformer-based diffusion model to generate samples of handwritten digits. The process involves:
- Setting up the model and loading pre-trained weights.
- Generating samples for each digit.
- Creating a GIF to visualize the generated samples.
To get a local copy up and running follow these simple example steps.
Ensure you have the following prerequisites installed:
- Python 3.8 or higher
- CUDA-enabled GPU (optional but recommended)
- The following Python libraries:
- torch
- torchvision
- numpy
- Pillow
- matplotlib
- Clone the repository:
git clone https://github.com/Yavuzhan-Baykara/Stable-Diffusion.git cd Stable-Diffusion
- Install the required Python libraries:
pip install torch torchvision numpy Pillow matplotlib
To train the UNet transformer with different datasets and samplers, use the following command:
python train.py <dataset> <sampler> <epoch> <batch_size>