Skip to content

A diffusion model pipeline to generate realistic scientific dataset

License

Notifications You must be signed in to change notification settings

mlexchange/mlex_scientific_txt2image

Repository files navigation

scientific_txt2image

A diffusion model pipeline to generate realistic scientific dataset based on iterative human labels/feedback.

Install

The full pipeline contains two parts, a diffusers generator and an enssemble classification process. Due to version conflicts (xformers==0.0.17 are needed for finetuning and inferencing with diffusers on GPU), we install them in two different environments.

Install diffusers enviroment pip install -r requirements-diffusers.txt
Install classification environment pip install -r requirements-classification.txt

How to finetune diffusion model and classifiers?

  1. The diffusers fine-tuning and inferencing based on scientific domain dataset (see /data/metadata.jsonl as an example). Go to /src, and use the command (example) below to train:
    accelerate launch train_text_to_image_lora.py \  
    --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 \  
    --train_data_dir="als_data" \  
    --resolution=512 --center_crop --random_flip \  
    --train_batch_size=32 \  
    --num_train_epochs=100 \  
    --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \  
    --enable_xformers_memory_efficient_attention \  
    --seed=42 \  
    --output_dir="output" \  
    --validation_prompt="GISAXS data showing peaks"  
    
  2. Once the diffusers model is trained, weights are saved. Then use python3 /src/generator.py (need to modify generator.py for each prompt) to generate images for a given prompt.
  3. Label the generated images.
  4. Use src/classifier/classification.ipynb to train an assortment of computer vision models (vision transformers etc.) to classify the generated images and save their weights. Reapeat step 3 and 4 to improve classification accuracy.

How to do inference?

  1. Full pipeline: generate realistic from diffusers and classifiers python3 /src/inferece.py (need to modify generator.py for each prompt).
  2. Or an interactive interface to generate realistic scientific images from a prompt: txt2image_widgets.ipynb (serial) and txt2image_widgets (parallel).

Paper

Preprint on arXiv

BibTex

@misc{zhao2024generatingrealisticxrayscattering,
      title={Generating Realistic X-ray Scattering Images Using Stable Diffusion and Human-in-the-loop Annotations}, 
      author={Zhuowen Zhao and Xiaoya Chong and Tanny Chavez and Alexander Hexemer},
      year={2024},
      eprint={2408.12720},
      archivePrefix={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2408.12720}, 
}

About

A diffusion model pipeline to generate realistic scientific dataset

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published