Skip to content

Commit

Permalink
Refactor out onnx conversion + use data for ncnn input name (#2360)
Browse files Browse the repository at this point in the history
* Refactor out onnx conversion + use data for ncnn input name

* better size requirements for auto inference

* throw error
  • Loading branch information
joeyballentine authored Dec 1, 2023
1 parent f13c00f commit f1402cc
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 36 deletions.
48 changes: 48 additions & 0 deletions backend/src/nodes/impl/pytorch/convert_to_onnx_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from io import BytesIO

import torch
from spandrel import ModelDescriptor


def convert_to_onnx_impl(
model: ModelDescriptor,
device: torch.device,
use_half: bool = False,
input_name: str = "input",
output_name: str = "output",
) -> bytes:
# https://github.com/onnx/onnx/issues/654
dynamic_axes = {
input_name: {0: "batch_size", 2: "height", 3: "width"},
output_name: {0: "batch_size", 2: "height", 3: "width"},
}
size = max(model.size_requirements.minimum, 3)
size = size + (size % model.size_requirements.multiple_of)
dummy_input = torch.rand(1, model.input_channels, size, size)
dummy_input = dummy_input.to(device)

if use_half:
if not model.supports_half:
raise ValueError(
f"Model of arch {model.architecture} does not support half precision."
)
model.model.half()
dummy_input = dummy_input.half()
else:
model.model.float()
dummy_input = dummy_input.float()

with BytesIO() as f:
torch.onnx.export(
model.model,
dummy_input,
f,
opset_version=14,
verbose=False,
input_names=[input_name],
output_names=[output_name],
dynamic_axes=dynamic_axes,
do_constant_folding=True,
)
f.seek(0)
return f.read()
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@
from spandrel.architectures.SwinIR import SwinIR

from nodes.impl.ncnn.model import NcnnModelWrapper
from nodes.impl.onnx.model import OnnxGeneric
from nodes.impl.pytorch.convert_to_onnx_impl import convert_to_onnx_impl
from nodes.properties.inputs import OnnxFpDropdown, SrModelInput
from nodes.properties.outputs import NcnnModelOutput, TextOutput

from ...settings import get_settings
from .. import utility_group
from .convert_to_onnx import convert_to_onnx_node

try:
from ....chaiNNer_onnx.onnx.utility.convert_to_ncnn import FP_MODE_32
from ....chaiNNer_onnx.onnx.utility.convert_to_ncnn import (
convert_to_ncnn_node as onnx_convert_to_ncnn_node,
)
except Exception:
onnx_convert_to_ncnn_node = None
FP_MODE_32 = 0


@utility_group.register(
Expand Down Expand Up @@ -58,8 +58,13 @@ def convert_to_ncnn_node(
model.model, (HAT, DAT, OmniSR, SwinIR, Swin2SR, SCUNet, SRFormer)
), f"{model.architecture} is not supported for NCNN conversions at this time."

exec_options = get_settings()
device = exec_options.device

# Intermediate conversion to ONNX is always fp32
onnx_model = convert_to_onnx_node(model, FP_MODE_32)[0]
onnx_model = OnnxGeneric(
convert_to_onnx_impl(model, device, False, "data", "output")
)
ncnn_model, fp_mode = onnx_convert_to_ncnn_node(onnx_model, is_fp16)

return ncnn_model, fp_mode
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

from io import BytesIO

import torch
from spandrel import ImageModelDescriptor
from spandrel.architectures.SCUNet import SCUNet

from nodes.impl.onnx.model import OnnxGeneric
from nodes.impl.pytorch.convert_to_onnx_impl import convert_to_onnx_impl
from nodes.properties.inputs import OnnxFpDropdown, SrModelInput
from nodes.properties.outputs import OnnxModelOutput, TextOutput

Expand Down Expand Up @@ -47,37 +45,11 @@ def convert_to_onnx_node(

model.model.eval()
model = model.to(device)
# https://github.com/onnx/onnx/issues/654
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size", 2: "height", 3: "width"},
}
dummy_input = torch.rand(1, model.input_channels, 64, 64)
dummy_input = dummy_input.to(device)

should_use_fp16 = exec_options.use_fp16 and model.supports_half and fp16
if should_use_fp16:
model.model.half()
dummy_input = dummy_input.half()
else:
model.model.float()
dummy_input = dummy_input.float()
use_half = fp16 and model.supports_half

with BytesIO() as f:
torch.onnx.export(
model.model,
dummy_input,
f,
opset_version=14,
verbose=False,
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
do_constant_folding=True,
)
f.seek(0)
onnx_model_bytes = f.read()
onnx_model_bytes = convert_to_onnx_impl(model, device, use_half)

fp_mode = "fp16" if should_use_fp16 else "fp32"
fp_mode = "fp16" if use_half else "fp32"

return OnnxGeneric(onnx_model_bytes), fp_mode
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def check_can_interp(model_a: dict, model_b: dict):
interp_50 = perform_interp(model_a, model_b, 50)
model_descriptor = ModelLoader(torch.device("cpu")).load_from_state_dict(interp_50)
size = max(model_descriptor.size_requirements.minimum, 3)
size = size + (size % model_descriptor.size_requirements.multiple_of)
assert isinstance(size, int), "min_size_restriction must be an int"
fake_img = np.ones((size, size, model_descriptor.input_channels), dtype=np.float32)
del interp_50
Expand Down

0 comments on commit f1402cc

Please sign in to comment.