Skip to content

Commit a6179c0

Browse files
build: manually update PyTorch version (llvm#3919)
This commit sets the PyTorch and TorchVision version to nightly release 2024-12-16. This commit adds the support for `aten.rrelu_with_noise_functional` op by decomposing it. And, also updates the existing decomposition of `aten.rrelu_with_noise` op by decomposing it to the newly added `aten.rrelu_with_noise_functional` op. It also updates the e2e tests for `aten.rrelu_with_noise` op by replacing it with its functional variant which is added here: pytorch/pytorch@f85e238 and which captures the noise mutation which was earlier a reason for the test failures during the training mode. This commit also removes the newly passing tests from the xfail_sets. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 02fa411 commit a6179c0

File tree

11 files changed

+121
-87
lines changed

11 files changed

+121
-87
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

+30-3
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,7 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
310310
}
311311

312312
def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [
313-
AllowsTypeRefinement,
314-
HasValueSemantics,
315-
ReadOnly
313+
AllowsTypeRefinement
316314
]> {
317315
let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
318316
let arguments = (ins
@@ -17519,6 +17517,35 @@ def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backwar
1751917517
}];
1752017518
}
1752117519

17520+
def Torch_AtenRreluWithNoiseFunctionalOp : Torch_Op<"aten.rrelu_with_noise_functional", [
17521+
AllowsTypeRefinement,
17522+
HasValueSemantics,
17523+
ReadOnly
17524+
]> {
17525+
let summary = "Generated op for `aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)`";
17526+
let arguments = (ins
17527+
AnyTorchTensorType:$self,
17528+
AnyTorchTensorType:$noise,
17529+
AnyTorchScalarType:$lower,
17530+
AnyTorchScalarType:$upper,
17531+
Torch_BoolType:$training,
17532+
AnyTorchOptionalGeneratorType:$generator
17533+
);
17534+
let results = (outs
17535+
AnyTorchOptionalTensorType:$result0,
17536+
AnyTorchOptionalTensorType:$noise_out
17537+
);
17538+
let hasCustomAssemblyFormat = 1;
17539+
let extraClassDefinition = [{
17540+
ParseResult AtenRreluWithNoiseFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
17541+
return parseDefaultTorchOp(parser, result, 6, 2);
17542+
}
17543+
void AtenRreluWithNoiseFunctionalOp::print(OpAsmPrinter &printer) {
17544+
printDefaultTorchOp(printer, *this, 6, 2);
17545+
}
17546+
}];
17547+
}
17548+
1752217549
def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [
1752317550
AllowsTypeRefinement,
1752417551
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

+17-39
Original file line numberDiff line numberDiff line change
@@ -7304,6 +7304,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
73047304
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
73057305
" return %0 : !torch.list<int>\n"
73067306
" }\n"
7307+
" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple<list<int>, list<int>> {\n"
7308+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7309+
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
7310+
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
7311+
" return %2 : !torch.tuple<list<int>, list<int>>\n"
7312+
" }\n"
73077313
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
73087314
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
73097315
" return %0 : !torch.list<int>\n"
@@ -12599,17 +12605,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1259912605
" return %0#1 : !torch.int\n"
1260012606
" }\n"
1260112607
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n"
12608+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12609+
" return %0#1 : !torch.int\n"
12610+
" }\n"
12611+
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
1260212612
" %none = torch.constant.none\n"
1260312613
" %str = torch.constant.str \"AssertionError: \"\n"
12604-
" %true = torch.constant.bool true\n"
1260512614
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12606-
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12607-
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
12608-
" torch.prim.If.yield %true : !torch.bool\n"
12609-
" } else {\n"
12610-
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12611-
" torch.prim.If.yield %3 : !torch.bool\n"
12612-
" }\n"
12615+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12616+
" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
1261312617
" torch.prim.If %2 -> () {\n"
1261412618
" torch.prim.If.yield\n"
1261512619
" } else {\n"
@@ -12618,46 +12622,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1261812622
" }\n"
1261912623
" return %0#1 : !torch.int\n"
1262012624
" }\n"
12621-
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n"
12625+
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple<int, int> {\n"
1262212626
" %none = torch.constant.none\n"
1262312627
" %str = torch.constant.str \"AssertionError: \"\n"
12624-
" %true = torch.constant.bool true\n"
1262512628
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1262612629
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
12627-
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12628-
" %3 = torch.prim.If %2 -> (!torch.bool) {\n"
12629-
" torch.prim.If.yield %true : !torch.bool\n"
12630-
" } else {\n"
12631-
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
12632-
" torch.prim.If.yield %7 : !torch.bool\n"
12633-
" }\n"
12634-
" torch.prim.If %3 -> () {\n"
12635-
" torch.prim.If.yield\n"
12636-
" } else {\n"
12637-
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12638-
" torch.prim.If.yield\n"
12639-
" }\n"
12640-
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
12641-
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
12642-
" torch.prim.If.yield %true : !torch.bool\n"
12643-
" } else {\n"
12644-
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
12645-
" torch.prim.If.yield %7 : !torch.bool\n"
12646-
" }\n"
12647-
" torch.prim.If %5 -> () {\n"
12648-
" torch.prim.If.yield\n"
12649-
" } else {\n"
12650-
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
12651-
" torch.prim.If.yield\n"
12652-
" }\n"
12653-
" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
12654-
" torch.prim.If %6 -> () {\n"
12630+
" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n"
12631+
" torch.prim.If %2 -> () {\n"
1265512632
" torch.prim.If.yield\n"
1265612633
" } else {\n"
1265712634
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
1265812635
" torch.prim.If.yield\n"
1265912636
" }\n"
12660-
" return %0#1 : !torch.int\n"
12637+
" %3 = torch.prim.TupleConstruct %0#1, %1#1 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
12638+
" return %3 : !torch.tuple<int, int>\n"
1266112639
" }\n"
1266212640
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1266312641
" %none = torch.constant.none\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

+32-7
Original file line numberDiff line numberDiff line change
@@ -3791,11 +3791,7 @@ class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
37913791
// Create a uniform random op with low and high set to `lower` and
37923792
// `upper`, respectively.
37933793
Value none = rewriter.create<ConstantNoneOp>(loc);
3794-
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
3795-
loc, resType, self, constantZeroFloat, /*dtype=*/none,
3796-
/*layout=*/none,
3797-
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
3798-
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
3794+
alpha = rewriter.create<AtenUniformOp>(loc, resType, self,
37993795
/*from=*/lower, /*to=*/upper,
38003796
/*generator=*/none);
38013797
} else {
@@ -3840,6 +3836,33 @@ class DecomposeAtenRreluWithNoiseOp
38403836
Value lower = op.getLower();
38413837
Value upper = op.getUpper();
38423838
auto resType = cast<BaseTensorType>(op.getType());
3839+
Value cstNone = rewriter.create<ConstantNoneOp>(loc);
3840+
Value cstFalse =
3841+
rewriter.create<ConstantBoolOp>(loc, rewriter.getBoolAttr(false));
3842+
Value result =
3843+
rewriter
3844+
.create<AtenRreluWithNoiseFunctionalOp>(
3845+
loc, resType, self, noise, lower, upper, cstFalse, cstNone)
3846+
->getResult(0);
3847+
rewriter.replaceOp(op, result);
3848+
return success();
3849+
}
3850+
};
3851+
} // namespace
3852+
3853+
namespace {
3854+
class DecomposeAtenRreluWithNoiseFunctionalOp
3855+
: public OpRewritePattern<AtenRreluWithNoiseFunctionalOp> {
3856+
public:
3857+
using OpRewritePattern::OpRewritePattern;
3858+
LogicalResult matchAndRewrite(AtenRreluWithNoiseFunctionalOp op,
3859+
PatternRewriter &rewriter) const override {
3860+
Location loc = op.getLoc();
3861+
Value self = op.getSelf();
3862+
Value noise = op.getNoise();
3863+
Value lower = op.getLower();
3864+
Value upper = op.getUpper();
3865+
auto resType = cast<BaseTensorType>(op.getResultTypes()[0]);
38433866
if (!resType.hasDtype()) {
38443867
return rewriter.notifyMatchFailure(op, "result should have dtype");
38453868
}
@@ -3885,7 +3908,7 @@ class DecomposeAtenRreluWithNoiseOp
38853908
rewriter.getI1Type());
38863909
Value oneTensor =
38873910
createRank0Tensor(rewriter, loc, resType, constantOneFloat);
3888-
Value not_positive = rewriter.create<AtenLtScalarOp>(
3911+
Value not_positive = rewriter.create<AtenLeScalarOp>(
38893912
loc, boolResType, self, constantZeroFloat);
38903913
noise = rewriter.create<AtenWhereSelfOp>(loc, resType, not_positive,
38913914
alpha, oneTensor);
@@ -3897,7 +3920,7 @@ class DecomposeAtenRreluWithNoiseOp
38973920
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
38983921
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
38993922
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
3900-
rewriter.replaceOp(op, rreluOutput);
3923+
rewriter.replaceOp(op, {rreluOutput, noise});
39013924
return success();
39023925
}
39033926
};
@@ -11568,6 +11591,8 @@ class DecomposeComplexOpsPass
1156811591
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
1156911592
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
1157011593
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseOp>(patterns);
11594+
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseFunctionalOp>(
11595+
patterns);
1157111596
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluWithNoiseBackwardOp>(
1157211597
patterns);
1157311598
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
501501
target.addIllegalOp<AtenPreluOp>();
502502
target.addIllegalOp<AtenRreluOp>();
503503
target.addIllegalOp<AtenRreluWithNoiseOp>();
504+
target.addIllegalOp<AtenRreluWithNoiseFunctionalOp>();
504505
target.addIllegalOp<AtenRreluWithNoiseBackwardOp>();
505506
target.addIllegalOp<AtenCeluOp>();
506507
target.addIllegalOp<AtenToDtypeLayoutOp>();

projects/pt1/e2e_testing/xfail_sets.py

-21
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,6 @@
398398
"AtenIntBoolOpConstTrueModule_basic",
399399
"AtenIntBoolOpModule_basic",
400400
"AtenIntMM_basic",
401-
"AtenItemFpOpModule_basic",
402401
"AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size
403402
"Aten_TrilinearModuleVaryingRanks_basic",
404403
"Aten_TrilinearModuleZerodDimBug_basic",
@@ -425,7 +424,6 @@
425424
"CumsumModule_basic",
426425
"CumprodModule_basic",
427426
"DeformConv2D_basic",
428-
"DivFloatModule_basic",
429427
"DivIntModule_basic",
430428
"ElementwiseDequantizePerChannelModule_basic",
431429
"ElementwiseDequantizePerTensorModule_basic",
@@ -439,7 +437,6 @@
439437
"IntFloatModule_basic",
440438
"IntImplicitModule_basic",
441439
"LenStrModule_basic",
442-
"MulFloatModule_basic",
443440
"NativeGroupNormBackwardModule_basic",
444441
"NeFloatIntModule_basic",
445442
"NllLossModuleBackward1DMeanWeight_basic",
@@ -464,15 +461,11 @@
464461
"QuantizedSingleLayer_basic",
465462
"ReduceMaxAlongDimUnsignedInt_basic",
466463
"ReduceMinAlongDimUnsignedInt_basic",
467-
"ScalarImplicitFloatModule_basic",
468464
"SplitDimDynamicModule_basic",
469465
"SplitDimStaticModule_basic",
470466
"SqrtIntModule_basic",
471-
"SubFloatModule_basic",
472467
"TensorToBoolZeroRank_basic",
473468
"TensorToBool_basic",
474-
"TensorToFloatZeroRank_basic",
475-
"TensorToFloat_basic",
476469
"ThresholdBackward2dMixedModule_basic",
477470
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
478471
"UpSampleNearest2dDynamicFactor_basic",
@@ -507,9 +500,6 @@
507500
"MeshgridIndexingIJ_basic",
508501
"MeshgridIndexingXY_basic",
509502
"Meshgrid_basic",
510-
# RuntimeError: cannot mutate tensors with frozen storage
511-
"ElementwiseRreluWithNoiseTrainModule_basic",
512-
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
513503
"ElementwiseSignbitModule_basic",
514504
"ElementwiseCopysignModule_basic",
515505
"BernoulliFloatModule_basic",
@@ -527,9 +517,6 @@
527517
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
528518
"Aten_TrilinearModuleSumAllDims_basic",
529519
"Aten_TrilinearModuleSumdims_basic",
530-
# torch export: RuntimeError: cannot mutate tensors with frozen storage
531-
"ElementwiseRreluWithNoiseTrainModule_basic",
532-
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
533520
}
534521

