diff --git a/fx2ait/fx2ait/tensor_spec.py b/fx2ait/fx2ait/tensor_spec.py index 1db415571..6ac2f3a32 100644 --- a/fx2ait/fx2ait/tensor_spec.py +++ b/fx2ait/fx2ait/tensor_spec.py @@ -481,6 +481,9 @@ def find_batch_size_dim(cls, inputs: Any) -> []: if len(shape) < 2: # By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info continue + if shape[0] == 0: + # We expect that 1 is the minimum batch size value anyways. + continue # Dedup shape value for single tensor first_dims.add(shape[0]) seen_dims = set() diff --git a/fx2ait/fx2ait/test/test_tensor_spec.py b/fx2ait/fx2ait/test/test_tensor_spec.py index 4ab33f218..4c7c59118 100644 --- a/fx2ait/fx2ait/test/test_tensor_spec.py +++ b/fx2ait/fx2ait/test/test_tensor_spec.py @@ -178,3 +178,21 @@ def test_input_with_no_bs_tensor(self): ), specs[3], ) + + def test_input_with_first_dim_zero(self): + inputs = [ + torch.empty([10, 8643], dtype=torch.float16), + torch.empty([0, 8643], dtype=torch.float16), + ] + + specs = TensorSpec.from_input_list_with_batch_size(inputs, 32) + + self.assertEqual( + [ + TensorSpec( + [IntVar([1, 32], "batch_size"), IntImm(8643)], torch.float16 + ), + TensorSpec([IntImm(0), IntImm(8643)], torch.float16), + ], + specs, + )