Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unet v2 #75

Merged
merged 1 commit into from
Dec 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Emojich.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ from rudalle import get_emojich_unet

device = 'cuda'
emojich_unet = get_emojich_unet('unet_effnetb7').to(device)
rgba_images = convert_emoji_to_rgba(sr_images, emojich_unet, device=device)
rgba_images, _ = convert_emoji_to_rgba(sr_images, emojich_unet, device=device)
for rgba_image in rgba_images:
show_rgba(rgba_image);
```
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)

```
pip install rudalle==0.3.0
pip install rudalle==0.4.0
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) \
Expand Down
2 changes: 1 addition & 1 deletion rudalle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@
'image_prompts',
]

__version__ = '0.3.0'
__version__ = '0.4.0'
10 changes: 3 additions & 7 deletions rudalle/emojich_unet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@
'unet_effnetb5': dict(
encoder_name='efficientnet-b5',
repo_id='sberbank-ai/rudalle-Emojich',
filename='pytorch_model.bin',
),
'unet_effnetb7': dict(
encoder_name='efficientnet-b7',
repo_id='sberbank-ai/rudalle-Emojich',
filename='pytorch_model.bin',
filename='pytorch_model_v2.bin',
classes=2,
),
}

Expand All @@ -33,7 +29,7 @@ def get_emojich_unet(name, cache_dir='/tmp/rudalle'):
encoder_name=config['encoder_name'],
encoder_weights=None,
in_channels=3,
classes=1,
classes=config['classes'],
)
cache_dir = os.path.join(cache_dir, name)
filename = config['filename']
Expand Down
44 changes: 33 additions & 11 deletions rudalle/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,25 @@ def show(pil_images, nrow=4, size=14, save_dir=None, show=True):
plt.show()


def convert_emoji_to_rgba(pil_images, emojich_unet, device='cpu', bs=1):
final_images = []
def classic_convert_emoji_to_rgba(np_image, lower_thr=240, upper_thr=255, width=2):
img = np_image[:, :, :3].copy()
lower = np.array([lower_thr, lower_thr, lower_thr], dtype='uint8')
upper = np.array([upper_thr, upper_thr, upper_thr], dtype='uint8')
mask = cv2.inRange(img, lower, upper)
ret, thresh = cv2.threshold(mask, 0, 255, 0)
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
a_channel = np.ones((512, 512), dtype=np.uint8)*255
if len(contours) != 0:
contours = sorted(contours, key=lambda x: x.shape[0])[-7:]
cv2.fillPoly(a_channel, contours, (0, 0, 0))
cv2.drawContours(a_channel, contours, -1, (0, 0, 0), width)
img = cv2.cvtColor(img, cv2.COLOR_RGB2RGBA)
img[:, :, 3] = a_channel
return img


def convert_emoji_to_rgba(pil_images, emojich_unet, device='cpu', bs=1, score_thr=0.99):
final_images, runs = [], []
with torch.no_grad():
for chunk in more_itertools.chunked(pil_images, bs):
images = []
Expand All @@ -129,20 +146,25 @@ def convert_emoji_to_rgba(pil_images, emojich_unet, device='cpu', bs=1):
image = torch.from_numpy(image).permute(2, 0, 1)
images.append(image)
images = torch.nn.utils.rnn.pad_sequence(images, batch_first=True)
pred_masks = emojich_unet(images.to(device))[:, 0, :, :]
pred_masks = torch.sigmoid(pred_masks)
pred_masks = (pred_masks > 0.5).int().cpu().numpy()
pred_masks = emojich_unet(images.to(device))
pred_masks = torch.softmax(pred_masks, 1)
scores, pred_masks = torch.max(pred_masks, 1)
pred_masks = pred_masks.int().cpu().numpy()
pred_masks = (pred_masks * 255).astype(np.uint8)
for pil_image, pred_mask in zip(chunk, pred_masks):
ret, thresh = cv2.threshold(pred_mask, 0, 255, 0)
contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
cv2.drawContours(pred_mask, contours, -1, (0, 0, 0), 1)
for pil_image, pred_mask, score in zip(chunk, pred_masks, scores):
score = score.mean().item()
final_image = np.zeros((512, 512, 4), np.uint8)
final_image[:, :, :3] = np.array(pil_image.resize((512, 512)))[:, :, :3]
final_image[:, :, -1] = pred_mask
if score > score_thr:
run = 'unet'
final_image[:, :, -1] = pred_mask
else:
run = 'classic'
final_image = classic_convert_emoji_to_rgba(final_image)
final_image = Image.fromarray(final_image)
final_images.append(final_image)
return final_images
runs.append(run)
return final_images, runs


def show_rgba(rgba_pil_image):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_emojich_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
def test_convert_emoji_to_rgba(sample_image, emojich_unet):
img = sample_image.copy()
img = img.resize((512, 512))
rgba_img = convert_emoji_to_rgba([img], emojich_unet)[0]
rgba_images, runs = convert_emoji_to_rgba([img], emojich_unet, score_thr=0.99)
assert len(runs) == len(rgba_images)
rgba_img = rgba_images[0]
assert rgba_img.size[0] == 512
assert rgba_img.size[1] == 512
assert np.array(rgba_img).shape[-1] == 4
assert runs[0] in ['unet', 'classic']