Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
  • Loading branch information
Yancey1989 committed Mar 19, 2024
1 parent 9dba20a commit 81aa949
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 22 deletions.
51 changes: 30 additions & 21 deletions tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ void maybeEmitInitLoops(OpBuilder& b,
b.setInsertionPoint(root_ops.back());
SmallVector<Operation*, 4> col_reduction_ops;
for (Operation* root_op : root_ops) {
if (isRank2ColReduction(root_op)) {
if (isRank2ColReduction(root_op) || isRank2ScalarReduction(root_op)) {
col_reduction_ops.emplace_back(root_op);
}
}
Expand Down Expand Up @@ -1294,7 +1294,7 @@ LogicalResult lowerWithScheduleParallelReduction(
std::back_inserter(scalar_reduction_roots),
[](Operation* operation) { return isRank2ScalarReduction(operation); });
auto root_op = scalar_reduction_roots.back();
const int thread_per_block = getCTASize(dominant_op);
const int thread_per_block = 256;
Location loc = dominant_op->getLoc();
OpBuilder b(root_ops.back());

Expand All @@ -1304,10 +1304,10 @@ LogicalResult lowerWithScheduleParallelReduction(
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
Value two = b.create<arith::ConstantIndexOp>(loc, 2);
auto elemFloatType = getLhloOpsElementType(root_op).cast<FloatType>();
Value zero_f = b.create<arith::ConstantFloatOp>(
loc, llvm::APFloat(0.0f), elemFloatType); // b.getF32Type());
Value one_f = b.create<arith::ConstantFloatOp>(
loc, llvm::APFloat(1.0f), elemFloatType); // b.getF32Type());
Value zero_f = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(getLhloOpsElementType(root_op), 0));
Value one_f = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(getLhloOpsElementType(root_op), 1));

// Start to emit.
Value num_blocks = b.create<arith::ConstantIndexOp>(loc, 1024);
Expand All @@ -1333,8 +1333,8 @@ LogicalResult lowerWithScheduleParallelReduction(
b.create<gpu::ThreadIdOp>(loc, b.getIndexType(), gpu::Dimension::x);
Value grid_dim =
b.create<gpu::GridDimOp>(loc, b.getIndexType(), gpu::Dimension::x);
tid = b.create<arith::RemSIOp>(loc, tid, block_dim);
// i = blockIdx.x * block_size * 2 + tid;
// tid = b.create<arith::RemSIOp>(loc, tid, block_dim);
// i = blockIdx.x * block_size * 2 + tid;
Value i = b.create<arith::AddIOp>(
loc,
b.create<arith::MulIOp>(
Expand All @@ -1348,11 +1348,7 @@ LogicalResult lowerWithScheduleParallelReduction(
Value n = b.create<memref::DimOp>(loc, lhs, zero);
Value m = b.create<memref::DimOp>(loc, lhs, one);
Value mn = b.create<arith::MulIOp>(loc, m, n);
// __shared__ float shm[block_size];
auto shared_mem = createSharedMemory(b, loc, thread_per_block,
getLhloOpsElementType(dominant_op));
Value acc_value;
// fused accumulation

// acc: init_values[num_col_reductions]
SmallVector<AccumulatorFactory, 4> accum_factory(
scalar_reduction_roots.size());
Expand Down Expand Up @@ -1393,14 +1389,29 @@ LogicalResult lowerWithScheduleParallelReduction(
for (auto* root_op : root_ops) {
if (isRank2ScalarReduction(root_op)) {
auto lhs = root_op->getOperands().begin();
SmallVector<Value, 2> load_index({i, zero});
SmallVector<Value, 2> load_index({var_j, zero});
Value data = createLoadOrUseCachedValue(
loc, &b, root_op, *lhs, load_index, b.saveInsertionPoint());
SmallVector<Value, 2> load_index2(
{b.create<arith::AddIOp>(loc, i, block_dim), zero});
Value index2 = b.create<arith::AddIOp>(loc, var_j, block_dim);
// if (i + grid_size < n)
scf::IfOp if_tid_valid_op = b.create<scf::IfOp>(
loc, /*resultTypes*/ init_values_types,
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, index2, n),
/*hasElseRegion*/ true);
if_tid_valid_op.getThenRegion().front().clear();
if_tid_valid_op.getElseRegion().front().clear();
b.setInsertionPointToStart(&if_tid_valid_op.getThenRegion().front());
SmallVector<Value, 2> load_index2({index2, zero});
Value data1 = createLoadOrUseCachedValue(
loc, &b, root_op, *lhs, load_index2, b.saveInsertionPoint());
Value sum = (accum_factory[scalar_red_root_op_idx])(data, data1);
b.setInsertionPointToEnd(&if_tid_valid_op.getThenRegion().front());
b.create<scf::YieldOp>(loc, data1);
b.setInsertionPointToStart(&if_tid_valid_op.getElseRegion().front());
b.create<scf::YieldOp>(loc, zero_f);
b.setInsertionPointAfter(if_tid_valid_op);
Value sum = (accum_factory[scalar_red_root_op_idx])(
data, if_tid_valid_op.getResults().front());

auto acc = (accum_factory[scalar_red_root_op_idx])(
*(for_op_k.getRegionIterArgs().begin() + scalar_red_root_op_idx),
sum);
Expand All @@ -1424,7 +1435,7 @@ LogicalResult lowerWithScheduleParallelReduction(
}
b.create<scf::YieldOp>(loc, yield_values_for_if);
b.setInsertionPointAfter(for_op_k);

b.create<gpu::BarrierOp>(loc);
for (auto root_pair : llvm::enumerate(scalar_reduction_roots)) {
Operation* root_op = root_pair.value();
int idx = root_pair.index();
Expand All @@ -1433,7 +1444,6 @@ LogicalResult lowerWithScheduleParallelReduction(
}
}
{
Value var_j = nullptr;
SmallVector<Value, 4> init_values = {};
for (int stride = 128; stride > 16; stride /= 2) {
b.create<gpu::BarrierOp>(loc);
Expand Down Expand Up @@ -1463,7 +1473,6 @@ LogicalResult lowerWithScheduleParallelReduction(
b.setInsertionPointAfter(if_tid_valid_op);
}
}
b.create<gpu::BarrierOp>(loc);
{
// warp reduce
// if (tid < 32)
Expand Down Expand Up @@ -1507,7 +1516,7 @@ LogicalResult lowerWithScheduleParallelReduction(
for (auto root_pair : llvm::enumerate(scalar_reduction_roots)) {
Operation* root_op = root_pair.value();
int idx = root_pair.index();
Value val = b.create<memref::LoadOp>(loc, shared_mem_map[root_op], tid);
Value val = b.create<memref::LoadOp>(loc, shared_mem_map[root_op], zero);
Type root_element_type = getLhloOpsElementType(root_op);
b.create<memref::AtomicRMWOp>(
loc, root_element_type,
Expand Down

0 comments on commit 81aa949

Please sign in to comment.