diff --git a/backend/src/nodes/impl/pytorch/auto_split.py b/backend/src/nodes/impl/pytorch/auto_split.py index 30316a637..73ad44833 100644 --- a/backend/src/nodes/impl/pytorch/auto_split.py +++ b/backend/src/nodes/impl/pytorch/auto_split.py @@ -4,6 +4,8 @@ import numpy as np import torch +from sanic.log import logger +from spandrel import ImageModelDescriptor from ..upscale.auto_split import Split, Tiler, auto_split from .utils import np2tensor, safe_cuda_cache_empty, tensor2np @@ -12,13 +14,16 @@ @torch.inference_mode() def pytorch_auto_split( img: np.ndarray, - model: torch.nn.Module, + model: ImageModelDescriptor, device: torch.device, use_fp16: bool, tiler: Tiler, ) -> np.ndarray: model = model.to(device) - model = model.half() if use_fp16 else model.float() + if use_fp16: + model.model.half() + else: + model.model.float() def upscale(img: np.ndarray, _: object): img_tensor = np2tensor(img, change_range=True) @@ -29,6 +34,7 @@ def upscale(img: np.ndarray, _: object): d_img = d_img.half() if use_fp16 else d_img.float() result = model(d_img) + logger.info(result) result = tensor2np( result.detach().cpu().detach(), change_range=False, diff --git a/backend/src/nodes/properties/inputs/pytorch_inputs.py b/backend/src/nodes/properties/inputs/pytorch_inputs.py index f49bffb05..93d855491 100644 --- a/backend/src/nodes/properties/inputs/pytorch_inputs.py +++ b/backend/src/nodes/properties/inputs/pytorch_inputs.py @@ -90,7 +90,7 @@ class InpaintModelInput(ModelInput): def __init__( self, label: str = "Model", input_type: navi.ExpressionJson = "PyTorchModel" ): - self.purpose: set[Purpose] = {"Inpaint"} + self.purpose: set[Purpose] = {"Inpainting"} super().__init__( label, diff --git a/backend/src/packages/chaiNNer_pytorch/__init__.py b/backend/src/packages/chaiNNer_pytorch/__init__.py index 9a5c4e0c1..c373384ae 100644 --- a/backend/src/packages/chaiNNer_pytorch/__init__.py +++ b/backend/src/packages/chaiNNer_pytorch/__init__.py @@ -95,7 +95,7 @@ def get_pytorch(): Dependency( display_name="Spandrel", pypi_name="spandrel", - version="0.1.0", + version="0.1.1", size_estimate=180.7 * KB, ), ], diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py index 029f76523..6f65dfdbe 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py @@ -90,7 +90,7 @@ def inpaint( d_mask = (d_mask > 0.5) * 1 d_mask = d_mask.half() if use_fp16 else d_mask.float() - result = model.model(d_img, d_mask) + result = model(d_img, d_mask) result = tensor2np( result.detach().cpu().detach(), change_range=False, diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py index b6cbaa82e..359722de9 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py @@ -63,7 +63,7 @@ def estimate(): img_out = pytorch_auto_split( img, - model=model.model, + model=model, device=device, use_fp16=use_fp16, tiler=parse_tile_size_input(tile_size, estimate), diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py index 2d88ea2e7..e5f3cb193 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py @@ -5,7 +5,12 @@ import numpy as np import torch from sanic.log import logger -from spandrel import ModelDescriptor, ModelLoader +from spandrel import ( + ImageModelDescriptor, + MaskedImageModelDescriptor, + ModelDescriptor, + ModelLoader, +) from nodes.impl.pytorch.utils import np2tensor, tensor2np from nodes.properties.inputs import ModelInput, SliderInput @@ -44,7 +49,17 @@ def check_can_interp(model_a: dict, model_b: dict): del interp_50 with torch.no_grad(): img_tensor = np2tensor(fake_img, change_range=True).cpu() - t_out = model_descriptor.model(img_tensor) + if isinstance(model_descriptor, MaskedImageModelDescriptor): + np.ones((size, size, 1), dtype=np.float32) + mask_tensor = np2tensor(fake_img, change_range=True).cpu() + t_out = model_descriptor(img_tensor, mask_tensor) + elif isinstance(model_descriptor, ImageModelDescriptor): # type: ignore <- I get that this can technically never happen, but please just let me write exhaustive checks + t_out = model_descriptor(img_tensor) + else: + logger.warning( + "Unknown model type used with interpolation. Since we cannot verify inference works with this model, we will assume the interpolation is valid. Please report." + ) + return True if isinstance(t_out, tuple): t_out = t_out[0] result = tensor2np(t_out.detach(), change_range=False, imtype=np.float32)