Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image generator to generate images for use with geneval #172

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions diffusion/evaluation/generate_geneval_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Image generation for runnning evaluation with geneval."""

import json
import os
from typing import Dict, Optional, Union
from urllib.parse import urlparse

import torch
from composer.core import get_precision_context
from composer.utils import dist
from composer.utils.file_helpers import get_file
from composer.utils.object_store import OCIObjectStore
from diffusers import AutoPipelineForText2Image
from torchvision.transforms.functional import to_pil_image
from tqdm.auto import tqdm


class GenevalImageGenerator:
"""Image generator that generates images from the geneval prompt set and saves them.

Args:
model (torch.nn.Module): The model to evaluate.
geneval_prompts (str): Path to the prompts to use for geneval (ex: `geneval/prompts/evaluation_metadata.json`).
load_path (str, optional): The path to load the model from. Default: ``None``.
local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``.
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``.
height (int): The height of the generated images. Default: ``1024``.
width (int): The width of the generated images. Default: ``1024``.
images_per_prompt (int): The number of images to generate per prompt. Default: ``4``.
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
seed (int): The seed to use for generation. Default: ``17``.
output_bucket (str, Optional): The remote to save images to. Default: ``None``.
output_prefix (str, Optional): The prefix to save images to. Default: ``None``.
local_prefix (str): The local prefix to save images to. Default: ``/tmp``.
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
hf_model: (bool, Optional): whether the model is HF or not. Default: ``False``.
"""

def __init__(self,
model: Union[torch.nn.Module, str],
geneval_prompts: str,
load_path: Optional[str] = None,
local_checkpoint_path: str = '/tmp/model.pt',
load_strict_model_weights: bool = True,
guidance_scale: float = 7.0,
height: int = 1024,
width: int = 1024,
images_per_prompt: int = 4,
seed: int = 17,
output_bucket: Optional[str] = None,
output_prefix: Optional[str] = None,
local_prefix: str = '/tmp',
additional_generate_kwargs: Optional[Dict] = None,
hf_model: Optional[bool] = False):

if isinstance(model, str) and hf_model == False:
raise ValueError('Can only use strings for model with hf models!')
self.hf_model = hf_model
if hf_model or isinstance(model, str):
if dist.get_local_rank() == 0:
self.model = AutoPipelineForText2Image.from_pretrained(
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
self.model = AutoPipelineForText2Image.from_pretrained(
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
else:
self.model = model
# Load the geneval prompts
self.geneval_prompts = geneval_prompts
with open(geneval_prompts) as f:
self.prompt_metadata = [json.loads(line) for line in f]
self.load_path = load_path
self.local_checkpoint_path = local_checkpoint_path
self.load_strict_model_weights = load_strict_model_weights
self.guidance_scale = guidance_scale
self.height = height
self.width = width
self.images_per_prompt = images_per_prompt
self.seed = seed
self.generator = torch.Generator(device='cuda').manual_seed(self.seed)

self.output_bucket = output_bucket
self.output_prefix = output_prefix if output_prefix is not None else ''
self.local_prefix = local_prefix
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}

# Object store for uploading images
if self.output_bucket is not None:
parsed_remote_bucket = urlparse(self.output_bucket)
if parsed_remote_bucket.scheme != 'oci':
raise ValueError(f'Currently only OCI object stores are supported. Got {parsed_remote_bucket.scheme}.')
self.object_store = OCIObjectStore(self.output_bucket.replace('oci://', ''), self.output_prefix)

# Download the model checkpoint if needed
if self.load_path is not None and not isinstance(self.model, str):
if dist.get_local_rank() == 0:
get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True)
with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path):
# Load the model
state_dict = torch.load(self.local_checkpoint_path, map_location='cpu')
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
del state_dict['state']['model'][key]
self.model.load_state_dict(state_dict['state']['model'], strict=self.load_strict_model_weights)
self.model = self.model.cuda().eval()

def generate(self):
"""Core image generation function. Generates images at a given guidance scale.

Args:
guidance_scale (float): The guidance scale to use for image generation.
"""
os.makedirs(os.path.join(self.local_prefix, self.output_prefix), exist_ok=True)
# Partition the dataset across the ranks. Note this partitions prompts, not repeats.
dataset_len = len(self.prompt_metadata)
samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size())
start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank())
end_idx = start_idx + samples_per_rank
if dist.get_global_rank() < remainder:
end_idx += 1
print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.')
# Iterate over the dataset
for sample_id in tqdm(range(start_idx, end_idx)):
metadata = self.prompt_metadata[sample_id]
# Write the metadata jsonl
output_dir = os.path.join(self.local_prefix, f'{sample_id:0>5}')
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, 'metadata.jsonl'), 'w') as f:
json.dump(metadata, f)
caption = metadata['prompt']
# Create dir for samples to live in
sample_dir = os.path.join(output_dir, 'samples')
os.makedirs(sample_dir, exist_ok=True)
# Generate images from the captions. Take care to use a different seed for each image
for i in range(self.images_per_prompt):
seed = self.seed + i
if self.hf_model:
generated_image = self.model(prompt=caption,
height=self.height,
width=self.width,
guidance_scale=self.guidance_scale,
generator=self.generator,
**self.additional_generate_kwargs).images[0]
img = generated_image
else:
with get_precision_context('amp_fp16'):
generated_image = self.model.generate(prompt=caption,
height=self.height,
width=self.width,
guidance_scale=self.guidance_scale,
seed=seed,
progress_bar=False,
**self.additional_generate_kwargs) # type: ignore
img = to_pil_image(generated_image[0])
# Save the images and metadata locally
image_name = f'{i:05}.png'
data_name = f'{i:05}.json'
img_local_path = os.path.join(sample_dir, image_name)
data_local_path = os.path.join(sample_dir, data_name)
img.save(img_local_path)
metadata = {
'image_name': image_name,
'prompt': caption,
'guidance_scale': self.guidance_scale,
'seed': seed
}
json.dump(metadata, open(f'{data_local_path}', 'w'))
# Upload the image and metadata to cloud storage
output_sample_prefix = os.path.join(self.output_prefix, f'{sample_id:0>5}', 'samples')
if self.output_bucket is not None:
self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, image_name),
filename=img_local_path)
# Upload the metadata
self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, data_name),
filename=data_local_path)
46 changes: 24 additions & 22 deletions diffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Generate images from a model."""

