Skip to content

Commit

Permalink
[mlir][acc] Consistency between acc.loop and acc compute ops (llvm#11…
Browse files Browse the repository at this point in the history
…4549)

- GangPrivate and GangFirstPrivate renamed to just Private and
Firstprivate respectively. This is makes compute ops consistent with the
loop op (and also with the acc spec wording for the clause).
- Added getBody to all compute ops
- Verifier for firstprivate ops / recipes is enabled
  • Loading branch information
razvanlupusoru authored Nov 1, 2024
1 parent c5a254c commit c0a1597
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 20 deletions.
34 changes: 22 additions & 12 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1114,9 +1114,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
UnitAttr:$selfAttr,
Variadic<AnyType>:$reductionOperands,
OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
Variadic<OpenACC_PointerLikeTypeInterface>:$firstprivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr,
Expand All @@ -1134,8 +1134,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
CArg<"mlir::Value", "{}">:$ifCond,
CArg<"mlir::Value", "{}">:$selfCond,
CArg<"mlir::ValueRange", "{}">:$reductionOperands,
CArg<"mlir::ValueRange", "{}">:$gangPrivateOperands,
CArg<"mlir::ValueRange", "{}">:$gangFirstPrivateOperands,
CArg<"mlir::ValueRange", "{}">:$privateOperands,
CArg<"mlir::ValueRange", "{}">:$firstprivateOperands,
CArg<"mlir::ValueRange", "{}">:$dataClauseOperands)>];

let extraClassDeclaration = [{
Expand All @@ -1145,6 +1145,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);

/// Used to retrieve the block inside the op's region.
Block &getBody() { return getRegion().front(); }

/// Return true if the op has the async attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasAsyncOnly();
Expand Down Expand Up @@ -1202,15 +1205,15 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
type($asyncOperands), $asyncOperandsDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
type($gangFirstPrivateOperands), $firstprivatizations)
| `firstprivate` `(` custom<SymOperandList>($firstprivateOperands,
type($firstprivateOperands), $firstprivatizations)
`)`
| `num_gangs` `(` custom<NumGangs>($numGangs,
type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
| `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
type($numWorkers), $numWorkersDeviceType) `)`
| `private` `(` custom<SymOperandList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
$privateOperands, type($privateOperands), $privatizations)
`)`
| `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
type($vectorLength), $vectorLengthDeviceType) `)`
Expand Down Expand Up @@ -1271,9 +1274,9 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
UnitAttr:$selfAttr,
Variadic<AnyType>:$reductionOperands,
OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
Variadic<OpenACC_PointerLikeTypeInterface>:$firstprivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr,
Expand All @@ -1288,6 +1291,9 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);

/// Used to retrieve the block inside the op's region.
Block &getBody() { return getRegion().front(); }

/// Return true if the op has the async attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasAsyncOnly();
Expand Down Expand Up @@ -1326,11 +1332,11 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
type($asyncOperands), $asyncOperandsDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
type($gangFirstPrivateOperands), $firstprivatizations)
| `firstprivate` `(` custom<SymOperandList>($firstprivateOperands,
type($firstprivateOperands), $firstprivatizations)
`)`
| `private` `(` custom<SymOperandList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
$privateOperands, type($privateOperands), $privatizations)
`)`
| `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
$waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
Expand Down Expand Up @@ -1410,6 +1416,9 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);

/// Used to retrieve the block inside the op's region.
Block &getBody() { return getRegion().front(); }

/// Return true if the op has the async attribute for the
/// mlir::acc::DeviceType::None device_type.
bool hasAsyncOnly();
Expand Down Expand Up @@ -1824,6 +1833,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
/// The i-th data operand passed.
Value getDataOperand(unsigned i);

/// Used to retrieve the block inside the op's region.
Block &getBody() { return getLoopRegions().front()->front(); }

/// Return true if the op has the auto attribute for the
Expand Down
20 changes: 14 additions & 6 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,8 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
}

unsigned ParallelOp::getNumDataOperands() {
return getReductionOperands().size() + getGangPrivateOperands().size() +
getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
return getReductionOperands().size() + getPrivateOperands().size() +
getFirstprivateOperands().size() + getDataClauseOperands().size();
}

Value ParallelOp::getDataOperand(unsigned i) {
Expand Down Expand Up @@ -783,9 +783,13 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch(

LogicalResult acc::ParallelOp::verify() {
if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
*this, getPrivatizations(), getGangPrivateOperands(), "private",
*this, getPrivatizations(), getPrivateOperands(), "private",
"privatizations", /*checkOperandType=*/false)))
return failure();
if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
*this, getFirstprivatizations(), getFirstprivateOperands(),
"firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
return failure();
if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
*this, getReductionRecipes(), getReductionOperands(), "reduction",
"reductions", false)))
Expand Down Expand Up @@ -1361,8 +1365,8 @@ printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
//===----------------------------------------------------------------------===//

unsigned SerialOp::getNumDataOperands() {
return getReductionOperands().size() + getGangPrivateOperands().size() +
getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
return getReductionOperands().size() + getPrivateOperands().size() +
getFirstprivateOperands().size() + getDataClauseOperands().size();
}

Value SerialOp::getDataOperand(unsigned i) {
Expand Down Expand Up @@ -1420,9 +1424,13 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {

LogicalResult acc::SerialOp::verify() {
if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
*this, getPrivatizations(), getGangPrivateOperands(), "private",
*this, getPrivatizations(), getPrivateOperands(), "private",
"privatizations", /*checkOperandType=*/false)))
return failure();
if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
*this, getFirstprivatizations(), getFirstprivateOperands(),
"firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
return failure();
if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
*this, getReductionRecipes(), getReductionOperands(), "reduction",
"reductions", false)))
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
!std::is_same_v<Op, acc::DataOp> &&
!std::is_same_v<Op, acc::DeclareOp>) {
collectPtrs(op.getReductionOperands(), values, hostToDevice);
collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
collectPtrs(op.getPrivateOperands(), values, hostToDevice);
collectPtrs(op.getFirstprivateOperands(), values, hostToDevice);
}
}

Expand Down

0 comments on commit c0a1597

Please sign in to comment.