diff --git a/python/hidet/graph/ops/quant/symmetric.py b/python/hidet/graph/ops/quant/symmetric.py index fc461820b..f377e7853 100644 --- a/python/hidet/graph/ops/quant/symmetric.py +++ b/python/hidet/graph/ops/quant/symmetric.py @@ -11,7 +11,7 @@ # limitations under the License. from typing import Union, List from hidet import ir -from hidet.ir.type import DataType +from hidet.ir.type import DataType, int32 from hidet.ir.expr import cast, if_then_else from hidet.ir.compute.primitives import TensorNode, compute from hidet.ir import primitives as prim @@ -36,7 +36,9 @@ def __init__(self, w: TensorNode, quant_type: DataType, dims: Union[int, List[in def scale_weight(*indices): scale_indices = [indices[i] for i in range(len(indices)) if not i in dims] - return cast(prim.round(w[indices] / scale[scale_indices]), quant_type) + # Have to cast to int32 first because there are several ways convert bf16 to int8 + cast_to_int = cast(prim.round(w[indices] / scale[scale_indices]), int32) + return cast(cast_to_int, quant_type) wq = compute(name='quantize', shape=w.shape, fcompute=scale_weight) super().__init__( diff --git a/python/hidet/version.py b/python/hidet/version.py index 2f7c9cff0..087572bb4 100644 --- a/python/hidet/version.py +++ b/python/hidet/version.py @@ -10,4 +10,3 @@ # See the License for the specific language governing permissions and # limitations under the License. __version__ = "0.5.0" -