diff --git a/docs/Dialects/onnx.md b/docs/Dialects/onnx.md index 42a060a6ac..43339eb658 100644 --- a/docs/Dialects/onnx.md +++ b/docs/Dialects/onnx.md @@ -1114,13 +1114,13 @@ Effects: `MemoryEffects::Effect{}` | Operand | Description | | :-----: | ----------- | -| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values +| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values #### Results: | Result | Description | | :----: | ----------- | -| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values +| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values ### `onnx.CategoryMapper` (ONNXCategoryMapperOp) diff --git a/src/Builder/OpBuildTable.inc b/src/Builder/OpBuildTable.inc index e46c353805..0e3abe9f49 100644 --- a/src/Builder/OpBuildTable.inc +++ b/src/Builder/OpBuildTable.inc @@ -28,7 +28,7 @@ op_dialect_version_map_["BitwiseNot"] = {18}; op_dialect_version_map_["BitwiseOr"] = {18}; op_dialect_version_map_["BitwiseXor"] = {18}; op_dialect_version_map_["BlackmanWindow"] = {17}; -op_dialect_version_map_["Cast"] = {19}; +op_dialect_version_map_["Cast"] = {21}; op_dialect_version_map_["CastLike"] = {19}; op_dialect_version_map_["CastMap"] = {1}; op_dialect_version_map_["CategoryMapper"] = {1}; diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 1c14bb4822..ba627d21a8 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -863,10 +863,10 @@ def ONNXCastOp:ONNX_Op<"Cast", | [x] < -FLT_MAX | NaN | NaN | -Inf | NaN | | else | RNE | RNE | RNE | RNE | }]; - let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$input, + let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$input, DefaultValuedAttr:$saturate, TypeAttr:$to); - let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$output); + let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$output); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 1; diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index bdbe2694f6..7f260f2e99 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -643,11 +643,13 @@ Type convertONNXTypeToMLIRType( return builder.getI1Type(); case onnx::TensorProto_DataType::TensorProto_DataType_STRING: return ONNXStringType::get(builder.getContext()); + case onnx::TensorProto_DataType::TensorProto_DataType_INT4: + return builder.getIntegerType(/*width=*/4); + case onnx::TensorProto_DataType::TensorProto_DataType_UINT4: + return builder.getIntegerType(/*width=*/4, false); case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64: case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128: - case onnx::TensorProto_DataType::TensorProto_DataType_INT4: - case onnx::TensorProto_DataType::TensorProto_DataType_UINT4: case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED: llvm_unreachable("Unsupported data type encountered."); return nullptr; @@ -697,6 +699,10 @@ int64_t mlirTypeToOnnxType(Type elemType) { ? onnx::TensorProto::UNDEFINED : onnx::TensorProto::BOOL; break; + case 4: + onnxType = type.isUnsigned() ? onnx::TensorProto::UINT4 + : onnx::TensorProto::INT4; + break; case 8: onnxType = type.isUnsigned() ? onnx::TensorProto::UINT8 : onnx::TensorProto::INT8; diff --git a/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext b/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext new file mode 100644 index 0000000000..c5005ca136 --- /dev/null +++ b/test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext @@ -0,0 +1,19 @@ +// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s +< + ir_version: 10, + opset_import: ["" : 22] +> +test_int4_casting (int4[1] input, uint4[1] input2) => (int4[1] int4_cast_output, uint4[1] uint4_cast_output) { + int8_cast_output = Cast (input) + int4_cast_output = Cast (int8_cast_output) + uint8_cast_output = Cast (input2) + uint4_cast_output = Cast (uint8_cast_output) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4> {onnx.name = "input"}, [[PARAM_1_:%.+]]: tensor<1xui4> {onnx.name = "input2"}) -> (tensor<1xi4> {onnx.name = "int4_cast_output"}, tensor<1xui4> {onnx.name = "uint4_cast_output"}) { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor<1xi4>) -> tensor<1xi8> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = i4} : (tensor<1xi8>) -> tensor<1xi4> +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Cast"([[VAR_2_]]) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4> +// CHECK: onnx.Return [[VAR_1_]], [[VAR_3_]] : tensor<1xi4>, tensor<1xui4> +// CHECK: } diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 534a9ab500..a2b4af02c7 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -109,7 +109,7 @@ "BitwiseOr": [18], "BitwiseXor": [18], "BlackmanWindow": [17], - "Cast": [19], + "Cast": [21], "CastLike": [19], "CastMap": [1], "CategoryMapper": [1], @@ -611,6 +611,10 @@ # FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients # FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero # +# // 4-bit integer data types +# UINT4 = 21; // Unsigned integer in range [0, 15] +# INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation +# # // Future extensions go here. # } onnx_types = ( @@ -635,6 +639,8 @@ "float8e4m3fnuz", "float8e5m2", "float8e5m2fnuz", + "uint4", + "int4", ) tblgen_types = ( "BF16", @@ -658,6 +664,8 @@ "F8E4M3FNUZ", "F8E5M2", "F8E5M2FNUZ", + "AnyUI4", + "AnyI4", ) # Maximum count for actual type. Number more than MAX_NUM_TYPES will be used to encode @@ -1048,10 +1056,12 @@ def parse_type_str(allowedType): "seq": "SeqOf", "map": "TupleOf", "bool": "I1", + "uint4": "UI<4>", "uint8": "UI8", "uint16": "UI16", "uint32": "UI32", "uint64": "UI64", + "int4": "I<4>", "int8": "I8", "int16": "I16", "int32": "I32",