535522
FX_IMPORTER_STABLEHLO_XFAIL_SET = {
@@ -934,9 +921,6 @@
934921
"UpSampleNearest2dStaticFactor_basic",
935922
"UpSampleNearest2dStaticSize_basic",
936923
"UpSampleNearest2d_basic",
937-
# RuntimeError: cannot mutate tensors with frozen storage
938-
"ElementwiseRreluWithNoiseTrainModule_basic",
939-
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
940924
"BernoulliFloatModule_basic",
941925
"UniformModule_basic",
942926
"UniformStaticShapeModule_basic",
@@ -961,9 +945,6 @@
961945
"Aten_TrilinearModuleSumdims_basic",
962946
"Aten_TrilinearModuleSumAllDims_basic",
963947
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
964-
# torch export: RuntimeError: cannot mutate tensors with frozen storage
965-
"ElementwiseRreluWithNoiseTrainModule_basic",
966-
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
967948
"CrossEntropyLossModule_basic",
968949
"CrossEntropyLossNoReductionModule_basic",
969950
}
@@ -3459,8 +3440,6 @@
34593440
"ElementwiseSignbitModule_basic",
34603441
"Aten_TrilinearModuleVaryingRanks_basic",
34613442
"Aten_TrilinearModuleZerodDimBug_basic",
3462-
"ElementwiseRreluWithNoiseTrainModule_basic",
3463-
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
34643443
"MaxPool3dEmptyStrideStaticModule_basic",
34653444
"MaxPool3dLargeDatadModule_basic",
34663445
"MaxPool3dModuleRandomSimple_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,9 @@ def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0
649649
def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
650650
return upstream_shape_functions.unary(self)
651651