import operator
from typing import List
from typing import Any, List, Optional

import hydra
from composer import Algorithm, ComposerModel
Expand All @@ -16,7 +16,20 @@
from omegaconf import DictConfig
from torch.utils.data import Dataset

from diffusion.evaluation.generate_images import ImageGenerator

def _make_dataset(config: DictConfig, tokenizer: Optional[Any] = None) -> Dataset:
if config.hf_dataset:
if dist.get_local_rank() == 0:
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
elif tokenizer:
dataset = hydra.utils.instantiate(config.dataset)

else:
dataset: Dataset = hydra.utils.instantiate(config.dataset)
return dataset


def generate(config: DictConfig) -> None:
Expand All @@ -37,20 +50,6 @@ def generate(config: DictConfig) -> None:

tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None

# The dataset to use for evaluation

if config.hf_dataset:
if dist.get_local_rank() == 0:
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
elif tokenizer:
dataset = hydra.utils.instantiate(config.dataset)

else:
dataset: Dataset = hydra.utils.instantiate(config.dataset)

# Build list of algorithms.
algorithms: List[Algorithm] = []

Expand Down Expand Up @@ -78,12 +77,15 @@ def generate(config: DictConfig) -> None:
precision=Precision(ag_conf['precision']),
optimizers=None,
)

image_generator: ImageGenerator = hydra.utils.instantiate(config.generator,
model=model,
dataset=dataset,
hf_model=config.hf_model,
hf_dataset=config.hf_dataset)
if 'dataset' in config:
dataset = _make_dataset(config, tokenizer)
image_generator = hydra.utils.instantiate(config.generator,
model=model,
dataset=dataset,
hf_model=config.hf_model,
hf_dataset=config.hf_dataset)
else:
image_generator = hydra.utils.instantiate(config.generator, model=model, hf_model=config.hf_model)

def generate_from_model():
image_generator.generate()
Expand Down
77 changes: 77 additions & 0 deletions yamls/mosaic-yamls/geneval-flux-1-schnell.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Example yaml for running geneval on FLUX.1-schnell model
name: geneval-flux-1-schnell
compute:
cluster: # your cluster name
instance: # your instance name
gpus: # number of gpus
env_variables:
HYDRA_FULL_ERROR: '1'
image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04
scheduling:
resumable: false
priority: medium
max_retries: 0
integrations:
- integration_type: git_repo
git_repo: mosaicml/diffusion
git_branch: main
pip_install: .[all] --no-deps # We install with no deps to use only specific deps needed for geneval
- integration_type: pip_packages
packages:
- huggingface-hub[hf_transfer]>=0.23.2
- numpy==1.26.4
- pandas
- open_clip_torch
- clip-benchmark
- openmim
- sentencepiece
- mosaicml
- mosaicml-streaming
- hydra-core
- hydra-colorlog
- diffusers[torch]==0.30.3
- transformers[torch]==4.44.2
- torchmetrics[image]
- lpips
- clean-fid
- gradio
- datasets
- peft
command: 'cd diffusion

pip install clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33

mim install mmengine mmcv-full==1.7.2

apt-get update && apt-get install libgl1-mesa-glx -y

git clone https://github.com/djghosh13/geneval.git

git clone https://github.com/open-mmlab/mmdetection.git

cd mmdetection; git checkout 2.x; pip install -v -e .; cd ..

composer run_generation.py --config-path /mnt/config --config-name parameters

cd geneval

./evaluation/download_models.sh eval_models

python evaluation/evaluate_images.py /tmp/geneval-images --outfile outputs.jsonl --model-path eval_models

python evaluation/summary_scores.py outputs.jsonl
'
parameters:
seed: 18
dist_timeout: 300
hf_model: true # We will use a model from huggingface
model:
name: black-forest-labs/FLUX.1-schnell # Model name from huggingface
generator:
_target_: diffusion.evaluation.generate_geneval_images.GenevalImageGenerator
geneval_prompts: geneval/prompts/evaluation_metadata.jsonl # Path to geneval prompts json
height: 1024 # Generated image height
width: 1024 # Generated image width
local_prefix: /tmp/geneval-images # Local path to save images to. Needed for geneval to read images from.
output_bucket: # Your output oci bucket name (optional)
output_prefix: # Your output prefix (optional)
Loading