Skip to content

Commit

Permalink
transformations: (csl-stencil-bufferize) Inject iter_arg into linalg …
Browse files Browse the repository at this point in the history
…compute (#3033)

Co-authored-by: n-io <[email protected]>
  • Loading branch information
n-io and n-io authored Aug 14, 2024
1 parent a12143e commit 3dbbb80
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 29 deletions.
42 changes: 21 additions & 21 deletions tests/filecheck/transforms/csl_stencil_bufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ builtin.module {
%5 = csl_stencil.access %1[-1, 0] : tensor<4x255xf32>
%6 = csl_stencil.access %1[0, 1] : tensor<4x255xf32>
%7 = csl_stencil.access %1[0, -1] : tensor<4x255xf32>
%8 = arith.addf %7, %6 : tensor<255xf32>
%9 = arith.addf %8, %5 : tensor<255xf32>
%10 = arith.addf %9, %4 : tensor<255xf32>
%8 = linalg.add ins(%7, %6 : tensor<255xf32>, tensor<255xf32>) outs(%7 : tensor<255xf32>) -> tensor<255xf32>
%9 = linalg.add ins(%8, %5 : tensor<255xf32>, tensor<255xf32>) outs(%8 : tensor<255xf32>) -> tensor<255xf32>
%10 = linalg.add ins(%9, %4 : tensor<255xf32>, tensor<255xf32>) outs(%9 : tensor<255xf32>) -> tensor<255xf32>
%11 = "tensor.insert_slice"(%10, %3, %2) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
csl_stencil.yield %11 : tensor<510xf32>
}, {
Expand All @@ -20,9 +20,9 @@ builtin.module {
%15 = arith.constant dense<1.666600e-01> : tensor<510xf32>
%16 = "tensor.extract_slice"(%14) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%17 = "tensor.extract_slice"(%14) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%18 = arith.addf %13, %17 : tensor<510xf32>
%19 = arith.addf %18, %16 : tensor<510xf32>
%20 = arith.mulf %19, %15 : tensor<510xf32>
%18 = linalg.add ins(%13, %17 : tensor<510xf32>, tensor<510xf32>) outs(%13 : tensor<510xf32>) -> tensor<510xf32>
%19 = linalg.add ins(%18, %16 : tensor<510xf32>, tensor<510xf32>) outs(%18 : tensor<510xf32>) -> tensor<510xf32>
%20 = linalg.mul ins(%19, %15 : tensor<510xf32>, tensor<510xf32>) outs(%19 : tensor<510xf32>) -> tensor<510xf32>
csl_stencil.yield %20 : tensor<510xf32>
}) to <[0, 0], [1, 1]>
func.return
Expand All @@ -37,18 +37,18 @@ builtin.module {
// CHECK-NEXT: csl_stencil.apply(%a : memref<512xf32>, %1 : memref<510xf32>) outs (%b : memref<512xf32>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 1>}> ({
// CHECK-NEXT: ^0(%2 : memref<4x255xf32>, %3 : index, %4 : memref<510xf32>):
// CHECK-NEXT: %5 = bufferization.to_tensor %4 restrict writable : memref<510xf32>
// CHECK-NEXT: %6 = "tensor.extract_slice"(%5, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0>}> : (tensor<510xf32>, index) -> tensor<255xf32>
// CHECK-NEXT: %7 = csl_stencil.access %2[1, 0] : memref<4x255xf32>
// CHECK-NEXT: %8 = bufferization.to_tensor %7 restrict : memref<255xf32>
// CHECK-NEXT: %9 = csl_stencil.access %2[-1, 0] : memref<4x255xf32>
// CHECK-NEXT: %10 = bufferization.to_tensor %9 restrict : memref<255xf32>
// CHECK-NEXT: %11 = csl_stencil.access %2[0, 1] : memref<4x255xf32>
// CHECK-NEXT: %12 = bufferization.to_tensor %11 restrict : memref<255xf32>
// CHECK-NEXT: %13 = csl_stencil.access %2[0, -1] : memref<4x255xf32>
// CHECK-NEXT: %14 = bufferization.to_tensor %13 restrict : memref<255xf32>
// CHECK-NEXT: %15 = arith.addf %14, %12 : tensor<255xf32>
// CHECK-NEXT: %16 = arith.addf %15, %10 : tensor<255xf32>
// CHECK-NEXT: %17 = arith.addf %16, %8 : tensor<255xf32>
// CHECK-NEXT: %6 = csl_stencil.access %2[1, 0] : memref<4x255xf32>
// CHECK-NEXT: %7 = bufferization.to_tensor %6 restrict : memref<255xf32>
// CHECK-NEXT: %8 = csl_stencil.access %2[-1, 0] : memref<4x255xf32>
// CHECK-NEXT: %9 = bufferization.to_tensor %8 restrict : memref<255xf32>
// CHECK-NEXT: %10 = csl_stencil.access %2[0, 1] : memref<4x255xf32>
// CHECK-NEXT: %11 = bufferization.to_tensor %10 restrict : memref<255xf32>
// CHECK-NEXT: %12 = csl_stencil.access %2[0, -1] : memref<4x255xf32>
// CHECK-NEXT: %13 = bufferization.to_tensor %12 restrict : memref<255xf32>
// CHECK-NEXT: %14 = "tensor.extract_slice"(%5, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0>}> : (tensor<510xf32>, index) -> tensor<255xf32>
// CHECK-NEXT: %15 = linalg.add ins(%13, %11 : tensor<255xf32>, tensor<255xf32>) outs(%14 : tensor<255xf32>) -> tensor<255xf32>
// CHECK-NEXT: %16 = linalg.add ins(%15, %9 : tensor<255xf32>, tensor<255xf32>) outs(%15 : tensor<255xf32>) -> tensor<255xf32>
// CHECK-NEXT: %17 = linalg.add ins(%16, %7 : tensor<255xf32>, tensor<255xf32>) outs(%16 : tensor<255xf32>) -> tensor<255xf32>
// CHECK-NEXT: %18 = "tensor.insert_slice"(%17, %5, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: %19 = bufferization.to_memref %18 : memref<510xf32>
// CHECK-NEXT: csl_stencil.yield %19 : memref<510xf32>
Expand All @@ -61,9 +61,9 @@ builtin.module {
// CHECK-NEXT: %26 = bufferization.to_tensor %25 restrict : memref<510xf32>
// CHECK-NEXT: %27 = "tensor.extract_slice"(%24) <{"static_offsets" = array<i64: 2>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %28 = "tensor.extract_slice"(%24) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %29 = arith.addf %22, %28 : tensor<510xf32>
// CHECK-NEXT: %30 = arith.addf %29, %27 : tensor<510xf32>
// CHECK-NEXT: %31 = arith.mulf %30, %26 : tensor<510xf32>
// CHECK-NEXT: %29 = linalg.add ins(%22, %28 : tensor<510xf32>, tensor<510xf32>) outs(%22 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %30 = linalg.add ins(%29, %27 : tensor<510xf32>, tensor<510xf32>) outs(%29 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %31 = linalg.mul ins(%30, %26 : tensor<510xf32>, tensor<510xf32>) outs(%30 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %32 = bufferization.to_memref %31 : memref<510xf32>
// CHECK-NEXT: csl_stencil.yield %32 : memref<510xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
Expand Down
70 changes: 62 additions & 8 deletions xdsl/transforms/csl_stencil_bufferize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass

from xdsl.context import MLContext
from xdsl.dialects import arith, bufferization, func, memref, stencil, tensor
from xdsl.dialects import arith, bufferization, func, linalg, memref, stencil, tensor
from xdsl.dialects.builtin import (
DenseArrayBase,
DenseIntOrFPElementsAttr,
Expand Down Expand Up @@ -114,12 +114,6 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
t := to_tensor_op(arg, writable=idx == 2),
InsertPoint.at_end(buf_apply_op.chunk_reduce.block),
)
if idx == 2:
offset_arg = buf_apply_op.chunk_reduce.block.args[1]
rewriter.insert_op(
self._build_extract_slice(op, t, offset_arg),
InsertPoint.at_end(buf_apply_op.chunk_reduce.block),
)
chunk_reduce_arg_mapping.append(t.tensor)
else:
chunk_reduce_arg_mapping.append(arg)
Expand All @@ -138,6 +132,9 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
else:
post_process_arg_mapping.append(arg)

assert isa(typ := op.chunk_reduce.block.args[0].type, TensorType[Attribute])
chunk_type = TensorType(typ.get_element_type(), typ.get_shape()[1:])

# inline blocks from old into new regions
rewriter.inline_block(
op.chunk_reduce.block,
Expand All @@ -151,6 +148,10 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
post_process_arg_mapping,
)

self._inject_iter_arg_into_linalg_outs(
buf_apply_op, rewriter, chunk_type, chunk_reduce_arg_mapping[2]
)

# insert new op
rewriter.replace_matched_op(new_ops=[*to_memrefs, buf_apply_op])

Expand All @@ -170,6 +171,58 @@ def _get_empty_bufferized_region(args: Sequence[BlockArgument]) -> Region:
)
)

@staticmethod
def _inject_iter_arg_into_linalg_outs(
op: csl_stencil.ApplyOp,
rewriter: PatternRewriter,
chunk_type: TensorType[Attribute],
iter_arg: SSAValue,
):
"""
Finds a linalg op with `chunk_type` shape in `outs` and injects
an extracted slice of `iter_arg`. This is a work-around for the
way bufferization works, causing it to use `iter_arg` as an accumulator
and avoiding having an extra alloc + memref.copy.
"""
linalg_op: linalg.NamedOpBase | None = None
for curr_op in op.chunk_reduce.block.ops:
if (
isinstance(curr_op, linalg.NamedOpBase)
and len(curr_op.outputs) > 0
and curr_op.outputs.types[0] == chunk_type
):
linalg_op = curr_op
break

if linalg_op is None:
return

rewriter.replace_op(
linalg_op,
[
extract_slice_op := tensor.ExtractSliceOp(
operands=[iter_arg, [op.chunk_reduce.block.args[1]], [], []],
result_types=[chunk_type],
properties={
"static_offsets": DenseArrayBase.from_list(
i64, (memref.Subview.DYNAMIC_INDEX,)
),
"static_sizes": DenseArrayBase.from_list(
i64, chunk_type.get_shape()
),
"static_strides": DenseArrayBase.from_list(i64, (1,)),
},
),
type(linalg_op).build(
operands=[linalg_op.inputs, extract_slice_op.results],
result_types=linalg_op.result_types,
properties=linalg_op.properties,
attributes=linalg_op.attributes,
regions=[linalg_op.detach_region(r) for r in linalg_op.regions],
),
],
)

@staticmethod
def _build_extract_slice(
op: csl_stencil.ApplyOp, to_tensor: bufferization.ToTensorOp, offset: SSAValue
Expand Down Expand Up @@ -303,7 +356,8 @@ class CslStencilBufferize(ModulePass):
"""
Bufferizes the csl_stencil dialect.
Creates a `tensor.extract_slice` op needed by `lift-arith-to-linalg` and should be run without `cse` in between.
Attempts to inject `csl_stencil.apply.chunk_reduce.iter_arg` into linalg compute ops `outs` within that region
for improved bufferization. Ideally be run after `--lift-arith-to-linalg`.
"""

name = "csl-stencil-bufferize"
Expand Down

0 comments on commit 3dbbb80

Please sign in to comment.