652+
def aten〇rrelu_with_noise_functional〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[List[int], List[int]]:
653+
return upstream_shape_functions.unary(self), upstream_shape_functions.unary(noise)
654+
652655
def aten〇selu〡shape(self: List[int]) -> List[int]:
653656
return upstream_shape_functions.unary(self)
654657

@@ -3472,21 +3475,25 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa
34723475
self_rank, self_dtype = self_rank_dtype
34733476
return self_dtype
34743477

3475-
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, *all_integer_dtypes()}))
3478+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
34763479
def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int:
34773480
self_rank, self_dtype = self_rank_dtype
3478-
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
34793481
return self_dtype
34803482

3481-
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()}))
3483+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2))
34823484
def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int:
34833485
self_rank, self_dtype = self_rank_dtype
34843486
noise_rank, noise_dtype = noise_rank_dtype
3485-
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
3486-
assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype)
34873487
assert self_rank == noise_rank
34883488
return self_dtype
34893489

3490+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2))
3491+
def aten〇rrelu_with_noise_functional〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[int, int]:
3492+
self_rank, self_dtype = self_rank_dtype
3493+
noise_rank, noise_dtype = noise_rank_dtype
3494+
assert self_rank == noise_rank
3495+
return self_dtype, noise_dtype
3496+
34903497
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
34913498
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
34923499
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

+3
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,9 @@ def emit_with_mutating_variants(key, **kwargs):
12121212
emit(
12131213
"aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)"
12141214
)
1215+
emit(
1216+
"aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)"
1217+
)
12151218

12161219
# quantized ops
12171220
emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)")

0 commit comments

Comments
 (0)