Skip to content

Commit

Permalink
Draft: new malevich (#91)
Browse files Browse the repository at this point in the history
* new malevich

* add att mask tests

* move import cv2 in required method

* support hf versions

* fix changes

* rc ruclip

* rc ruclip

* rc ruclip

* fix image latents

* mal v2

* fix setup

* add ppl image

* update ppl scores
  • Loading branch information
shonenkov authored Jan 11, 2022
1 parent 7cce1f6 commit c336514
Show file tree
Hide file tree
Showing 18 changed files with 238 additions and 160 deletions.
2 changes: 1 addition & 1 deletion LICENSE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright [2020] [sberbank-ai]
Copyright [2021] [sberbank-ai]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
27 changes: 10 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)

```
pip install rudalle==0.4.0
pip install rudalle==1.0.0
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) \
Expand All @@ -29,8 +29,9 @@ pip install rudalle==0.4.0

### generation by ruDALLE:
```python
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
import ruclip
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_ruclip
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan
from rudalle.utils import seed_everything

# prepare models:
Expand All @@ -41,25 +42,17 @@ vae = get_vae(dwt=True).to(device)

# pipeline utils:
realesrgan = get_realesrgan('x2', device=device)
ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5')
ruclip = ruclip.to(device)

text = 'изображение радуги на фоне ночного города'
clip, processor = ruclip.load('ruclip-vit-base-patch32-384', device=device)
clip_predictor = ruclip.Predictor(clip, processor, device, bs=8)
text = 'радуга на фоне ночного города'

seed_everything(42)
pil_images = []
scores = []
for top_k, top_p, images_num in [
(2048, 0.995, 3),
(1536, 0.99, 3),
(1024, 0.99, 3),
(1024, 0.98, 3),
(512, 0.97, 3),
(384, 0.96, 3),
(256, 0.95, 3),
(128, 0.95, 3),
(2048, 0.995, 24),
]:
_pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p)
_pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, bs=8, top_p=top_p)
pil_images += _pil_images
scores += _scores

Expand All @@ -68,7 +61,7 @@ show(pil_images, 6)
![](pics/malevich/rainbow-full.png)
### auto cherry-pick by ruCLIP:
```python
top_images, clip_scores = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=6)
top_images, clip_scores = cherry_pick_by_ruclip(pil_images, text, clip_predictor, count=6)
show(top_images, 3)
```
![](pics/malevich/rainbow-cherry-pick.png)
Expand Down
Binary file modified pics/malevich/rainbow-cherry-pick.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified pics/malevich/rainbow-full.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified pics/malevich/rainbow-super-resolution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
taming-transformers==0.0.1
more_itertools~=8.10.0
more_itertools~=8.12.0
transformers~=4.10.2
youtokentome~=1.0.6
omegaconf>=2.0.0
Expand Down
7 changes: 2 additions & 5 deletions rudalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,22 @@
from .dalle import get_rudalle_model
from .tokenizer import get_tokenizer
from .realesrgan import get_realesrgan
from .ruclip import get_ruclip
from .emojich_unet import get_emojich_unet
from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip, image_prompts
from . import vae, dalle, tokenizer, realesrgan, pipelines, image_prompts


__all__ = [
'get_vae',
'get_rudalle_model',
'get_tokenizer',
'get_realesrgan',
'get_ruclip',
'get_emojich_unet',
'vae',
'dalle',
'ruclip',
'tokenizer',
'realesrgan',
'pipelines',
'image_prompts',
]

__version__ = '0.4.0'
__version__ = '1.0.0'
28 changes: 27 additions & 1 deletion rudalle/dalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,30 @@

