Skip to content

Commit

Permalink
[bugfix] fix scatter op accuracy (#1295)
Browse files Browse the repository at this point in the history
fix scatter op accuracy
  • Loading branch information
Yancey1989 authored May 20, 2024
1 parent 59c9279 commit 52cb669
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 56 deletions.
53 changes: 12 additions & 41 deletions pytorch_blade/tests/disc/ops/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,62 +25,33 @@ def scatter_func(destination, place_at, source):
a = torch.scatter(destination, 0, place_at, source)
return a

destination = torch.tensor(
[
[4.0, 0.0, 3.0, 1.0, 0.0],
[0.0, 5.0, 8.0, 0.0, 0.0],
[6.0, 0.0, 0.0, 9.0, 0.0]
], dtype=torch.float32, device=self.device)
destination = torch.rand(3, 5, dtype=torch.float32, device=self.device)

source = torch.tensor(
[
[0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[0.5735, 0.9006, 0.6797, 0.4152, 0.1732]
], dtype=torch.float32, device=self.device)

place_at = torch.tensor(
[
[0, 1, 2, 0],
[2, 0, 0, 1]
], dtype=torch.int64, device=self.device)
source = torch.rand(4, 5, dtype=torch.float32, device=self.device)
indices = torch.randint(0, 3, (2, 4), dtype=torch.int64, device=self.device)

annotations = [(list(destination.shape), torch.float32), (list(
place_at.shape), torch.int64), (list(source.shape), torch.float32)]
indices.shape), torch.int64), (list(source.shape), torch.float32)]
self._test_disc(scatter_func, annotations,
(destination, place_at, source))
(destination, indices, source))

def test_scatteradd(self):
if self.device != torch.device('cuda'):
return

@torch.jit.script
def scatter_func(destination, place_at, source):
a = torch.scatter_add(destination, 0, place_at, source)
return a

destination = torch.tensor(
[
[4.0, 0.0, 3.0, 1.0, 0.0],
[0.0, 5.0, 8.0, 0.0, 0.0],
[6.0, 0.0, 0.0, 9.0, 0.0]
], dtype=torch.float32, device=self.device)

source = torch.tensor(
[
[0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[0.5735, 0.9006, 0.6797, 0.4152, 0.1732]
], dtype=torch.float32, device=self.device)

place_at = torch.tensor(
[
[0, 1, 2, 0],
[2, 0, 0, 1]
], dtype=torch.int64, device=self.device)
destination = torch.rand(3, 5, dtype=torch.float32, device=self.device)
source = torch.rand(2, 5, dtype=torch.float32, device=self.device)
indices = torch.randint(0, 3, (2, 4), dtype=torch.int64, device=self.device)

annotations = [(list(destination.shape), torch.float32), (list(
place_at.shape), torch.int64), (list(source.shape), torch.float32)]
indices.shape), torch.int64), (list(source.shape), torch.float32)]
self._test_disc(scatter_func, annotations,
(destination, place_at, source))
(destination, indices, source))


if __name__ == "__main__":
Expand Down
20 changes: 6 additions & 14 deletions tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ Operation* getReduceOperator(Region& body) {
int64_t num_calc_ops = 0;
body.walk([&](Operation* op) {
if (isa<memref::AllocOp>(op) || isa<lmhlo::CopyOp>(op) ||
isa<lmhlo::TerminatorOp>(op)) {
isa<lmhlo::TerminatorOp>(op) || isa<memref::DeallocOp>(op)) {
return;
}
num_calc_ops++;
Expand Down Expand Up @@ -683,6 +683,7 @@ Value LowerInplaceScatterOp(OpBuilder* b, Location loc, lmhlo::ScatterOp op,
if_inbound_op.getThenRegion().front().clear();
if_inbound_op.getElseRegion().front().clear();
b->setInsertionPointToEnd(&if_inbound_op.getThenRegion().front());

if (calc_op == nullptr) {
b->create<memref::StoreOp>(loc, update_value, result_memref, result_index);
} else {
Expand All @@ -699,18 +700,10 @@ Value LowerInplaceScatterOp(OpBuilder* b, Location loc, lmhlo::ScatterOp op,
SmallVector<Value, 4> operand_values;
operand_values.push_back(original_value);
operand_values.push_back(update_value);
auto num_operands = calc_op->getNumOperands();
auto result_type = calc_op->getOperand(num_operands - 1).getType();
auto result_elem_type = result_type.cast<MemRefType>().getElementType();
if (isa<lmhlo::AddOp>(calc_op)) {
updated_value = LhloOpToStdScalarOp::map<lmhlo::AddOp>(
llvm::cast<lmhlo::AddOp>(calc_op), result_elem_type, operand_values,
b);
} else {
assert(false && "unexpected update computation in scatter op");
}

b->create<memref::StoreOp>(loc, updated_value, result_memref, result_index);
// atomic add to original_value
b->create<memref::AtomicRMWOp>(loc, result_types[0],
getAtomicRMWKind(op.getUpdateComputation()),
update_value, result_memref, result_index);
}
b->create<scf::YieldOp>(loc, update_value);
b->setInsertionPointToEnd(&if_inbound_op.getElseRegion().front());
Expand All @@ -719,7 +712,6 @@ Value LowerInplaceScatterOp(OpBuilder* b, Location loc, lmhlo::ScatterOp op,
b->setInsertionPointAfter(if_inbound_op);

Value result = *(if_inbound_op.getResults().begin());

return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ LogicalResult miscLowerHelper(OpBuilder& b, Location loc, Operation* opaque_op,
memref = cast<lmhlo::LmhloOp>(&*op).getOperation()->getOperand(1);
}

// for inplace scatter op, output_index according to operand(3)
// for inplace scatter op, output_index according to update_index, the
// operand(2) of lmhlo::ScatterOp
if (isa<lmhlo::ScatterOp>(op) && isInplaceOperator(op)) {
memref = cast<lmhlo::ScatterOp>(&*op).getOperation()->getOperand(2);
}
Expand Down

0 comments on commit 52cb669

Please sign in to comment.