Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added support for debufferization across nested regions - working for…
Browse files Browse the repository at this point in the history
… scf.if
arpitj1 committed Jan 31, 2025
1 parent 490f924 commit e20708c
Showing 3 changed files with 485 additions and 79 deletions.
2 changes: 2 additions & 0 deletions lib/polygeist/Ops.cpp
Original file line number Diff line number Diff line change
@@ -674,6 +674,8 @@ bool isCaptured(Value v, Operation *potentialUser = nullptr,
for (auto u : v.getUsers()) {
if (seenuse && u == potentialUser)
*seenuse = true;
if (isa<linalg::GenericOp>(u))
continue;
if (isa<memref::LoadOp, LLVM::LoadOp, affine::AffineLoadOp,
polygeist::CacheLoad>(u))
continue;
346 changes: 329 additions & 17 deletions lib/polygeist/Passes/LinalgDebufferize.cpp
Original file line number Diff line number Diff line change
@@ -34,15 +34,214 @@ using namespace bufferization;
bool isCaptured(Value v, Operation *potentialUser = nullptr,
bool *seenuse = nullptr);

bool isAncestor(Operation *potentialAncestor, Operation *op) {
Operation *current = op->getParentOp();
while (current != nullptr) {
if (current == potentialAncestor)
return true;
current = current->getParentOp();
}
return false;
}

//Checks if a comes before b
bool comesBefore(Operation *a, Operation *b) {
if (a == b) return false;

if (isAncestor(a, b)) return true;
if (isAncestor(b, a)) return false;

//Block *aBlock = a->getBlock();
//Block *bBlock = b->getBlock();

//// Same block: compare operation order
//if (aBlock == bBlock) {
// for (Operation &op : aBlock->getOperations()) {
// if (&op == a) return true;
// if (&op == b) return false;
// }
// llvm_unreachable("Operations not found in their parent block");
//}

//// Different blocks: compare region hierarchy
//Region *aRegion = aBlock->getParent();
//Region *bRegion = bBlock->getParent();

//// Same region: compare block order
//if (aRegion == bRegion) {
// //auto aBlockIt = std::find(aRegion->begin(), aRegion->end(), aBlock);
// //auto bBlockIt = std::find(aRegion->begin(), aRegion->end(), bBlock);
// //return aBlockIt < bBlockIt;
// //const int aIndex = std::distance(aRegion->begin(), aRegion->find(aBlock));
// //const int bIndex = std::distance(aRegion->begin(), aRegion->find(bBlock));
// //return aIndex < bIndex;
// auto get_block_pos = [](Region *region, Block *block) {
// auto &blocks = region->getBlocks();
// auto it = llvm::find_if(blocks, [block](Block &b) {
// return &b == block; // Address comparison
// });
// assert(it != blocks.end() && "Block not found in region");
// return std::distance(blocks.begin(), it);
// //return std::distance(region->getBlocks().begin(),
// // llvm::find(region->getBlocks(), block));
// };
// return get_block_pos(aRegion, aBlock) <
// get_block_pos(aRegion, bBlock);
//}

//// Different regions: compare parent operations
//Operation *aParent = aRegion->getParentOp();
//Operation *bParent = bRegion->getParentOp();

//// Same parent op: compare region order
//if (aParent == bParent) {
// //auto aRegionIt = std::find(aParent->getRegions().begin(),
// // aParent->getRegions().end(), aRegion);
// //auto bRegionIt = std::find(bParent->getRegions().begin(),
// // bParent->getRegions().end(), bRegion);
// //return aRegionIt < bRegionIt;
// //auto get_region_position = [](Operation *parent, Region *target) {
// //return std::distance(
// // parent->getRegions.begin(),
// // llvm::find_if(parent->getRegions(), [&](Region &r) {
// // return &r == target; // Compare region addresses
// // })
// // );
// //};

// auto get_region_position = [](Operation *parent, Region *target) {
// auto regions = parent->getRegions(); // Get reference to region list
// auto begin = regions.begin();
// auto it = llvm::find_if(regions, [&](Region &r) {
// return &r == target;
// });
// return std::distance(begin, it);
// };
// return get_region_position(aParent, aRegion) <
// get_region_position(aParent, bRegion);
//}

Operation *aParent = a->getParentOp();
Operation *bParent = b->getParentOp();
// Walk up b's hierarchy until we reach a's level
Operation *bAncestor = b;
//We traverse B's ancestors here
while (Operation *parent = bAncestor->getParentOp()) {
if (parent == aParent) {
// Compare positions within aParent's regions/blocks
Region *aRegion = a->getParentRegion();
Region *bRegion = bAncestor->getParentRegion();

if (aRegion == bRegion) {
// Same region: compare block order
Block *aBlock = a->getBlock();
Block *bBlock = bAncestor->getBlock();
if (aBlock != bBlock) {
auto get_block_pos = [](Region *region, Block *block) {
auto &blocks = region->getBlocks();
auto it = llvm::find_if(blocks, [block](Block &b) {
return &b == block; // Address comparison
});
assert(it != blocks.end() && "Block not found in region");
return std::distance(blocks.begin(), it);
};
return get_block_pos(aRegion, aBlock) <
get_block_pos(bRegion, bBlock);
};
// Same block: compare operation order
return a->isBeforeInBlock(bAncestor);
}

// Different regions: compare region order
auto compareRegions = [parent](Region *x, Region *y) {
auto get_region_position = [](Operation *parent, Region *target) {
auto regions = parent->getRegions(); // Get reference to region list
auto begin = regions.begin();
auto it = llvm::find_if(regions, [&](Region &r) {
return &r == target;
});
return std::distance(begin, it);
};
return get_region_position(parent, x) <
get_region_position(parent, y);
};
return compareRegions(aRegion, bRegion);
}
bAncestor = parent;
}

Operation *aAncestor = a;
//We traverse A's ancestors here
while (Operation *parent = aAncestor->getParentOp()) {
if (parent == bParent) {
// Compare positions within aParent's regions/blocks
Region *bRegion = b->getParentRegion();
Region *aRegion = aAncestor->getParentRegion();

if (aRegion == bRegion) {
// Same region: compare block order
Block *bBlock = b->getBlock();
Block *aBlock = aAncestor->getBlock();
if (aBlock != bBlock) {
auto get_block_pos = [](Region *region, Block *block) {
auto &blocks = region->getBlocks();
auto it = llvm::find_if(blocks, [block](Block &b) {
return &b == block; // Address comparison
});
assert(it != blocks.end() && "Block not found in region");
return std::distance(blocks.begin(), it);
};
return !(get_block_pos(bRegion, bBlock) <
get_block_pos(aRegion, aBlock));
};
// Same block: compare operation order
return !b->isBeforeInBlock(aAncestor);
}

// Different regions: compare region order
auto compareRegions = [parent](Region *x, Region *y) {
auto get_region_position = [](Operation *parent, Region *target) {
auto regions = parent->getRegions(); // Get reference to region list
auto begin = regions.begin();
auto it = llvm::find_if(regions, [&](Region &r) {
return &r == target;
});
return std::distance(begin, it);
};
return get_region_position(parent, x) <
get_region_position(parent, y);
};
return !compareRegions(bRegion, aRegion);
}
aAncestor = parent;
}

llvm_unreachable("Operations do not share a common ancestor");
//// Recursive case: compare parent operations
//return comesBefore(aParent, bParent);
}

std::vector<Operation *> getSortedUsers(Value val) {
std::vector<Operation*> users;
for (Operation *user : val.getUsers()) {
users.push_back(user);
}

//TODO: problem is this only works for 1 level
// Sort the users based on their topological order
std::sort(users.begin(), users.end(), [](Operation *a, Operation *b) {
return a->isBeforeInBlock(b);
return comesBefore(a,b);
//if (a->getBlock() == b->getBlock()) {
// return a->isBeforeInBlock(b);
//}
//if (a->getParentRegion() == b->getParentRegion()) {
// Block *blockA = a->getBlock();
// Block *blockB = b->getBlock();
// return std::distance(blockA->getParent()->begin(), blockA->getIterator()) <
// std::distance(blockB->getParent()->begin(), blockB->getIterator());
//}

//return a->getParentRegion()->isAncestor(b->getParentRegion());
});

return users;
@@ -70,6 +269,27 @@ std::vector<Operation *> getSortedUsers(Operation *op) {
return sortedUsers;
}

Region* findCommonAncestorRegion(Operation* a, Operation* b) {
DenseMap<Region*, size_t> regionCounts;

// Walk up from operation A
Operation* currentOp = a;
while (Region* region = currentOp->getParentRegion()) {
regionCounts[region]++;
currentOp = region->getParentOp();
}

// Walk up from operation B to find common region
currentOp = b;
while (Region* region = currentOp->getParentRegion()) {
if (regionCounts.count(region))
return region;
currentOp = region->getParentOp();
}
return nullptr;
}


struct debufferizationAllocaRemoval : public OpRewritePattern<memref::AllocaOp> {
using OpRewritePattern<memref::AllocaOp>::OpRewritePattern;

@@ -109,6 +329,20 @@ struct debufferizationAllocaRemoval : public OpRewritePattern<memref::AllocaOp>
}
};

// Problems with this implementation: The way this implementation works is by jumping over users
// of alloca/args. The users we get are not in sorted order. We write a function to sort out the users across
// regions, blocks and ops as long as they lie in the same ancestry.
// Now as we update an op, and use the output tensor to give input to the next op- it works fine for simple cases with no region.
// But things becomes more complicated when we have nested regions like in scf.if and scf.for ops
// Why? Because we need to update scf.if and scf.for ops to yield correct tensors to be used by the next user.
// So how to do it? Well the best way is to traverse all the IR in a walk and and as we encouter a user and it's linalg.generic then we update
// it's params to tensor and generate an output tensor if it can, and move to the next op and repeat this until we encounter an end of region.
// At this point we need to decide if we need to yield the tensor or not? This depends if there is an external user of the original arg/alloca
// still left over. I think this can be done by tracking users of an op, and eliminating the ones which have been used.
// In the current way it's done- we can go the next user and check if the previous user is in the same block if not we need to propagate the previous
// users output tensor through regions with yield.
// How does this work if the user is not actually outputing data, that means it didn't generate an output tensor. In which case the original tensor needs to be continued.
// In current flow, we are tracking updated output tensor, now we can iteratively yield the value until it reaches the same block as next user.
struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern<func::FuncOp>::OpRewritePattern;

@@ -153,7 +387,7 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
// if we are no alias we can just look at all users of the value
// if we are not noalias, or we are captured, then we have to look at all users that
// could read or write
if (!isNoalias) { //|| isCaptured(memVal)) { TODO: need to improve isCaptured to include linalg.generic
if ((!isNoalias) || isCaptured(memVal)) { //TODO: need to improve isCaptured to include linalg.generic
return failure(); //|| isCaptured(memVal)) { TODO: need to improve isCaptured to include linalg.generic
}

@@ -185,6 +419,7 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(
memVal.getLoc(), tensorType, memVal);
Value currentTensor = toTensorOp;
Value prevTensor = toTensorOp;

auto sortedUsers = getSortedUsers(memVal);

@@ -202,6 +437,82 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
SmallVector<Type> resultTypes;
// Create a new linalg.generic in Destination Style Passing format

//check_if_current_tensor_is_available_to_user_if_not_propagate_to_scope() {
// extract_common_ancestor of curentTensor and userOp.
// propagte currentTensor all the way to common ancestor.
// Make the propagated value the current tensor.
//}
auto commonRegion = findCommonAncestorRegion(currentTensor.getDefiningOp(), user);
if (!commonRegion) return failure();
// Collect regions from source to common ancestor
SmallVector<Region*> regions;
for (Region* r = currentTensor.getParentRegion(); r != commonRegion;
r = r->getParentOp()->getParentRegion()) {
regions.push_back(r);
}

// Propagate value through each region
Value currentValue = currentTensor;
for (Region* region : llvm::reverse(regions)) {
Block& block = region->front();
Operation* terminator = block.getTerminator();
Operation *parentOp = region->getParentOp();

if( auto prevIf = dyn_cast_or_null<scf::IfOp>(parentOp)) {
auto prevResults = prevIf.getResults();
SmallVector<Type> newResultTypes;
for (auto res : prevResults)
newResultTypes.push_back(res.getType());
newResultTypes.push_back(currentValue.getType());

// Yield original results + new value
auto thenYieldArgs = prevIf.thenYield().getOperands();
SmallVector<Value> thenYieldValues;
for (const auto &it :thenYieldArgs) {
thenYieldValues.push_back(it);
}
thenYieldValues.push_back(currentValue);

SmallVector<Value> elseYieldValues;
if(!prevIf.getElseRegion().empty()){
auto elseYieldArgs = prevIf.elseYield().getOperands();
for (const auto &it :elseYieldArgs) {
elseYieldValues.push_back(it);
}
}
elseYieldValues.push_back(prevTensor);

//Create new Ifop
rewriter.setInsertionPoint(prevIf);
auto newIf = rewriter.create<scf::IfOp>(prevIf.getLoc(),
newResultTypes, // Combined types
prevIf.getCondition(), // New condition value
true
);
if (newIf.thenBlock())
rewriter.eraseBlock(newIf.thenBlock());

newIf.getThenRegion().takeBody(prevIf.getThenRegion());
if(!prevIf.getElseRegion().empty())
newIf.getElseRegion().takeBody(prevIf.getElseRegion());


//Update yield ops
rewriter.setInsertionPointToEnd(newIf.thenBlock());
rewriter.replaceOpWithNewOp<scf::YieldOp>(newIf.thenYield(), thenYieldValues);
if(!prevIf.getElseRegion().empty()) {
rewriter.setInsertionPointToEnd(newIf.elseBlock());
rewriter.replaceOpWithNewOp<scf::YieldOp>(newIf.elseYield(), elseYieldValues);
} else {
rewriter.setInsertionPointToEnd(newIf.elseBlock());
rewriter.create<scf::YieldOp>(newIf.getLoc(), elseYieldValues);
}

currentValue = newIf->getResult(newIf->getNumResults() - 1);
}
}
currentTensor = currentValue;

ArrayAttr indexingMaps = genericOp.getIndexingMaps();
for (auto input : genericOp.getInputs()) {
newInputs.push_back(input == memVal ? currentTensor : input);
@@ -220,6 +531,7 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
index++;
}

rewriter.setInsertionPointAfter(genericOp);
StringAttr empty = StringAttr::get(genericOp.getContext());
ArrayRef<Type> resultTypesRef(resultTypes);
auto newGenericOp = rewriter.create<linalg::GenericOp>(
@@ -239,14 +551,16 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
}

// Delete the original genericOp
if (newCurrentTensorIndex != -1)
if (newCurrentTensorIndex != -1){
prevTensor = currentTensor;
currentTensor = newGenericOp.getResult(newCurrentTensorIndex);
}

processedGenericOps.insert(genericOp.getOperation());
// Delete the original genericOp
//genericOp.erase();
genericOp.erase();
//WalkResult::interrupt();
opsToDelete.push_back(genericOp.getOperation());
//opsToDelete.push_back(genericOp.getOperation());
}
}

@@ -259,19 +573,17 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {


bool changed;
do {
changed = funcOp.walk([&](memref::AllocaOp alloca) {
//if (handleMemref(alloca.getResult()).succeeded())
// return WalkResult::advance();
//return WalkResult::interrupt();
handleMemref(alloca.getResult()).succeeded();
return WalkResult::advance();
}).wasInterrupted();

if (changed)
passResult = success();
} while (changed);
//Fix instead of walk, just get the list of allocaOp users, so that you can easily delete ops inside
SmallVector<memref::AllocaOp> listOfAllocaOps;

funcOp.walk([&](memref::AllocaOp alloca) {
listOfAllocaOps.push_back(alloca);
});

for (auto alloca : listOfAllocaOps) {
handleMemref(alloca);
}

if (llvm::any_of(llvm::map_range(funcOp.getArguments(), handleMemref), [](LogicalResult res) {return res.succeeded();}))

passResult = success();
216 changes: 154 additions & 62 deletions test/polygeist-opt/debufferize.mlir
Original file line number Diff line number Diff line change
@@ -40,68 +40,160 @@
}
}

