Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump onnx.Cast to opset 21 , adding int/uint4 support #3057

Merged
merged 3 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/Dialects/onnx.md
Original file line number Diff line number Diff line change
Expand Up @@ -1114,13 +1114,13 @@ Effects: `MemoryEffects::Effect{}`

| Operand | Description |
| :-----: | ----------- |
| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values
| `input` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values

#### Results:

| Result | Description |
| :----: | ----------- |
| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values
| `output` | tensor of 16-bit float values or tensor of 32-bit float values or tensor of 64-bit float values or tensor of 8-bit signless integer values or tensor of 16-bit signless integer values or tensor of 32-bit signless integer values or tensor of 64-bit signless integer values or tensor of 8-bit unsigned integer values or tensor of 16-bit unsigned integer values or tensor of 32-bit unsigned integer values or tensor of 64-bit unsigned integer values or tensor of 1-bit signless integer values or tensor of string type values or tensor of bfloat16 type values or tensor of f8E4M3FN type values or tensor of f8E4M3FNUZ type values or tensor of f8E5M2 type values or tensor of f8E5M2FNUZ type values or tensor of 4-bit unsigned integer values or tensor of 4-bit signless integer values

### `onnx.CategoryMapper` (ONNXCategoryMapperOp)

Expand Down
2 changes: 1 addition & 1 deletion src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ op_dialect_version_map_["BitwiseNot"] = {18};
op_dialect_version_map_["BitwiseOr"] = {18};
op_dialect_version_map_["BitwiseXor"] = {18};
op_dialect_version_map_["BlackmanWindow"] = {17};
op_dialect_version_map_["Cast"] = {19};
op_dialect_version_map_["Cast"] = {21};
op_dialect_version_map_["CastLike"] = {19};
op_dialect_version_map_["CastMap"] = {1};
op_dialect_version_map_["CategoryMapper"] = {1};
Expand Down
4 changes: 2 additions & 2 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -863,10 +863,10 @@ def ONNXCastOp:ONNX_Op<"Cast",
| [x] < -FLT_MAX | NaN | NaN | -Inf | NaN |
| else | RNE | RNE | RNE | RNE |
}];
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$input,
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$input,
DefaultValuedAttr<SI64Attr, "1">:$saturate,
TypeAttr:$to);
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$output);
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, TensorOf<[UI<4>]>, TensorOf<[I<4>]>]>:$output);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 1;
Expand Down
10 changes: 8 additions & 2 deletions src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,11 +643,13 @@ Type convertONNXTypeToMLIRType(
return builder.getI1Type();
case onnx::TensorProto_DataType::TensorProto_DataType_STRING:
return ONNXStringType::get(builder.getContext());
case onnx::TensorProto_DataType::TensorProto_DataType_INT4:
return builder.getIntegerType(/*width=*/4);
case onnx::TensorProto_DataType::TensorProto_DataType_UINT4:
return builder.getIntegerType(/*width=*/4, false);

case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX64:
case onnx::TensorProto_DataType::TensorProto_DataType_COMPLEX128:
case onnx::TensorProto_DataType::TensorProto_DataType_INT4:
case onnx::TensorProto_DataType::TensorProto_DataType_UINT4:
case onnx::TensorProto_DataType::TensorProto_DataType_UNDEFINED:
llvm_unreachable("Unsupported data type encountered.");
return nullptr;
Expand Down Expand Up @@ -697,6 +699,10 @@ int64_t mlirTypeToOnnxType(Type elemType) {
? onnx::TensorProto::UNDEFINED
: onnx::TensorProto::BOOL;
break;
case 4:
onnxType = type.isUnsigned() ? onnx::TensorProto::UINT4
: onnx::TensorProto::INT4;
break;
case 8:
onnxType = type.isUnsigned() ? onnx::TensorProto::UINT8
: onnx::TensorProto::INT8;
Expand Down
19 changes: 19 additions & 0 deletions test/mlir/onnx/parse/cast_to_int_4_and_back.onnxtext
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s
<
ir_version: 10,
opset_import: ["" : 22]
>
test_int4_casting (int4[1] input, uint4[1] input2) => (int4[1] int4_cast_output, uint4[1] uint4_cast_output) {
int8_cast_output = Cast <to: int = 3> (input)
int4_cast_output = Cast <to: int = 22> (int8_cast_output)
uint8_cast_output = Cast <to: int = 2> (input2)
uint4_cast_output = Cast <to: int = 21> (uint8_cast_output)
}
// CHECK-LABEL: func.func @main_graph
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1xi4> {onnx.name = "input"}, [[PARAM_1_:%.+]]: tensor<1xui4> {onnx.name = "input2"}) -> (tensor<1xi4> {onnx.name = "int4_cast_output"}, tensor<1xui4> {onnx.name = "uint4_cast_output"}) {
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.Cast"([[PARAM_0_]]) {saturate = 1 : si64, to = i8} : (tensor<1xi4>) -> tensor<1xi8>
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Cast"([[VAR_0_]]) {saturate = 1 : si64, to = i4} : (tensor<1xi8>) -> tensor<1xi4>
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[PARAM_1_]]) {saturate = 1 : si64, to = ui8} : (tensor<1xui4>) -> tensor<1xui8>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Cast"([[VAR_2_]]) {saturate = 1 : si64, to = ui4} : (tensor<1xui8>) -> tensor<1xui4>
// CHECK: onnx.Return [[VAR_1_]], [[VAR_3_]] : tensor<1xi4>, tensor<1xui4>
// CHECK: }
12 changes: 11 additions & 1 deletion utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"BitwiseOr": [18],
"BitwiseXor": [18],
"BlackmanWindow": [17],
"Cast": [19],
"Cast": [21],
"CastLike": [19],
"CastMap": [1],
"CategoryMapper": [1],
Expand Down Expand Up @@ -611,6 +611,10 @@
# FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
# FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
#
# // 4-bit integer data types
# UINT4 = 21; // Unsigned integer in range [0, 15]
# INT4 = 22; // Signed integer in range [-8, 7], using two's-complement representation
#
# // Future extensions go here.
# }
onnx_types = (
Expand All @@ -635,6 +639,8 @@
"float8e4m3fnuz",
"float8e5m2",
"float8e5m2fnuz",
"uint4",
"int4",
)
tblgen_types = (
"BF16",
Expand All @@ -658,6 +664,8 @@
"F8E4M3FNUZ",
"F8E5M2",
"F8E5M2FNUZ",
"AnyUI4",
"AnyI4",
)

# Maximum count for actual type. Number more than MAX_NUM_TYPES will be used to encode
Expand Down Expand Up @@ -1048,10 +1056,12 @@ def parse_type_str(allowedType):
"seq": "SeqOf",
"map": "TupleOf",
"bool": "I1",
"uint4": "UI<4>",
"uint8": "UI8",
"uint16": "UI16",
"uint32": "UI32",
"uint64": "UI64",
"int4": "I<4>",
"int8": "I8",
"int16": "I16",
"int32": "I32",
Expand Down