Skip to content

Commit

Permalink
Compatibility with diffusers 0.3.0 (#7)
Browse files Browse the repository at this point in the history
* first changes to use dataclasses

* first changes to use dataclasses

* first changes to use dataclasses

* first changes to use dataclasses

* first changes to use dataclasses

* first changes to use dataclasses

* update notebook

* update README

* update README

* update README

* update notebooks

* fix LDMTextToImagePipelineExplainer

* bump version
  • Loading branch information
JoaoLages authored Sep 9, 2022
1 parent 6a7cbf4 commit b5989ed
Show file tree
Hide file tree
Showing 8 changed files with 2,826 additions and 2,734 deletions.
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

Install directly from PyPI:

pip install diffusers-interpret
pip install --upgrade diffusers-interpret

## Usage

Expand Down Expand Up @@ -53,7 +53,7 @@ with torch.autocast('cuda') if device == 'cuda' else nullcontext():

To see the final generated image:
```python
output['sample']
output.image
```

![](assets/corgi_eiffel_tower.png)
Expand All @@ -63,7 +63,7 @@ You can also check all the images that the diffusion process generated at the en

To analyse how a token in the input `prompt` influenced the generation, you can study the token attribution scores:
```python
>>> output['token_attributions'] # (token, attribution)
>>> output.token_attributions # (token, attribution)
[('a', 1063.0526),
('cute', 415.62888),
('corgi', 6430.694),
Expand All @@ -78,7 +78,7 @@ To analyse how a token in the input `prompt` influenced the generation, you can

Or their computed normalized version, in percentage:
```python
>>> output['normalized_token_attributions'] # (token, attribution_percentage)
>>> output.normalized_token_attributions # (token, attribution_percentage)
[('a', 3.884),
('cute', 1.519),
('corgi', 23.495),
Expand All @@ -104,7 +104,7 @@ with torch.autocast('cuda') if device == 'cuda' else nullcontext():
generator=generator,
explanation_2d_bounding_box=((70, 180), (400, 435)), # (upper left corner, bottom right corner)
)
output['sample']
output.image
```
![](assets/corgi_eiffel_tower_box_1.png)

Expand All @@ -113,7 +113,7 @@ The generated image now has a <span style="color:red"> **red bounding box** </sp
The token attributions are now computed only for the area specified in the image.

```python
>>> output['normalized_token_attributions'] # (token, attribution_percentage)
>>> output.normalized_token_attributions # (token, attribution_percentage)
[('a', 1.891),
('cute', 1.344),
('corgi', 23.115),
Expand All @@ -130,6 +130,8 @@ Check other functionalities and more implementation examples in [here](https://g

## Future Development
- [x] ~~Add interactive display of all the images that were generated in the diffusion process~~
- [ ] Add explainer for StableDiffusionImg2ImgPipeline
- [ ] Add explainer for StableDiffusionInpaintPipeline
- [ ] Add interactive bounding-box and token attributions visualization
- [ ] Add unit tests
- [ ] Add example for `diffusers_interpret.LDMTextToImagePipelineExplainer`
Expand Down
1,745 changes: 877 additions & 868 deletions notebooks/stable-diffusion-example.ipynb

Large diffs are not rendered by default.

3,559 changes: 1,780 additions & 1,779 deletions notebooks/stable_diffusion_example_colab.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
transformers>=4.21.1
setuptools>=49.6.0
torch>=1.9.1
diffusers~=0.2.4
diffusers~=0.3.0
scipy>=1.7.3
ftfy>=6.1.1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setup(
name='diffusers-interpret',
version='0.2.4',
version='0.3.0',
description='diffusers-interpret: model explainability for 🤗 Diffusers',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down
125 changes: 95 additions & 30 deletions src/diffusers_interpret/explainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Union, Dict, Any, Tuple, Set
from dataclasses import dataclass
from typing import List, Optional, Union, Tuple, Set

import torch
import numpy as np
from PIL import ImageDraw
from PIL.Image import Image
from diffusers import DiffusionPipeline
from transformers import BatchEncoding, PreTrainedTokenizerBase

Expand All @@ -11,10 +14,55 @@
from diffusers_interpret.utils import clean_token_from_prefixes_and_suffixes


@dataclass
class BaseMimicPipelineCallOutput:
"""
Output class for BasePipelineExplainer._mimic_pipeline_call
Args:
images (`List[Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
all_images_during_generation (`Optional[List[List[Image]]]`)
A list with all the batch images generated during diffusion
"""
images: Union[List[Image], np.ndarray, torch.Tensor]
nsfw_content_detected: List[bool]
all_images_during_generation: Optional[List[List[Image]]]

def __getitem__(self, item):
return getattr(self, item)

def __setitem__(self, key, value):
setattr(self, key, value)


@dataclass
class PipelineExplainerOutput:
image: Union[Image, np.ndarray, torch.Tensor]
nsfw_content_detected: List[bool]
all_images_during_generation: Optional[GeneratedImages]
token_attributions: Optional[List[Tuple[str, float]]] = None
normalized_token_attributions: Optional[List[Tuple[str, float]]] = None

def __getitem__(self, item):
return getattr(self, item)

def __setitem__(self, key, value):
setattr(self, key, value)


class BasePipelineExplainer(ABC):
def __init__(self, pipe: DiffusionPipeline, verbose: bool = True, gradient_checkpointing: bool = False) -> None:
self.pipe = pipe
self.verbose = verbose
self.pipe._progress_bar_config = {
**(getattr(self.pipe, '_progress_bar_config', {}) or {}),
'disable': not verbose
}
self.gradient_checkpointing = gradient_checkpointing
if self.gradient_checkpointing:
self.gradient_checkpointing_enable()
Expand All @@ -32,18 +80,23 @@ def __call__(
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = 'pil',
return_dict: bool = True,
run_safety_checker: bool = False,
n_last_diffusion_steps_to_consider_for_attributions: Optional[int] = None,
get_images_for_all_inference_steps: bool = True
) -> Dict[str, Any]:
) -> PipelineExplainerOutput:
# TODO: add description

if attribution_method != 'grad_x_input':
raise NotImplementedError("Only `attribution_method='grad_x_input'` is implemented for now")

if isinstance(prompt, str):
batch_size = 1 # TODO: make compatible with bigger batch sizes
elif isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], str):
batch_size = len(prompt)
raise NotImplementedError("Passing a list of strings in `prompt` is still not implemented yet.")
else:
raise ValueError(f"`prompt` has to be of type `str` but is {type(prompt)}")

Expand Down Expand Up @@ -73,57 +126,64 @@ def __call__(
guidance_scale=guidance_scale,
eta=eta,
generator=generator,
latents=latents,
output_type=None,
return_dict=return_dict,
run_safety_checker=run_safety_checker,
n_last_diffusion_steps_to_consider_for_attributions=n_last_diffusion_steps_to_consider_for_attributions,
get_images_for_all_inference_steps=get_images_for_all_inference_steps
)

if output['nsfw_content_detected']:
output = PipelineExplainerOutput(
image=output.images[0], nsfw_content_detected=output.nsfw_content_detected,
all_images_during_generation=output.all_images_during_generation
)

if output.nsfw_content_detected:
raise Exception(
"NSFW content was detected, it is not possible to provide an explanation. "
"Try to set `run_safety_checker=False` if you really want to skip the NSFW safety check."
)

# Get primary attribution scores
output['token_attributions'] = None
output['normalized_token_attributions'] = None
output.token_attributions = None
output.normalized_token_attributions = None
if calculate_attributions and attribution_method == 'grad_x_input':

if self.verbose:
print("Calculating token attributions... ", end='')

token_attributions = gradient_x_inputs_attribution(
pred_logits=output['sample'][0], input_embeds=text_embeddings,
pred_logits=output.image, input_embeds=text_embeddings,
explanation_2d_bounding_box=explanation_2d_bounding_box
)
token_attributions = token_attributions.detach().cpu().numpy()

# remove special tokens
assert len(token_attributions) == len(tokens)
output['token_attributions'] = []
output['normalized_token_attributions'] = []
for sample_token_attributions, sample_tokens in zip(token_attributions, tokens):
assert len(sample_token_attributions) == len(sample_tokens)
output.token_attributions = []
output.normalized_token_attributions = []
for image_token_attributions, image_tokens in zip(token_attributions, tokens):
assert len(image_token_attributions) == len(image_tokens)

# Add token attributions
output['token_attributions'].append([])
for attr, token in zip(sample_token_attributions, sample_tokens):
output.token_attributions.append([])
for attr, token in zip(image_token_attributions, image_tokens):
if consider_special_tokens or token not in self.special_tokens_attributes:

if clean_token_prefixes_and_suffixes:
token = clean_token_from_prefixes_and_suffixes(token)

output['token_attributions'][-1].append(
output.token_attributions[-1].append(
(token, attr)
)

# Add normalized
total = sum([attr for _, attr in output['token_attributions'][-1]])
output['normalized_token_attributions'].append(
total = sum([attr for _, attr in output.token_attributions[-1]])
output.normalized_token_attributions.append(
[
(token, round(100 * attr / total, 3))
for token, attr in output['token_attributions'][-1]
for token, attr in output.token_attributions[-1]
]
)

Expand All @@ -135,33 +195,36 @@ def __call__(

if batch_size == 1:
# squash batch dimension
for k in ['sample', 'token_attributions', 'normalized_token_attributions']:
for k in ['image', 'token_attributions', 'normalized_token_attributions']:
if output[k] is not None:
output[k] = output[k][0]
if output['all_samples_during_generation']:
output['all_samples_during_generation'] = [b[0] for b in output['all_samples_during_generation']]
if output.all_images_during_generation:
output.all_images_during_generation = [b[0] for b in output.all_images_during_generation]

else:
raise NotImplementedError

# convert to PIL Image if requested
# also draw bounding box in the last image if requested
if output['all_samples_during_generation'] or output_type == "pil":
all_samples = GeneratedImages(
all_generated_images=output['all_samples_during_generation'] or [output['sample']],
if output.all_images_during_generation or output_type == "pil":
all_images = GeneratedImages(
all_generated_images=output.all_images_during_generation or [output.image],
pipe=self.pipe,
remove_batch_dimension=batch_size==1,
prepare_image_slider=bool(output['all_samples_during_generation'])
prepare_image_slider=bool(output.all_images_during_generation)
)
if output['all_samples_during_generation']:
output['all_samples_during_generation'] = all_samples
sample = output['all_samples_during_generation'][-1]
if output.all_images_during_generation:
output.all_images_during_generation = all_images
image = output.all_images_during_generation[-1]
else:
sample = all_samples[-1]
image = all_images[-1]

if explanation_2d_bounding_box:
draw = ImageDraw.Draw(sample)
draw = ImageDraw.Draw(image)
draw.rectangle(explanation_2d_bounding_box, outline="red")

if output_type == "pil":
output['sample'] = sample
output.image = image

return output

Expand Down Expand Up @@ -214,9 +277,11 @@ def _mimic_pipeline_call(
guidance_scale: Optional[float] = 7.5,
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = 'pil',
return_dict: bool = True,
run_safety_checker: bool = True,
n_last_diffusion_steps_to_consider_for_attributions: Optional[int] = None,
get_images_for_all_inference_steps: bool = False
) -> Dict[str, Any]:
) -> BaseMimicPipelineCallOutput:
raise NotImplementedError
Loading

0 comments on commit b5989ed

Please sign in to comment.