From f1402ccfaee7595504c59a72d4498b0597c3f6ae Mon Sep 17 00:00:00 2001 From: Joey Ballentine <34788790+joeyballentine@users.noreply.github.com> Date: Fri, 1 Dec 2023 16:02:48 -0500 Subject: [PATCH] Refactor out onnx conversion + use data for ncnn input name (#2360) * Refactor out onnx conversion + use data for ncnn input name * better size requirements for auto inference * throw error --- .../impl/pytorch/convert_to_onnx_impl.py | 48 +++++++++++++++++++ .../pytorch/utility/convert_to_ncnn.py | 13 +++-- .../pytorch/utility/convert_to_onnx.py | 36 ++------------ .../pytorch/utility/interpolate_models.py | 1 + 4 files changed, 62 insertions(+), 36 deletions(-) create mode 100644 backend/src/nodes/impl/pytorch/convert_to_onnx_impl.py diff --git a/backend/src/nodes/impl/pytorch/convert_to_onnx_impl.py b/backend/src/nodes/impl/pytorch/convert_to_onnx_impl.py new file mode 100644 index 000000000..0781da61d --- /dev/null +++ b/backend/src/nodes/impl/pytorch/convert_to_onnx_impl.py @@ -0,0 +1,48 @@ +from io import BytesIO + +import torch +from spandrel import ModelDescriptor + + +def convert_to_onnx_impl( + model: ModelDescriptor, + device: torch.device, + use_half: bool = False, + input_name: str = "input", + output_name: str = "output", +) -> bytes: + # https://github.com/onnx/onnx/issues/654 + dynamic_axes = { + input_name: {0: "batch_size", 2: "height", 3: "width"}, + output_name: {0: "batch_size", 2: "height", 3: "width"}, + } + size = max(model.size_requirements.minimum, 3) + size = size + (size % model.size_requirements.multiple_of) + dummy_input = torch.rand(1, model.input_channels, size, size) + dummy_input = dummy_input.to(device) + + if use_half: + if not model.supports_half: + raise ValueError( + f"Model of arch {model.architecture} does not support half precision." + ) + model.model.half() + dummy_input = dummy_input.half() + else: + model.model.float() + dummy_input = dummy_input.float() + + with BytesIO() as f: + torch.onnx.export( + model.model, + dummy_input, + f, + opset_version=14, + verbose=False, + input_names=[input_name], + output_names=[output_name], + dynamic_axes=dynamic_axes, + do_constant_folding=True, + ) + f.seek(0) + return f.read() diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py index bbcee54dd..58bec21e5 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py @@ -10,20 +10,20 @@ from spandrel.architectures.SwinIR import SwinIR from nodes.impl.ncnn.model import NcnnModelWrapper +from nodes.impl.onnx.model import OnnxGeneric +from nodes.impl.pytorch.convert_to_onnx_impl import convert_to_onnx_impl from nodes.properties.inputs import OnnxFpDropdown, SrModelInput from nodes.properties.outputs import NcnnModelOutput, TextOutput +from ...settings import get_settings from .. import utility_group -from .convert_to_onnx import convert_to_onnx_node try: - from ....chaiNNer_onnx.onnx.utility.convert_to_ncnn import FP_MODE_32 from ....chaiNNer_onnx.onnx.utility.convert_to_ncnn import ( convert_to_ncnn_node as onnx_convert_to_ncnn_node, ) except Exception: onnx_convert_to_ncnn_node = None - FP_MODE_32 = 0 @utility_group.register( @@ -58,8 +58,13 @@ def convert_to_ncnn_node( model.model, (HAT, DAT, OmniSR, SwinIR, Swin2SR, SCUNet, SRFormer) ), f"{model.architecture} is not supported for NCNN conversions at this time." + exec_options = get_settings() + device = exec_options.device + # Intermediate conversion to ONNX is always fp32 - onnx_model = convert_to_onnx_node(model, FP_MODE_32)[0] + onnx_model = OnnxGeneric( + convert_to_onnx_impl(model, device, False, "data", "output") + ) ncnn_model, fp_mode = onnx_convert_to_ncnn_node(onnx_model, is_fp16) return ncnn_model, fp_mode diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py index ad9469795..a49d73826 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py @@ -1,12 +1,10 @@ from __future__ import annotations -from io import BytesIO - -import torch from spandrel import ImageModelDescriptor from spandrel.architectures.SCUNet import SCUNet from nodes.impl.onnx.model import OnnxGeneric +from nodes.impl.pytorch.convert_to_onnx_impl import convert_to_onnx_impl from nodes.properties.inputs import OnnxFpDropdown, SrModelInput from nodes.properties.outputs import OnnxModelOutput, TextOutput @@ -47,37 +45,11 @@ def convert_to_onnx_node( model.model.eval() model = model.to(device) - # https://github.com/onnx/onnx/issues/654 - dynamic_axes = { - "input": {0: "batch_size", 2: "height", 3: "width"}, - "output": {0: "batch_size", 2: "height", 3: "width"}, - } - dummy_input = torch.rand(1, model.input_channels, 64, 64) - dummy_input = dummy_input.to(device) - should_use_fp16 = exec_options.use_fp16 and model.supports_half and fp16 - if should_use_fp16: - model.model.half() - dummy_input = dummy_input.half() - else: - model.model.float() - dummy_input = dummy_input.float() + use_half = fp16 and model.supports_half - with BytesIO() as f: - torch.onnx.export( - model.model, - dummy_input, - f, - opset_version=14, - verbose=False, - input_names=["input"], - output_names=["output"], - dynamic_axes=dynamic_axes, - do_constant_folding=True, - ) - f.seek(0) - onnx_model_bytes = f.read() + onnx_model_bytes = convert_to_onnx_impl(model, device, use_half) - fp_mode = "fp16" if should_use_fp16 else "fp32" + fp_mode = "fp16" if use_half else "fp32" return OnnxGeneric(onnx_model_bytes), fp_mode diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py index e5f3cb193..885128c59 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py @@ -44,6 +44,7 @@ def check_can_interp(model_a: dict, model_b: dict): interp_50 = perform_interp(model_a, model_b, 50) model_descriptor = ModelLoader(torch.device("cpu")).load_from_state_dict(interp_50) size = max(model_descriptor.size_requirements.minimum, 3) + size = size + (size % model_descriptor.size_requirements.multiple_of) assert isinstance(size, int), "min_size_restriction must be an int" fake_img = np.ones((size, size, model_descriptor.input_channels), dtype=np.float32) del interp_50