Skip to content

Commit

Permalink
Merge pull request #108 from nakajima-john-shotaro/develop/develop
Browse files Browse the repository at this point in the history
Develop/develop
  • Loading branch information
urasakikeisuke authored Oct 3, 2021
2 parents 93a07f3 + f58ccf4 commit c48dcf7
Show file tree
Hide file tree
Showing 10 changed files with 919 additions and 899 deletions.
2 changes: 1 addition & 1 deletion aicon/backend/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
BACKBONE_NAME_RN50x4: str = "RN50x4"
BACKBONE_NAME_ViTB32: str = "ViT-B32"

PRETRAINED_BACKBONE_MODEL_PATH: str = "~/.cache"
PRETRAINED_BACKBONE_MODEL_PATH: str = "/home/user/.cache"

ACCEPTABLE_BACKBONE: List[str] = [BACKBONE_NAME_RN50, BACKBONE_NAME_RN101, BACKBONE_NAME_RN50x4, BACKBONE_NAME_ViTB32]

Expand Down
46 changes: 19 additions & 27 deletions aicon/backend/models/big_sleep/big_sleep.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from base64 import b64decode
from io import BytesIO
import os
import sys
from queue import Empty
Expand All @@ -20,7 +22,6 @@
from imageio import get_writer
from PIL import Image
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torchvision.transforms.transforms import Compose

Expand All @@ -39,17 +40,12 @@ def exists(val: Any) -> bool:
return val is not None


def create_text_path(text: Optional[str] = None, img: Optional[str] = None) -> str:
def create_text_path(text: Optional[str] = None) -> str:
input_name: str = ""

if text is not None:
input_name += text

if img is not None:
img_name: str = "".join(img.split(".")[:-1])
img_name = img_name.split("/")[-1]
input_name += "_" + img_name

return input_name.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:255]