module @conv_2 {
func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
%c0_i32 = arith.constant 0 : i32
linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) {
^bb0(%in: i32, %in_0: i32, %out: i32):
%3 = arith.muli %in, %in_0 : i32
%4 = arith.addi %out, %3 : i32
linalg.yield %4 : i32
}
return %c0_i32 : i32
}
module @in_place_cond_add{
func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) {
%c0 = arith.constant 0 : index
//%buffer = memref.alloca() : memref<128xf32>
scf.if %cond {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
linalg.yield %sum : f32
}
}
return
}
}

module @harris_score_with_gradient_extra_kernel {
//memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1>
//memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
//memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
func.func @main(%0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
%c4_i32 = arith.constant 4 : i32
%c0_i32 = arith.constant 0 : i32
%alloca = memref.alloca() : memref<512x512xi32>
%alloca_0 = memref.alloca() : memref<512x512xi32>
%alloca_1 = memref.alloca() : memref<512x512xi32>
%alloca_2 = memref.alloca() : memref<516x516xi32>
%alloca_3 = memref.alloca() : memref<516x516xi32>
%alloca_4 = memref.alloca() : memref<518x518xi32>
//%score = memref.alloca() : memref<512x512xi32>
//%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32>
//%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32>
//%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32>
linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) {
^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32):
%4 = arith.muli %in, %in_5 : i32
%5 = arith.addi %out_7, %4 : i32
%6 = arith.muli %in, %in_6 : i32
%7 = arith.addi %out, %6 : i32
linalg.yield %7, %5 : i32, i32
}
linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) {
^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32):
%4 = arith.muli %in, %in : i32
%5 = arith.muli %4, %in_6 : i32
%6 = arith.addi %out_8, %5 : i32
%7 = arith.muli %in_5, %in_5 : i32
%8 = arith.muli %7, %in_6 : i32
%9 = arith.addi %out_7, %8 : i32
%10 = arith.muli %in, %in_5 : i32
%11 = arith.muli %10, %in_6 : i32
%12 = arith.addi %out, %11 : i32
linalg.yield %12, %9, %6 : i32, i32, i32
}
linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) {
^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32):
%4 = arith.muli %in, %in_5 : i32
%5 = arith.muli %in_6, %in_6 : i32
%6 = arith.subi %4, %5 : i32
%7 = arith.addi %in, %in_5 : i32
%8 = arith.muli %7, %c4_i32 : i32
%9 = arith.muli %8, %7 : i32
%10 = arith.subi %6, %9 : i32
linalg.yield %10 : i32
module @in_place_add_for{
func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
//%buffer = memref.alloca() : memref<128xf32>
scf.for %i = %c0 to %c10 step %c1 {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
linalg.yield %sum : f32
}
}
return
}
}

