Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor out onnx conversion + use data for ncnn input name #2360

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions backend/src/nodes/impl/pytorch/convert_to_onnx_impl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from io import BytesIO

import torch
from spandrel import ModelDescriptor


def convert_to_onnx_impl(
model: ModelDescriptor,
device: torch.device,
use_half: bool = False,
input_name: str = "input",
output_name: str = "output",
) -> bytes:
# https://github.com/onnx/onnx/issues/654
dynamic_axes = {
input_name: {0: "batch_size", 2: "height", 3: "width"},
output_name: {0: "batch_size", 2: "height", 3: "width"},
}
dummy_input = torch.rand(1, model.input_channels, 64, 64)
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved
dummy_input = dummy_input.to(device)

if use_half:
model.model.half()
joeyballentine marked this conversation as resolved.
Show resolved Hide resolved
dummy_input = dummy_input.half()
else:
model.model.float()
dummy_input = dummy_input.float()

with BytesIO() as f:
torch.onnx.export(
model.model,
dummy_input,
f,
opset_version=14,
verbose=False,
input_names=[input_name],
output_names=[output_name],
dynamic_axes=dynamic_axes,
do_constant_folding=True,
)
f.seek(0)
return f.read()
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@
from spandrel.architectures.SwinIR import SwinIR

from nodes.impl.ncnn.model import NcnnModelWrapper
from nodes.impl.onnx.model import OnnxGeneric
from nodes.impl.pytorch.convert_to_onnx_impl import convert_to_onnx_impl
from nodes.properties.inputs import OnnxFpDropdown, SrModelInput
from nodes.properties.outputs import NcnnModelOutput, TextOutput

from ...settings import get_settings
from .. import utility_group
from .convert_to_onnx import convert_to_onnx_node

try:
from ....chaiNNer_onnx.onnx.utility.convert_to_ncnn import FP_MODE_32
from ....chaiNNer_onnx.onnx.utility.convert_to_ncnn import (
convert_to_ncnn_node as onnx_convert_to_ncnn_node,
)
except Exception:
onnx_convert_to_ncnn_node = None
FP_MODE_32 = 0


@utility_group.register(
Expand Down Expand Up @@ -58,8 +58,13 @@ def convert_to_ncnn_node(
model.model, (HAT, DAT, OmniSR, SwinIR, Swin2SR, SCUNet, SRFormer)
), f"{model.architecture} is not supported for NCNN conversions at this time."

exec_options = get_settings()
device = exec_options.device

# Intermediate conversion to ONNX is always fp32
onnx_model = convert_to_onnx_node(model, FP_MODE_32)[0]
onnx_model = OnnxGeneric(
convert_to_onnx_impl(model, device, False, "data", "output")
)
ncnn_model, fp_mode = onnx_convert_to_ncnn_node(onnx_model, is_fp16)

return ncnn_model, fp_mode
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from __future__ import annotations

from io import BytesIO

import torch
from spandrel import ImageModelDescriptor
from spandrel.architectures.SCUNet import SCUNet

from nodes.impl.onnx.model import OnnxGeneric
from nodes.impl.pytorch.convert_to_onnx_impl import convert_to_onnx_impl
from nodes.properties.inputs import OnnxFpDropdown, SrModelInput
from nodes.properties.outputs import OnnxModelOutput, TextOutput

Expand Down Expand Up @@ -47,37 +45,11 @@ def convert_to_onnx_node(

model.model.eval()
model = model.to(device)
# https://github.com/onnx/onnx/issues/654
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size", 2: "height", 3: "width"},
}
dummy_input = torch.rand(1, model.input_channels, 64, 64)
dummy_input = dummy_input.to(device)

should_use_fp16 = exec_options.use_fp16 and model.supports_half and fp16
if should_use_fp16:
model.model.half()
dummy_input = dummy_input.half()
else:
model.model.float()
dummy_input = dummy_input.float()
use_half = fp16 and model.supports_half

with BytesIO() as f:
torch.onnx.export(
model.model,
dummy_input,
f,
opset_version=14,
verbose=False,
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
do_constant_folding=True,
)
f.seek(0)
onnx_model_bytes = f.read()
onnx_model_bytes = convert_to_onnx_impl(model, device, use_half)

fp_mode = "fp16" if should_use_fp16 else "fp32"
fp_mode = "fp16" if use_half else "fp32"

return OnnxGeneric(onnx_model_bytes), fp_mode