diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 04a5f985643e..d36e802d06c9 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -262,6 +262,8 @@ def Tensor( vdevice: Optional[str] = None, ndim: int = -1, ) -> TensorProxy: + if shape is not None and isinstance(shape, int): + shape = [shape,] # scalar tensor case if shape is not None and not isinstance(shape, Var) and len(shape) == 0: shape = []