diff --git a/pytorch_blade/tests/disc/ops/test_scatter.py b/pytorch_blade/tests/disc/ops/test_scatter.py index 879b362488a..73f03dbfc0d 100644 --- a/pytorch_blade/tests/disc/ops/test_scatter.py +++ b/pytorch_blade/tests/disc/ops/test_scatter.py @@ -25,62 +25,33 @@ def scatter_func(destination, place_at, source): a = torch.scatter(destination, 0, place_at, source) return a - destination = torch.tensor( - [ - [4.0, 0.0, 3.0, 1.0, 0.0], - [0.0, 5.0, 8.0, 0.0, 0.0], - [6.0, 0.0, 0.0, 9.0, 0.0] - ], dtype=torch.float32, device=self.device) + destination = torch.rand(3, 5, dtype=torch.float32, device=self.device) - source = torch.tensor( - [ - [0.3992, 0.2908, 0.9044, 0.4850, 0.6004], - [0.5735, 0.9006, 0.6797, 0.4152, 0.1732] - ], dtype=torch.float32, device=self.device) - - place_at = torch.tensor( - [ - [0, 1, 2, 0], - [2, 0, 0, 1] - ], dtype=torch.int64, device=self.device) + source = torch.rand(4, 5, dtype=torch.float32, device=self.device) + indices = torch.randint(0, 3, (2, 4), dtype=torch.int64, device=self.device) annotations = [(list(destination.shape), torch.float32), (list( - place_at.shape), torch.int64), (list(source.shape), torch.float32)] + indices.shape), torch.int64), (list(source.shape), torch.float32)] self._test_disc(scatter_func, annotations, - (destination, place_at, source)) - + (destination, indices, source)) + def test_scatteradd(self): if self.device != torch.device('cuda'): return - + @torch.jit.script def scatter_func(destination, place_at, source): a = torch.scatter_add(destination, 0, place_at, source) return a - destination = torch.tensor( - [ - [4.0, 0.0, 3.0, 1.0, 0.0], - [0.0, 5.0, 8.0, 0.0, 0.0], - [6.0, 0.0, 0.0, 9.0, 0.0] - ], dtype=torch.float32, device=self.device) - - source = torch.tensor( - [ - [0.3992, 0.2908, 0.9044, 0.4850, 0.6004], - [0.5735, 0.9006, 0.6797, 0.4152, 0.1732] - ], dtype=torch.float32, device=self.device) - - place_at = torch.tensor( - [ - [0, 1, 2, 0], - [2, 0, 0, 1] - ], dtype=torch.int64, device=self.device) + destination = torch.rand(3, 5, dtype=torch.float32, device=self.device) + source = torch.rand(2, 5, dtype=torch.float32, device=self.device) + indices = torch.randint(0, 3, (2, 4), dtype=torch.int64, device=self.device) annotations = [(list(destination.shape), torch.float32), (list( - place_at.shape), torch.int64), (list(source.shape), torch.float32)] + indices.shape), torch.int64), (list(source.shape), torch.float32)] self._test_disc(scatter_func, annotations, - (destination, place_at, source)) + (destination, indices, source)) if __name__ == "__main__": diff --git a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc index e47cf39c957..745973be732 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc @@ -152,7 +152,7 @@ Operation* getReduceOperator(Region& body) { int64_t num_calc_ops = 0; body.walk([&](Operation* op) { if (isa(op) || isa(op) || - isa(op)) { + isa(op) || isa(op)) { return; } num_calc_ops++; @@ -683,6 +683,7 @@ Value LowerInplaceScatterOp(OpBuilder* b, Location loc, lmhlo::ScatterOp op, if_inbound_op.getThenRegion().front().clear(); if_inbound_op.getElseRegion().front().clear(); b->setInsertionPointToEnd(&if_inbound_op.getThenRegion().front()); + if (calc_op == nullptr) { b->create(loc, update_value, result_memref, result_index); } else { @@ -699,18 +700,10 @@ Value LowerInplaceScatterOp(OpBuilder* b, Location loc, lmhlo::ScatterOp op, SmallVector operand_values; operand_values.push_back(original_value); operand_values.push_back(update_value); - auto num_operands = calc_op->getNumOperands(); - auto result_type = calc_op->getOperand(num_operands - 1).getType(); - auto result_elem_type = result_type.cast().getElementType(); - if (isa(calc_op)) { - updated_value = LhloOpToStdScalarOp::map( - llvm::cast(calc_op), result_elem_type, operand_values, - b); - } else { - assert(false && "unexpected update computation in scatter op"); - } - - b->create(loc, updated_value, result_memref, result_index); + // atomic add to original_value + b->create(loc, result_types[0], + getAtomicRMWKind(op.getUpdateComputation()), + update_value, result_memref, result_index); } b->create(loc, update_value); b->setInsertionPointToEnd(&if_inbound_op.getElseRegion().front()); @@ -719,7 +712,6 @@ Value LowerInplaceScatterOp(OpBuilder* b, Location loc, lmhlo::ScatterOp op, b->setInsertionPointAfter(if_inbound_op); Value result = *(if_inbound_op.getResults().begin()); - return result; } diff --git a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc index 5cd454b4849..73b1feaaa68 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_legalize_roots_to_loops.cc @@ -169,7 +169,8 @@ LogicalResult miscLowerHelper(OpBuilder& b, Location loc, Operation* opaque_op, memref = cast(&*op).getOperation()->getOperand(1); } - // for inplace scatter op, output_index according to operand(3) + // for inplace scatter op, output_index according to update_index, the + // operand(2) of lmhlo::ScatterOp if (isa(op) && isInplaceOperator(op)) { memref = cast(&*op).getOperation()->getOperand(2); }