//Case when buffer is captured
module @in_place_add_for_loop_carried{
func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
//%buffer = memref.alloca() : memref<128xf32>
%result = scf.for %i = %c0 to %c10 step %c1 iter_args(%buf = %buffer) -> (memref<128xf32>) {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buf : memref<128xf32>)
outs(%buf : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
linalg.yield %sum : f32
}
scf.yield %buf : memref<128xf32>
}
return
}
}

module @in_place_cond_add_followed_by_add{
func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1) {
%c0 = arith.constant 0 : index
//%buffer = memref.alloca() : memref<128xf32>
scf.if %cond {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
%sum2 = arith.addf %sum, %value : f32
linalg.yield %sum2 : f32
}
}
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
linalg.yield %sum : f32
}
return
}
return %c0_i32 : i32
}
}
}

// module @conv_2 {
// func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
// %c0_i32 = arith.constant 0 : i32
// linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) {
// ^bb0(%in: i32, %in_0: i32, %out: i32):
// %3 = arith.muli %in, %in_0 : i32
// %4 = arith.addi %out, %3 : i32
// linalg.yield %4 : i32
// }
// return %c0_i32 : i32
// }
// }

// module @harris_score_with_gradient_extra_kernel {
// //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1>
// //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
// //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
// func.func @main(%0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
// %c4_i32 = arith.constant 4 : i32
// %c0_i32 = arith.constant 0 : i32
// %alloca = memref.alloca() : memref<512x512xi32>
// %alloca_0 = memref.alloca() : memref<512x512xi32>
// %alloca_1 = memref.alloca() : memref<512x512xi32>
// %alloca_2 = memref.alloca() : memref<516x516xi32>
// %alloca_3 = memref.alloca() : memref<516x516xi32>
// %alloca_4 = memref.alloca() : memref<518x518xi32>
// //%score = memref.alloca() : memref<512x512xi32>
// //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32>
// //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32>
// //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32>
// linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) {
// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32):
// %4 = arith.muli %in, %in_5 : i32
// %5 = arith.addi %out_7, %4 : i32
// %6 = arith.muli %in, %in_6 : i32
// %7 = arith.addi %out, %6 : i32
// linalg.yield %7, %5 : i32, i32
// }
// linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) {
// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32):
// %4 = arith.muli %in, %in : i32
// %5 = arith.muli %4, %in_6 : i32
// %6 = arith.addi %out_8, %5 : i32
// %7 = arith.muli %in_5, %in_5 : i32
// %8 = arith.muli %7, %in_6 : i32
// %9 = arith.addi %out_7, %8 : i32
// %10 = arith.muli %in, %in_5 : i32
// %11 = arith.muli %10, %in_6 : i32
// %12 = arith.addi %out, %11 : i32
// linalg.yield %12, %9, %6 : i32, i32, i32
// }
// linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) {
// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32):
// %4 = arith.muli %in, %in_5 : i32
// %5 = arith.muli %in_6, %in_6 : i32
// %6 = arith.subi %4, %5 : i32
// %7 = arith.addi %in, %in_5 : i32
// %8 = arith.muli %7, %c4_i32 : i32
// %9 = arith.muli %8, %7 : i32
// %10 = arith.subi %6, %9 : i32
// linalg.yield %10 : i32
// }
// return %c0_i32 : i32
// }
// }

0 comments on commit e20708c

Please sign in to comment.