Skip to content

Commit

Permalink
Bump onnx.Cast to opset 21 , adding int/uint4 support (onnx#3057)
Browse files Browse the repository at this point in the history
* Add support for TensorProto::UINT4/INT4

Signed-off-by: Rickert, Jonas <[email protected]>

* Upgrade onnx.Cast to opset 21

Signed-off-by: Rickert, Jonas <[email protected]>

---------

Signed-off-by: Rickert, Jonas <[email protected]>
  • Loading branch information
jorickert committed Feb 7, 2025
1 parent d863cb7 commit 6e5d133
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 8 deletions.
4 changes: 2 additions & 2 deletions docs/Dialects/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -1163,13 +1163,13 @@ Effects: `MemoryEffects::Effect{}`

| Operand | Description |
| :-----: | ----------- |
| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values
| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values

#### Results:

| Result | Description |
| :----: | ----------- |
| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values
| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values

### `onnx.CategoryMapper` (ONNXCategoryMapperOp)

Expand Down
2 changes: 1 addition & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ op_dialect_version_map_["BitwiseNot"] = {18};
op_dialect_version_map_["BitwiseOr"] = {18};
op_dialect_version_map_["BitwiseXor"] = {18};
op_dialect_version_map_["BlackmanWindow"] = {17};
op_dialect_version_map_["Cast"] = {19};
op_dialect_version_map_["Cast"] = {21};
op_dialect_version_map_["CastLike"] = {19};
op_dialect_version_map_["CastMap"] = {1};
op_dialect_version_map_["CategoryMapper"] = {1};
Expand Down
4 changes: 2 additions & 2 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -911,10 +911,10 @@ def ONNXCastOp:ONNX_Op<"Cast",
| [x] < -FLT_MAX | NaN | NaN | -Inf | NaN |
| else | RNE | RNE | RNE | RNE |
}];
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$input,
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$input,
DefaultValuedAttr<SI64Attr, "1">:$saturate,
TypeAttr:$to);
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$output);
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$output);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
Expand Down
10 changes: 8 additions & 2 deletions src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,11 +667,13 @@ Type convertONNXTypeToMLIRType(
return builder.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
return ONNXStringType::get(builder.getContext());
case onnx::TensorProto_DataType::TensorProto_DataType_INT4:
return builder.getIntegerType(/*width=*/4);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT4:
return builder.getIntegerType(/*width=*/4, false);

case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
case onnx::TensorProto_DataType::TensorProto_DataType_INT4:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT4:
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
llvm_unreachable("Unsupported data type encountered.");
return nullptr;
Expand Down Expand Up @@ -721,6 +723,10 @@ int64_t mlirTypeToOnnxType(Type elemType) {
? onnx::TensorProto::UNDEFINED
: onnx::TensorProto::BOOL;
break;
case 4:
onnxType = type.isUnsigned() ? onnx::TensorProto::UINT4
: onnx::TensorProto::INT4;
break;
case 8:
onnxType = type.isUnsigned() ? onnx::TensorProto::UINT8
: onnx::TensorProto::INT8;
Expand Down
19 changes: 19 additions & 0 deletions test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s
<
ir_version: 10,
opset_import: ["" : 22]
>
test_int4_casting (int4[1] input, uint4[1] input2) => (int4[1] int4_cast_output, uint4[1] uint4_cast_output) {
int8_cast_output = Cast <to: int = 3> (input)
int4_cast_output = Cast <to: int = 22> (int8_cast_output)
uint8_cast_output = Cast <to: int = 2> (input2)
uint4_cast_output = Cast <to: int = 21> (uint8_cast_output)
}
// CHECK-LABEL: func.func @main_graph
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4> {onnx.name = "input"}, [[PARAM_1_:%.+]]: tensor<1xui4> {onnx.name = "input2"}) -> (tensor<1xi4> {onnx.name = "int4_cast_output"}, tensor<1xui4> {onnx.name = "uint4_cast_output"}) {
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor<1xi4>) -> tensor<1xi8>
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = i4} : (tensor<1xi8>) -> tensor<1xi4>
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Cast"([[VAR_2_]]) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4>
// CHECK: onnx.Return [[VAR_1_]], [[VAR_3_]] : tensor<1xi4>, tensor<1xui4>
// CHECK: }
12 changes: 11 additions & 1 deletion utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"BitwiseOr": [18],
"BitwiseXor": [18],
"BlackmanWindow": [17],
"Cast": [19],
"Cast": [21],
"CastLike": [19],
"CastMap": [1],
"CategoryMapper": [1],
Expand Down Expand Up @@ -614,6 +614,10 @@
# FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
# FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
#
# // 4-bit integer data types
# UINT4 = 21; // Unsigned integer in range [0, 15]
# INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation
#
# // Future extensions go here.
# }
onnx_types = (
Expand All @@ -638,6 +642,8 @@
"float8e4m3fnuz",
"float8e5m2",
"float8e5m2fnuz",
"uint4",
"int4",
)
tblgen_types = (
"BF16",
Expand All @@ -661,6 +667,8 @@
"F8E4M3FNUZ",
"F8E5M2",
"F8E5M2FNUZ",
"AnyUI4",
"AnyI4",
)

# Maximum count for actual type. Number more than MAX_NUM_TYPES will be used to encode
Expand Down Expand Up @@ -1051,10 +1059,12 @@ def parse_type_str(allowedType):
"seq": "SeqOf",
"map": "TupleOf",
"bool": "I1",
"uint4": "UI<4>",
"uint8": "UI8",
"uint16": "UI16",
"uint32": "UI32",
"uint64": "UI64",
"int4": "I<4>",
"int8": "I8",
"int16": "I16",
"int32": "I32",
Expand Down

0 comments on commit 6e5d133

Please sign in to comment.