Skip to content

Commit

Permalink
Revert "Replace most print()s with logging calls (Stability-AI#42)" (S…
Browse files Browse the repository at this point in the history
…tability-AI#65)

This reverts commit 6f6d3f8.
Jonas Müller authored Jul 26, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 7934245 commit 4a3f0f5
Showing 10 changed files with 91 additions and 117 deletions.
26 changes: 13 additions & 13 deletions sgm/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import logging
from typing import Optional

import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule

logger = logging.getLogger(__name__)

try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
raise NotImplementedError(
"Datasets not yet available. "
"To enable, we need to add stable-datasets as a submodule; "
"please use ``git submodule update --init --recursive`` "
"and do ``pip install -e stable-datasets/`` from the root of this repo"
) from e
print("#" * 100)
print("Datasets not yet available")
print("to enable, we need to add stable-datasets as a submodule")
print("please use ``git submodule update --init --recursive``")
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
print("#" * 100)
exit(1)


class StableDataModuleFromConfig(LightningDataModule):
@@ -41,8 +39,8 @@ def __init__(
"datapipeline" in self.val_config and "loader" in self.val_config
), "validation config requires the fields `datapipeline` and `loader`"
else:
logger.warning(
"No Validation datapipeline defined, using that one from training"
print(
"Warning: No Validation datapipeline defined, using that one from training"
)
self.val_config = train

@@ -54,10 +52,12 @@ def __init__(

self.dummy = dummy
if self.dummy:
logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
print("#" * 100)
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
print("#" * 100)

def setup(self, stage: str) -> None:
logger.debug("Preparing datasets")
print("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
31 changes: 15 additions & 16 deletions sgm/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import logging

import numpy as np

logger = logging.getLogger(__name__)


class LambdaWarmUpCosineScheduler:
"""
@@ -28,8 +24,9 @@ def __init__(
self.verbosity_interval = verbosity_interval

def schedule(self, n, **kwargs):
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (
self.lr_max - self.lr_start
@@ -86,11 +83,12 @@ def find_in_interval(self, n):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
logger.info(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
@@ -116,11 +114,12 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
logger.info(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)

if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
19 changes: 8 additions & 11 deletions sgm/models/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import re
from abc import abstractmethod
from contextlib import contextmanager
@@ -15,8 +14,6 @@
from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config

logger = logging.getLogger(__name__)


class AbstractAutoencoder(pl.LightningModule):
"""
@@ -41,7 +38,7 @@ def __init__(

if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@@ -63,16 +60,16 @@ def init_from_ckpt(
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
logger.debug(f"Deleting key {k} from state_dict.")
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
logger.debug(
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
logger.info(f"Missing Keys: {missing}")
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
logger.info(f"Unexpected Keys: {unexpected}")
print(f"Unexpected Keys: {unexpected}")

@abstractmethod
def get_input(self, batch) -> Any:
@@ -89,14 +86,14 @@ def ema_scope(self, context=None):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logger.info(f"{context}: Switched to EMA weights")
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logger.info(f"{context}: Restored training weights")
print(f"{context}: Restored training weights")

@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
@@ -107,7 +104,7 @@ def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("decode()-method of abstract base class called")

def instantiate_optimizer_from_config(self, params, lr, cfg):
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config")
print(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
17 changes: 7 additions & 10 deletions sgm/models/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union

@@ -19,8 +18,6 @@
log_txt_as_img,
)

logger = logging.getLogger(__name__)


class DiffusionEngine(pl.LightningModule):
def __init__(
@@ -76,7 +73,7 @@ def __init__(
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
@@ -97,13 +94,13 @@ def init_from_ckpt(
raise NotImplementedError

missing, unexpected = self.load_state_dict(sd, strict=False)
logger.info(
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
logger.info(f"Missing Keys: {missing}")
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
logger.info(f"Unexpected Keys: {unexpected}")
print(f"Unexpected Keys: {unexpected}")

def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
@@ -182,14 +179,14 @@ def ema_scope(self, context=None):
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
logger.info(f"{context}: Switched to EMA weights")
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
logger.info(f"{context}: Restored training weights")
print(f"{context}: Restored training weights")

def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(
@@ -205,7 +202,7 @@ def configure_optimizers(self):
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
logger.debug("Setting up LambdaLR scheduler...")
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
38 changes: 17 additions & 21 deletions sgm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import math
from inspect import isfunction
from typing import Any, Optional
@@ -9,10 +8,6 @@
from packaging import version
from torch import nn


logger = logging.getLogger(__name__)


if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
@@ -41,9 +36,9 @@
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
logger.warning(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. "
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading."
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
)

try:
@@ -53,7 +48,7 @@
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
logger.debug("no module 'xformers'. Processing without...")
print("no module 'xformers'. Processing without...")

from .diffusionmodules.util import checkpoint

@@ -294,7 +289,7 @@ def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
):
super().__init__()
logger.info(
print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads with a dimension of {dim_head}."
)
@@ -398,21 +393,22 @@ def __init__(
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
logger.warning(
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
logger.warning(
print(
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
)
if not XFORMERS_IS_AVAILABLE:
raise NotImplementedError(
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
logger.info("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
assert (
False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
@@ -441,7 +437,7 @@ def __init__(
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
logger.info(f"{self.__class__.__name__} is using checkpointing")
print(f"{self.__class__.__name__} is using checkpointing")

def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
@@ -558,7 +554,7 @@ def __init__(
sdp_backend=None,
):
super().__init__()
logger.debug(
print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
from omegaconf import ListConfig
@@ -567,8 +563,8 @@ def __init__(
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
logger.warning(
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
print(
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
6 changes: 1 addition & 5 deletions sgm/modules/autoencoding/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from typing import Any, Union

import torch
@@ -11,9 +10,6 @@
from ....util import default, instantiate_from_config


logger = logging.getLogger(__name__)


def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
@@ -108,7 +104,7 @@ def __init__(
super().__init__()
self.dims = dims
if self.dims > 2:
logger.info(
print(
f"running with dims={dims}. This means that for perceptual loss calculation, "
f"the LPIPS loss will be applied to each frame independently. "
)
17 changes: 7 additions & 10 deletions sgm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pytorch_diffusion + derived encoder decoder
import logging
import math
from typing import Any, Callable, Optional

@@ -9,16 +8,14 @@
from einops import rearrange
from packaging import version

logger = logging.getLogger(__name__)

try:
import xformers
import xformers.ops

XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
logger.debug("no module 'xformers'. Processing without...")
print("no module 'xformers'. Processing without...")

from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention

@@ -291,14 +288,12 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type = "vanilla-xformers"
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
logger.debug(
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
)
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
@@ -638,8 +633,10 @@ def __init__(
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logger.debug(
f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions."
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)

make_attn_cls = self._make_attn()
19 changes: 8 additions & 11 deletions sgm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import math
from abc import abstractmethod
from functools import partial
@@ -22,8 +21,6 @@
)
from ...util import default, exists

logger = logging.getLogger(__name__)


# dummy replace
def convert_module_to_f16(x):
@@ -180,13 +177,13 @@ def __init__(
self.dims = dims
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
if use_conv:
logger.debug(
f"Building a Downsample layer with {dims} dims.\n"
print(f"Building a Downsample layer with {dims} dims.")
print(
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
f"kernel-size: 3, stride: {stride}, padding: {padding}"
)
if dims == 3:
logger.debug(f" --> Downsampling third axis (time): {third_down}")
print(f" --> Downsampling third axis (time): {third_down}")
self.op = conv_nd(
dims,
self.channels,
@@ -273,7 +270,7 @@ def __init__(
2 * self.out_channels if use_scale_shift_norm else self.out_channels
)
if self.skip_t_emb:
logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}")
print(f"Skipping timestep embedding in {self.__class__.__name__}")
assert not self.use_scale_shift_norm
self.emb_layers = None
self.exchange_temb_dims = False
@@ -622,12 +619,12 @@ def __init__(
range(len(num_attention_blocks)),
)
)
logger.warning(
print(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set."
)
) # todo: convert to warning

self.attention_resolutions = attention_resolutions
self.dropout = dropout
@@ -636,7 +633,7 @@ def __init__(
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
if use_fp16:
logger.warning("use_fp16 was dropped and has no effect anymore.")
print("WARNING: use_fp16 was dropped and has no effect anymore.")
# self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
@@ -667,7 +664,7 @@ def __init__(
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
logger.debug("setting up linear c_adm embedding layer")
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = checkpoint_wrapper_fn(
11 changes: 4 additions & 7 deletions sgm/modules/encoders/modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from contextlib import nullcontext
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
@@ -33,8 +32,6 @@
instantiate_from_config,
)

logger = logging.getLogger(__name__)


class AbstractEmbModel(nn.Module):
def __init__(self):
@@ -99,7 +96,7 @@ def __init__(self, emb_models: Union[List, ListConfig]):
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
logger.debug(
print(
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
)
@@ -730,7 +727,7 @@ def encode_with_vision_transformer(self, img):
)
if tokens is not None:
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
logger.warning(
print(
f"You are running very experimental token-concat in {self.__class__.__name__}. "
f"Check what you are doing, and then remove this message."
)
@@ -756,7 +753,7 @@ def __init__(
clip_version, device, max_length=clip_max_length
)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
logger.debug(
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
@@ -798,7 +795,7 @@ def __init__(
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None or remap_output
if self.remap_output:
logger.debug(
print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(
24 changes: 11 additions & 13 deletions sgm/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import importlib
import logging
import os
from functools import partial
from inspect import isfunction
@@ -11,8 +10,6 @@
from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors

logger = logging.getLogger(__name__)


def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
@@ -89,7 +86,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
logger.warning("Cant encode string %r for logging. Skipping.", lines)
print("Cant encode string for logging. Skipping.")

txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -164,7 +161,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params


@@ -203,11 +200,11 @@ def append_dims(x, target_dims):


def load_model_from_config(config, ckpt, verbose=True, freeze=True):
logger.info(f"Loading model from {ckpt}")
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
logger.debug(f"Global Step: {pl_sd['global_step']}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
@@ -216,13 +213,14 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):

model = instantiate_from_config(config.model)

missing, unexpected = model.load_state_dict(sd, strict=False)
m, u = model.load_state_dict(sd, strict=False)

if verbose:
if missing:
logger.info("missing keys: %r", missing)
if unexpected:
logger.info("unexpected keys: %r", unexpected)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)

if freeze:
for param in model.parameters():

0 comments on commit 4a3f0f5

Please sign in to comment.