# tensor helpers
Expand Down Expand Up @@ -147,7 +143,7 @@ def __init__(

assert image_width in (128, 256, 512), 'image size must be one of 128, 256, or 512'

self.biggan = BigGAN.from_pretrained(f'biggan-deep-{image_width}')
self.biggan = BigGAN.from_pretrained(f'biggan-deep-{image_width}', cache_dir=PRETRAINED_BACKBONE_MODEL_PATH)
self.max_classes = max_classes
self.class_temperature = class_temperature
self.ema_decay\
Expand Down Expand Up @@ -236,8 +232,7 @@ def forward(self, text_embeds, stick_embeds=[], return_loss = True):
into = torch.cat(pieces)
into = self.normalize_image(into)

with autocast(enabled=False):
image_embed = self.perceptor.encode_image(into)
image_embed = self.perceptor.encode_image(into)

latents, soft_one_hot_classes = self.model.latents()
num_latents = latents.shape[0]
Expand Down Expand Up @@ -306,8 +301,13 @@ def __init__(
iterations: int = int(self.client_data[RECEIVED_DATA][JSON_TOTAL_ITER])
gradient_accumulate_every: int = int(self.client_data[RECEIVED_DATA][JSON_GAE])
model_name: str = self.client_data[RECEIVED_DATA][JSON_BACKBONE]
if self.client_data[RECEIVED_DATA][JSON_SOURCE_IMG] is not None:
source_img: Image = Image.open(BytesIO(b64decode((self.client_data[RECEIVED_DATA][JSON_SOURCE_IMG])))).convert('RGB')
else:
source_img = None

self.c2i_queue: Queue = self.client_data[CORE_C2I_QUEUE]
self.c2i_brake_queue: Queue = self.client_data[CORE_C2I_BREAK_QUEUE]
self.c2i_event: Event_ = self.client_data[CORE_C2I_EVENT]
self.i2c_event: Event_ = self.client_data[CORE_I2C_EVENT]

Expand Down Expand Up @@ -359,8 +359,6 @@ def __init__(
clip_model=clip_model,
).cuda()

self.scaler: GradScaler = GradScaler()

self.model: BigSleep = model

self.lr: float = lr
Expand All @@ -377,7 +375,7 @@ def __init__(
self.clip_transform = create_clip_img_transform(224)

# create starting encoding
self.set_clip_encoding(text=text, img=img, stick=stick)
self.set_clip_encoding(text=text, img=source_img, stick=stick)

def set_text(self, text: Optional[str] = None) -> None:
self.set_clip_encoding(text=text)
Expand Down Expand Up @@ -428,7 +426,7 @@ def set_clip_encoding(self, text: Optional[str] = None, img: Optional[str] = Non

if len(stick) > 0:
text = text + "_wout_" + stick[:255] if text is not None else "wout_" + stick[:255]
text_path = create_text_path(text=text, img=img)
text_path = create_text_path(text=text)

self.text_path = text_path
self.filename = Path(f'./{text_path}.png')
Expand Down Expand Up @@ -456,22 +454,15 @@ def get_img_sequence_number(self, epoch: int, iteration: int) -> int:
return sequence_number

def train_step(self, epoch: int, iteration: int) -> None:
total_loss: float = 0
total_loss = 0

for _ in range(self.gradient_accumulate_every):
with autocast(enabled=True):
loss: torch.Tensor
_, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])

loss: torch.Tensor = sum(losses) / self.gradient_accumulate_every
self.scaler.scale(loss).backward()

total_loss = total_loss + loss.item()

del loss
_, losses = self.model(self.encoded_texts["max"], self.encoded_texts["min"])
loss = sum(losses) / self.gradient_accumulate_every
total_loss += loss
loss.backward()

self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.step()
self.model.model.latents.update()
self.optimizer.zero_grad(set_to_none=True)

Expand Down Expand Up @@ -545,6 +536,7 @@ def forward(self) -> None:
self.put_data[JSON_IMG_PATH] = str(self.response_filename)

self.c2i_queue.put_nowait(self.put_data)
self.c2i_brake_queue.put_nowait(self.put_data)
self.c2i_event.set()

torch.cuda.empty_cache()
59 changes: 23 additions & 36 deletions aicon/backend/models/big_sleep/biggan.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,51 @@
# this code is a copy from huggingface
# with some minor modifications

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import json
import copy
import json
import logging
import math
import os
import shutil
import sys
import tempfile
from functools import wraps
from hashlib import sha256
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
from io import open

import boto3
import requests
from botocore.exceptions import ClientError
from constant import *
from tqdm import tqdm
from urllib.parse import urlparse

try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse

try:
from pathlib import Path
PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
Path.home() / '.pytorch_pretrained_biggan'))
PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(PRETRAINED_BACKBONE_MODEL_PATH)
except (AttributeError, ImportError):
PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan'))
PYTORCH_PRETRAINED_BIGGAN_CACHE = PRETRAINED_BACKBONE_MODEL_PATH

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


PRETRAINED_MODEL_ARCHIVE_MAP = {
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin",
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin",
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
'biggan-deep-128': f"{PRETRAINED_BACKBONE_MODEL_PATH}/biggan-deep-128-pytorch_model.bin",
'biggan-deep-256': f"{PRETRAINED_BACKBONE_MODEL_PATH}/biggan-deep-256-pytorch_model.bin",
'biggan-deep-512': f"{PRETRAINED_BACKBONE_MODEL_PATH}/biggan-deep-512-pytorch_model.bin",
}

PRETRAINED_CONFIG_ARCHIVE_MAP = {
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
'biggan-deep-128': f"{PRETRAINED_BACKBONE_MODEL_PATH}/biggan-deep-128-config.json",
'biggan-deep-256': f"{PRETRAINED_BACKBONE_MODEL_PATH}/biggan-deep-256-config.json",
'biggan-deep-512': f"{PRETRAINED_BACKBONE_MODEL_PATH}/biggan-deep-512-config.json",
}

WEIGHTS_NAME = 'pytorch_model.bin'
Expand Down Expand Up @@ -219,7 +218,7 @@ def get_from_cache(url, cache_dir=None):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
logger.warning("%s not found in cache, downloading to %s", url, temp_file.name)

# GET file object
if url.startswith("s3://"):
Expand Down Expand Up @@ -536,22 +535,10 @@ class BigGAN(nn.Module):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
resolved_model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
resolved_config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]

try:
resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error("Wrong model name, should be a valid path to a folder containing "
"a {} file and a {} file or a model name in {}".format(
WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
raise

logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))
logger.info("\n\nloading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))

# Load config
config = BigGANConfig.from_json_file(resolved_config_file)
Expand Down
Loading

0 comments on commit c48dcf7

Please sign in to comment.