diff --git a/Emojich.md b/Emojich.md index 1b4a39f..abd8e7d 100644 --- a/Emojich.md +++ b/Emojich.md @@ -68,7 +68,7 @@ from rudalle import get_emojich_unet device = 'cuda' emojich_unet = get_emojich_unet('unet_effnetb7').to(device) -rgba_images = convert_emoji_to_rgba(sr_images, emojich_unet, device=device) +rgba_images, _ = convert_emoji_to_rgba(sr_images, emojich_unet, device=device) for rgba_image in rgba_images: show_rgba(rgba_image); ``` diff --git a/README.md b/README.md index fafc20d..70272ad 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master) ``` -pip install rudalle==0.3.0 +pip install rudalle==0.4.0 ``` ### 🤗 HF Models: [ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich) \ diff --git a/rudalle/__init__.py b/rudalle/__init__.py index 6149751..c706022 100644 --- a/rudalle/__init__.py +++ b/rudalle/__init__.py @@ -24,4 +24,4 @@ 'image_prompts', ] -__version__ = '0.3.0' +__version__ = '0.4.0' diff --git a/rudalle/emojich_unet/__init__.py b/rudalle/emojich_unet/__init__.py index cbd4818..6e79361 100644 --- a/rudalle/emojich_unet/__init__.py +++ b/rudalle/emojich_unet/__init__.py @@ -9,12 +9,8 @@ 'unet_effnetb5': dict( encoder_name='efficientnet-b5', repo_id='sberbank-ai/rudalle-Emojich', - filename='pytorch_model.bin', - ), - 'unet_effnetb7': dict( - encoder_name='efficientnet-b7', - repo_id='sberbank-ai/rudalle-Emojich', - filename='pytorch_model.bin', + filename='pytorch_model_v2.bin', + classes=2, ), } @@ -33,7 +29,7 @@ def get_emojich_unet(name, cache_dir='/tmp/rudalle'): encoder_name=config['encoder_name'], encoder_weights=None, in_channels=3, - classes=1, + classes=config['classes'], ) cache_dir = os.path.join(cache_dir, name) filename = config['filename'] diff --git a/rudalle/pipelines.py b/rudalle/pipelines.py index c54266d..d86544e 100644 --- a/rudalle/pipelines.py +++ b/rudalle/pipelines.py @@ -118,8 +118,25 @@ def show(pil_images, nrow=4, size=14, save_dir=None, show=True): plt.show() -def convert_emoji_to_rgba(pil_images, emojich_unet, device='cpu', bs=1): - final_images = [] +def classic_convert_emoji_to_rgba(np_image, lower_thr=240, upper_thr=255, width=2): + img = np_image[:, :, :3].copy() + lower = np.array([lower_thr, lower_thr, lower_thr], dtype='uint8') + upper = np.array([upper_thr, upper_thr, upper_thr], dtype='uint8') + mask = cv2.inRange(img, lower, upper) + ret, thresh = cv2.threshold(mask, 0, 255, 0) + contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + a_channel = np.ones((512, 512), dtype=np.uint8)*255 + if len(contours) != 0: + contours = sorted(contours, key=lambda x: x.shape[0])[-7:] + cv2.fillPoly(a_channel, contours, (0, 0, 0)) + cv2.drawContours(a_channel, contours, -1, (0, 0, 0), width) + img = cv2.cvtColor(img, cv2.COLOR_RGB2RGBA) + img[:, :, 3] = a_channel + return img + + +def convert_emoji_to_rgba(pil_images, emojich_unet, device='cpu', bs=1, score_thr=0.99): + final_images, runs = [], [] with torch.no_grad(): for chunk in more_itertools.chunked(pil_images, bs): images = [] @@ -129,20 +146,25 @@ def convert_emoji_to_rgba(pil_images, emojich_unet, device='cpu', bs=1): image = torch.from_numpy(image).permute(2, 0, 1) images.append(image) images = torch.nn.utils.rnn.pad_sequence(images, batch_first=True) - pred_masks = emojich_unet(images.to(device))[:, 0, :, :] - pred_masks = torch.sigmoid(pred_masks) - pred_masks = (pred_masks > 0.5).int().cpu().numpy() + pred_masks = emojich_unet(images.to(device)) + pred_masks = torch.softmax(pred_masks, 1) + scores, pred_masks = torch.max(pred_masks, 1) + pred_masks = pred_masks.int().cpu().numpy() pred_masks = (pred_masks * 255).astype(np.uint8) - for pil_image, pred_mask in zip(chunk, pred_masks): - ret, thresh = cv2.threshold(pred_mask, 0, 255, 0) - contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) - cv2.drawContours(pred_mask, contours, -1, (0, 0, 0), 1) + for pil_image, pred_mask, score in zip(chunk, pred_masks, scores): + score = score.mean().item() final_image = np.zeros((512, 512, 4), np.uint8) final_image[:, :, :3] = np.array(pil_image.resize((512, 512)))[:, :, :3] - final_image[:, :, -1] = pred_mask + if score > score_thr: + run = 'unet' + final_image[:, :, -1] = pred_mask + else: + run = 'classic' + final_image = classic_convert_emoji_to_rgba(final_image) final_image = Image.fromarray(final_image) final_images.append(final_image) - return final_images + runs.append(run) + return final_images, runs def show_rgba(rgba_pil_image): diff --git a/tests/test_emojich_unet.py b/tests/test_emojich_unet.py index 310860d..b649df7 100644 --- a/tests/test_emojich_unet.py +++ b/tests/test_emojich_unet.py @@ -7,7 +7,10 @@ def test_convert_emoji_to_rgba(sample_image, emojich_unet): img = sample_image.copy() img = img.resize((512, 512)) - rgba_img = convert_emoji_to_rgba([img], emojich_unet)[0] + rgba_images, runs = convert_emoji_to_rgba([img], emojich_unet, score_thr=0.99) + assert len(runs) == len(rgba_images) + rgba_img = rgba_images[0] assert rgba_img.size[0] == 512 assert rgba_img.size[1] == 512 assert np.array(rgba_img).shape[-1] == 4 + assert runs[0] in ['unet', 'classic']