Skip to content

Commit

Permalink
Update Spandrel again + fixes (#2359)
Browse files Browse the repository at this point in the history
* Update Spandrel again + fix inference

* use call api for inpainting as well

* fix inpainting thing

* Update backend/src/nodes/impl/pytorch/auto_split.py

Co-authored-by: Michael Schmidt <[email protected]>

---------

Co-authored-by: Michael Schmidt <[email protected]>
  • Loading branch information
joeyballentine and RunDevelopment authored Dec 1, 2023
1 parent 1161de4 commit f13c00f
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 8 deletions.
10 changes: 8 additions & 2 deletions backend/src/nodes/impl/pytorch/auto_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np
import torch
from sanic.log import logger
from spandrel import ImageModelDescriptor

from ..upscale.auto_split import Split, Tiler, auto_split
from .utils import np2tensor, safe_cuda_cache_empty, tensor2np
Expand All @@ -12,13 +14,16 @@
@torch.inference_mode()
def pytorch_auto_split(
img: np.ndarray,
model: torch.nn.Module,
model: ImageModelDescriptor,
device: torch.device,
use_fp16: bool,
tiler: Tiler,
) -> np.ndarray:
model = model.to(device)
model = model.half() if use_fp16 else model.float()
if use_fp16:
model.model.half()
else:
model.model.float()

def upscale(img: np.ndarray, _: object):
img_tensor = np2tensor(img, change_range=True)
Expand All @@ -29,6 +34,7 @@ def upscale(img: np.ndarray, _: object):
d_img = d_img.half() if use_fp16 else d_img.float()

result = model(d_img)
logger.info(result)
result = tensor2np(
result.detach().cpu().detach(),
change_range=False,
Expand Down
2 changes: 1 addition & 1 deletion backend/src/nodes/properties/inputs/pytorch_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class InpaintModelInput(ModelInput):
def __init__(
self, label: str = "Model", input_type: navi.ExpressionJson = "PyTorchModel"
):
self.purpose: set[Purpose] = {"Inpaint"}
self.purpose: set[Purpose] = {"Inpainting"}

super().__init__(
label,
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_pytorch():
Dependency(
display_name="Spandrel",
pypi_name="spandrel",
version="0.1.0",
version="0.1.1",
size_estimate=180.7 * KB,
),
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def inpaint(
d_mask = (d_mask > 0.5) * 1
d_mask = d_mask.half() if use_fp16 else d_mask.float()

result = model.model(d_img, d_mask)
result = model(d_img, d_mask)
result = tensor2np(
result.detach().cpu().detach(),
change_range=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def estimate():

img_out = pytorch_auto_split(
img,
model=model.model,
model=model,
device=device,
use_fp16=use_fp16,
tiler=parse_tile_size_input(tile_size, estimate),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import numpy as np
import torch
from sanic.log import logger
from spandrel import ModelDescriptor, ModelLoader
from spandrel import (
ImageModelDescriptor,
MaskedImageModelDescriptor,
ModelDescriptor,
ModelLoader,
)

from nodes.impl.pytorch.utils import np2tensor, tensor2np
from nodes.properties.inputs import ModelInput, SliderInput
Expand Down Expand Up @@ -44,7 +49,17 @@ def check_can_interp(model_a: dict, model_b: dict):
del interp_50
with torch.no_grad():
img_tensor = np2tensor(fake_img, change_range=True).cpu()
t_out = model_descriptor.model(img_tensor)
if isinstance(model_descriptor, MaskedImageModelDescriptor):
np.ones((size, size, 1), dtype=np.float32)
mask_tensor = np2tensor(fake_img, change_range=True).cpu()
t_out = model_descriptor(img_tensor, mask_tensor)
elif isinstance(model_descriptor, ImageModelDescriptor): # type: ignore <- I get that this can technically never happen, but please just let me write exhaustive checks
t_out = model_descriptor(img_tensor)
else:
logger.warning(
"Unknown model type used with interpolation. Since we cannot verify inference works with this model, we will assume the interpolation is valid. Please report."
)
return True
if isinstance(t_out, tuple):
t_out = t_out[0]
result = tensor2np(t_out.detach(), change_range=False, imtype=np.float32)
Expand Down

0 comments on commit f13c00f

Please sign in to comment.