From 404e896c769aa069f77deb79081e19024e8596a8 Mon Sep 17 00:00:00 2001 From: Marcus Pertlwieser <116986601+Marcus1506@users.noreply.github.com> Date: Thu, 3 Oct 2024 16:20:38 +0200 Subject: [PATCH] Makes batch size dynamic (#2339) Made batch dimension of ONNX export dynamic when specifying input shape. --- src/anomalib/models/components/base/export_mixin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index e0627b462c..5e7e5e9481 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -142,7 +142,9 @@ def to_onnx( export_root = _create_export_root(export_root, ExportType.ONNX) input_shape = torch.zeros((1, 3, *input_size)) if input_size else torch.zeros((1, 3, 1, 1)) dynamic_axes = ( - None if input_size else {"input": {0: "batch_size", 2: "height", 3: "weight"}, "output": {0: "batch_size"}} + {"input": {0: "batch_size"}, "output": {0: "batch_size"}} + if input_size + else {"input": {0: "batch_size", 2: "height", 3: "weight"}, "output": {0: "batch_size"}} ) _write_metadata_to_json(self._get_metadata(task), export_root) onnx_path = export_root / "model.onnx"