@@ -7304,6 +7304,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
7304
7304
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7305
7305
" return %0 : !torch.list<int>\n"
7306
7306
" }\n"
7307
+ " func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple<list<int>, list<int>> {\n"
7308
+ " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7309
+ " %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
7310
+ " %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
7311
+ " return %2 : !torch.tuple<list<int>, list<int>>\n"
7312
+ " }\n"
7307
7313
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
7308
7314
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7309
7315
" return %0 : !torch.list<int>\n"
@@ -12599,17 +12605,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
12599
12605
" return %0#1 : !torch.int\n"
12600
12606
" }\n"
12601
12607
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n"
12608
+ " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12609
+ " return %0#1 : !torch.int\n"
12610
+ " }\n"
12611
+ " func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
12602
12612
" %none = torch.constant.none\n"
12603
12613
" %str = torch.constant.str \"AssertionError: \"\n"
12604
- " %true = torch.constant.bool true\n"
12605
12614
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12606
- " %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12607
- " %2 = torch.prim.If %1 -> (!torch.bool) {\n"
12608
- " torch.prim.If.yield %true : !torch.bool\n"
12609
- " } else {\n"
12610
- " %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12611
- " torch.prim.If.yield %3 : !torch.bool\n"
12612
- " }\n"
12615
+ " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12616
+ " %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
12613
12617
" torch.prim.If %2 -> () {\n"
12614
12618
" torch.prim.If.yield\n"
12615
12619
" } else {\n"
@@ -12618,46 +12622,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
12618
12622
" }\n"
12619
12623
" return %0#1 : !torch.int\n"
12620
12624
" }\n"
12621
- " func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise \"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
12625
+ " func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_functional \"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple< int, int> {\n"
12622
12626
" %none = torch.constant.none\n"
12623
12627
" %str = torch.constant.str \"AssertionError: \"\n"
12624
- " %true = torch.constant.bool true\n"
12625
12628
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12626
12629
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12627
- " %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12628
- " %3 = torch.prim.If %2 -> (!torch.bool) {\n"
12629
- " torch.prim.If.yield %true : !torch.bool\n"
12630
- " } else {\n"
12631
- " %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12632
- " torch.prim.If.yield %7 : !torch.bool\n"
12633
- " }\n"
12634
- " torch.prim.If %3 -> () {\n"
12635
- " torch.prim.If.yield\n"
12636
- " } else {\n"
12637
- " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12638
- " torch.prim.If.yield\n"
12639
- " }\n"
12640
- " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
12641
- " %5 = torch.prim.If %4 -> (!torch.bool) {\n"
12642
- " torch.prim.If.yield %true : !torch.bool\n"
12643
- " } else {\n"
12644
- " %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
12645
- " torch.prim.If.yield %7 : !torch.bool\n"
12646
- " }\n"
12647
- " torch.prim.If %5 -> () {\n"
12648
- " torch.prim.If.yield\n"
12649
- " } else {\n"
12650
- " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12651
- " torch.prim.If.yield\n"
12652
- " }\n"
12653
- " %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
12654
- " torch.prim.If %6 -> () {\n"
12630
+ " %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
12631
+ " torch.prim.If %2 -> () {\n"
12655
12632
" torch.prim.If.yield\n"
12656
12633
" } else {\n"
12657
12634
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12658
12635
" torch.prim.If.yield\n"
12659
12636
" }\n"
12660
- " return %0#1 : !torch.int\n"
12637
+ " %3 = torch.prim.TupleConstruct %0#1, %1#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
12638
+ " return %3 : !torch.tuple<int, int>\n"
12661
12639
" }\n"
12662
12640
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
12663
12641
" %none = torch.constant.none\n"
0 commit comments