Skip to content

Commit

Permalink
Allow setting ONNX opset (#2361)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment authored Dec 2, 2023
1 parent f1402cc commit 5c7c649
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
3 changes: 2 additions & 1 deletion backend/src/nodes/impl/pytorch/convert_to_onnx_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def convert_to_onnx_impl(
use_half: bool = False,
input_name: str = "input",
output_name: str = "output",
opset_version: int = 14,
) -> bytes:
# https://github.com/onnx/onnx/issues/654
dynamic_axes = {
Expand Down Expand Up @@ -37,7 +38,7 @@ def convert_to_onnx_impl(
model.model,
dummy_input,
f,
opset_version=14,
opset_version=opset_version,
verbose=False,
input_names=[input_name],
output_names=[output_name],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
from __future__ import annotations

from enum import Enum

from spandrel import ImageModelDescriptor
from spandrel.architectures.SCUNet import SCUNet

from nodes.impl.onnx.model import OnnxGeneric
from nodes.impl.pytorch.convert_to_onnx_impl import convert_to_onnx_impl
from nodes.properties.inputs import OnnxFpDropdown, SrModelInput
from nodes.properties.inputs import EnumInput, OnnxFpDropdown, SrModelInput
from nodes.properties.outputs import OnnxModelOutput, TextOutput

from ...settings import get_settings
from .. import utility_group


class Opset(Enum):
OPSET_14 = 14
OPSET_15 = 15
OPSET_16 = 16
OPSET_17 = 17


OPSET_LABELS: dict[Opset, str] = {
Opset.OPSET_14: "14",
Opset.OPSET_15: "15",
Opset.OPSET_16: "16",
Opset.OPSET_17: "17",
}


@utility_group.register(
schema_id="chainner:pytorch:convert_to_onnx",
name="Convert To ONNX",
Expand All @@ -24,14 +41,20 @@
inputs=[
SrModelInput("PyTorch Model"),
OnnxFpDropdown(),
EnumInput(
Opset,
label="Opset",
default=Opset.OPSET_14,
option_labels=OPSET_LABELS,
),
],
outputs=[
OnnxModelOutput(model_type="OnnxGenericModel", label="ONNX Model"),
TextOutput("FP Mode", "FpMode::toString(Input1)"),
],
)
def convert_to_onnx_node(
model: ImageModelDescriptor, is_fp16: int
model: ImageModelDescriptor, is_fp16: int, opset: Opset
) -> tuple[OnnxGeneric, str]:
assert not isinstance(
model.model, SCUNet
Expand All @@ -48,7 +71,12 @@ def convert_to_onnx_node(

use_half = fp16 and model.supports_half

onnx_model_bytes = convert_to_onnx_impl(model, device, use_half)
onnx_model_bytes = convert_to_onnx_impl(
model,
device,
use_half,
opset_version=opset.value,
)

fp_mode = "fp16" if use_half else "fp32"

Expand Down

0 comments on commit 5c7c649

Please sign in to comment.