MODELS = {
'Malevich': dict(
hf_version='v3',
description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, '
'that uses Russian language and text+image multi-modality.',
model_params=dict(
num_layers=24,
hidden_size=2048,
num_attention_heads=16,
embedding_dropout_prob=0.1,
output_dropout_prob=0.1,
attention_dropout_prob=0.1,
image_tokens_per_dim=32,
text_seq_length=128,
cogview_sandwich_layernorm=True,
cogview_pb_relax=True,
vocab_size=16384 + 128,
image_vocab_size=8192,
),
repo_id='sberbank-ai/rudalle-Malevich',
filename='pytorch_model_v3.bin',
authors='SberAI, SberDevices',
full_description='', # TODO
),
'Malevich_v2': dict(
hf_version='v2',
description='◼️ Malevich is 1.3 billion params model from the family GPT3-like, '
'that uses Russian language and text+image multi-modality.',
model_params=dict(
Expand All @@ -32,6 +56,7 @@
full_description='', # TODO
),
'Emojich': dict(
hf_version='v2',
description='😋 Emojich is a 1.3 billion params model from the family GPT3-like, '
'it generates emoji-style images with the brain of ◾ Malevich.',
model_params=dict(
Expand All @@ -54,6 +79,7 @@
full_description='', # TODO
),
'Kandinsky': dict(
hf_version='v3',
description='Kandinsky is large 12 billion params model from the family GPT3-like, '
'that uses Russian language and text+image multi-modality.',
model_params=dict(
Expand All @@ -77,7 +103,7 @@
authors='SberAI, SberDevices',
full_description='', # TODO
),
'small': dict(
'dummy': dict(
description='',
model_params=dict(
num_layers=12,
Expand Down
14 changes: 9 additions & 5 deletions rudalle/dalle/image_attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-

import torch


Expand All @@ -26,7 +25,7 @@ def get_col_mask(text_tokens=256, image_tokens_per_dim=32, is_bool_mask=False):
return mask


def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11, is_bool_mask=False):
def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11, is_bool_mask=False, hf_version='v3'):
mask = _init_mask(text_tokens, image_tokens_per_dim, is_bool_mask=is_bool_mask)
shift = kernel // 2
for pos in range(text_tokens, mask.size(1)):
Expand All @@ -37,12 +36,17 @@ def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11, is_bool_m
col = pixel_id % image_tokens_per_dim
for r in range(-shift, shift+1):
for c in range(-shift, shift+1):
c_abs = (c + col) % image_tokens_per_dim
r_abs = (r + row) % image_tokens_per_dim
if hf_version == 'v2':
c_abs = (c + col) % image_tokens_per_dim
r_abs = (r + row) % image_tokens_per_dim
elif hf_version == 'v3':
c_abs = max(min(c + col, image_tokens_per_dim - 1), 0)
r_abs = max(min(r + row, image_tokens_per_dim - 1), 0)
else:
raise ValueError(f'Unknown hf_version: {hf_version}')
img[r_abs, c_abs] = 0.2
cell_id = r_abs * image_tokens_per_dim + c_abs
if text_tokens + cell_id > pos:
mask[text_tokens + cell_id, pos] = True if is_bool_mask else 1.0

img[row, col] = 1.0
return mask
26 changes: 19 additions & 7 deletions rudalle/dalle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self,
cogview_layernorm_prescale=False,
custom_relax=False,
is_bool_mask=True,
mlp_activation='gelu_jit'):
mlp_activation='gelu_jit',
hf_version='v3'):
super(DalleModel, self).__init__()
self.device = device
self.image_tokens_per_dim = image_tokens_per_dim
Expand All @@ -38,6 +39,8 @@ def __init__(self,
self.vocab_size = vocab_size
self.loss_img_weight = loss_img_weight

self.hf_version = hf_version

init_method = init_method_normal(std=0.02)

self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size)
Expand Down Expand Up @@ -74,6 +77,7 @@ def __init__(self,
custom_relax=custom_relax,
mlp_activation=mlp_activation,
is_bool_mask=is_bool_mask,
hf_version=self.hf_version,
)

def get_param(self, item):
Expand Down Expand Up @@ -103,8 +107,9 @@ def forward(
text_range += (self.vocab_size - self.text_seq_length)
text_range = text_range.to(self.device)
text = torch.where(text == 0, text_range, text)
# some hardcode :)
text = F.pad(text, (1, 0), value=2)
if self.hf_version == 'v2':
# some hardcode :)
text = F.pad(text, (1, 0), value=2)
text_pos = self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device))
text_embeddings = self.text_embeddings(text) + text_pos
image_input_ids = input_ids[:, self.text_seq_length:]
Expand All @@ -115,9 +120,11 @@ def forward(
embeddings = torch.cat((text_embeddings, image_embeddings), dim=1)
else:
embeddings = text_embeddings
# some hardcode :)
if embeddings.shape[1] > self.total_seq_length:
embeddings = embeddings[:, :-1]

if self.hf_version == 'v2':
# some hardcode :)
if embeddings.shape[1] > self.total_seq_length:
embeddings = embeddings[:, :-1]

alpha = 0.1
embeddings = embeddings * alpha + embeddings.detach() * (1-alpha)
Expand All @@ -136,7 +143,12 @@ def forward(
logits = rearrange(logits, 'b n c -> b c n')

text_logits = logits[:, :self.vocab_size, :self.text_seq_length].contiguous().float()
image_logits = logits[:, self.vocab_size:, self.text_seq_length:].contiguous().float()
if self.hf_version == 'v3':
image_logits = logits[:, self.vocab_size:, self.text_seq_length:-1].contiguous().float()
elif self.hf_version == 'v2':
image_logits = logits[:, self.vocab_size:, self.text_seq_length:].contiguous().float()
else:
raise ValueError(f'Unknown hf_version: {self.hf_version}')

loss_text = F.cross_entropy(
text_logits,
Expand Down
8 changes: 5 additions & 3 deletions rudalle/dalle/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,16 @@ def __init__(self,
cogview_layernorm_prescale=False,
custom_relax=False,
mlp_activation='gelu_jit',
is_bool_mask=False):
is_bool_mask=False,
hf_version='v3'):
super(DalleTransformer, self).__init__()

self.num_layers = num_layers
# CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
self.cogview_pb_relax = cogview_pb_relax
# Additional stabilization tweak for large models
self.custom_relax = custom_relax

self.hf_version = hf_version
# Transformer layers.
self.layers = torch.nn.ModuleList([
DalleTransformerLayer(
Expand All @@ -112,7 +113,8 @@ def __init__(self,

row_mask = get_row_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask)
col_mask = get_col_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask)
conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask)
conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim, is_bool_mask=is_bool_mask,
hf_version=self.hf_version)
self.register_buffer('row_mask', row_mask)
self.register_buffer('col_mask', col_mask)
self.register_buffer('conv_mask', conv_mask)
Expand Down
Loading

0 comments on commit c336514

Please sign in to comment.