From 486a8d2bb726e4a8ab9fb2a7d9504afdb382220f Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 18 Dec 2024 11:30:00 +0100 Subject: [PATCH] [docs] Tiny documentation export page fix --- docs/source/using_doctr/using_model_export.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/using_doctr/using_model_export.rst b/docs/source/using_doctr/using_model_export.rst index 073172efbc..4ab94faf94 100644 --- a/docs/source/using_doctr/using_model_export.rst +++ b/docs/source/using_doctr/using_model_export.rst @@ -119,10 +119,10 @@ It defines a common format for representing models, including the network struct from doctr.models import vitstr_small from doctr.models.utils import export_model_to_onnx - batch_size = 16 + batch_size = 1 input_shape = (3, 32, 128) model = vitstr_small(pretrained=True, exportable=True) - dummy_input = torch.rand((batch_size, input_shape), dtype=torch.float32) + dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32) model_path = export_model_to_onnx( model, model_name="vitstr.onnx", @@ -137,10 +137,10 @@ It defines a common format for representing models, including the network struct from doctr.models import vitstr_small from doctr.models.utils import export_model_to_onnx - batch_size = 16 + batch_size = 1 input_shape = (32, 128, 3) model = vitstr_small(pretrained=True, exportable=True) - dummy_input = [tf.TensorSpec([batch_size, input_shape], tf.float32, name="input")] + dummy_input = [tf.TensorSpec([batch_size, *input_shape], tf.float32, name="input")] model_path, output = export_model_to_onnx( model, model_name="vitstr.onnx",