Skip to content

Commit

Permalink
Decompose AtenNormalFunctionalOp into AtenRandn* and other arithmetic. (
Browse files Browse the repository at this point in the history
  • Loading branch information
godot73 authored Jan 16, 2024
1 parent f85e5c9 commit a8538e1
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 2 deletions.
17 changes: 17 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7655,6 +7655,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg0 : !torch.float to !torch.union<float, int>\n"
" %1 = torch.derefine %arg1 : !torch.float to !torch.union<float, int>\n"
Expand Down Expand Up @@ -11557,6 +11560,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list<int>, %arg1: !torch.any, %arg2: !torch.optional<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<Device>, %arg5: !torch.optional<bool>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int6 = torch.constant.int 6\n"
Expand Down
34 changes: 32 additions & 2 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3669,9 +3669,38 @@ class DecomposeAtenExponentialOp : public OpRewritePattern<AtenExponentialOp> {
return success();
}
};
} // namespace

namespace {
// aten.normal_functional(mean, sigma) = randn() * sigma + mean.
class DecomposeAtenNormalFunctionalOp
: public OpRewritePattern<AtenNormalFunctionalOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNormalFunctionalOp op,
PatternRewriter &rewriter) const override {
if (!op.getGenerator().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to be None because only global default "
"generator is supported");

Location loc = op.getLoc();
Type resultType = op.getType();
Value std = op.getStd();
Value mean = op.getMean();

Value none = rewriter.create<ConstantNoneOp>(loc);
Value one =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value randN = rewriter.create<AtenRandnLikeOp>(
loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
Value stdRandN =
rewriter.create<AtenMulScalarOp>(loc, resultType, randN, std);
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN,
mean, /*alpha=*/one);
return success();
}
};

template <typename OpTy, typename T1T2Op>
class DecomposeAtenAddCLikeOp : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
Expand Down Expand Up @@ -6591,6 +6620,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnGeneratorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNormalFunctionalOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRandnOp>();
target.addIllegalOp<AtenRandnGeneratorOp>();
target.addIllegalOp<AtenRandnLikeOp>();
target.addIllegalOp<AtenNormalFunctionalOp>();
target.addIllegalOp<AtenVarMeanOp>();
target.addIllegalOp<AtenCosineSimilarityOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>();
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,6 +1484,7 @@
"VarMeanUnbiasedModule_basic",
"RandnLikeModule_basic",
"RandnLikeDtypeModule_basic",
"NormalFunctionalModule_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliPModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,9 @@ def aten〇randn〡shape(size: List[int], dtype: Optional[int] = None, layout: O
def aten〇randn〇generator〡shape(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size

def aten〇normal_functional〡shape(self: List[int], mean: float = 0., std: float = 1., generator: Any = None) -> List[int]:
return self

def aten〇arange〇start_step〡shape(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)

Expand Down Expand Up @@ -3822,6 +3825,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O
assert not is_integer_dtype(dtype)
return dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(
num_of_tensors=1,
error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}))
def aten〇normal_functional〡dtype(self_rank_dtype: Tuple[int, int], mean: float = 0., std: float = 1., generator: Any = None) -> int:
self_rank, self_dtype = self_rank_dtype
if self_dtype is None:
return torch.float32
assert not is_integer_dtype(self_dtype)
return self_dtype

@check_dtype_function([Invocation(size=[1], generator=None),
Invocation(size=[1], generator=None, dtype=torch.float32),
ErrorInvocation(size=[1], generator=None, dtype=torch.int32),
Expand Down
21 changes: 21 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,3 +605,24 @@ def forward(self, x):
@register_test_case(module_factory=lambda: RandnLikeDtypeModule())
def RandnLikeDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(256, 1024).double())
# ==============================================================================

class NormalFunctionalModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float64, True),
])
def forward(self, x):
a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0)
mean = torch.mean(a)
std = torch.std(a)
return mean, std


@register_test_case(module_factory=lambda: NormalFunctionalModule())
def NormalFunctionalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2048, 4096).double())

0 comments on commit a8538e1

Please sign in to comment.