Skip to content

Commit b18e821

Browse files
committed
[mlir][OpenMP] Added omp.task
This patch adds tasking construct according to Section 2.10.1 of OpenMP 5.0 Reviewed By: peixin, kiranchandramohan, abidmalikwaterloo Differential Revision: https://reviews.llvm.org/D123575
1 parent fdd424e commit b18e821

File tree

4 files changed

+249
-0
lines changed

4 files changed

+249
-0
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,91 @@ def YieldOp : OpenMP_Op<"yield",
463463
let assemblyFormat = [{ ( `(` $results^ `:` type($results) `)` )? attr-dict}];
464464
}
465465

466+
//===----------------------------------------------------------------------===//
467+
// 2.10.1 task Construct
468+
//===----------------------------------------------------------------------===//
469+
470+
def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
471+
OutlineableOpenMPOpInterface, AutomaticAllocationScope,
472+
ReductionClauseInterface]> {
473+
let summary = "task construct";
474+
let description = [{
475+
The task construct defines an explicit task.
476+
477+
For definitions of "undeferred task", "included task", "final task" and
478+
"mergeable task", please check OpenMP Specification.
479+
480+
When an `if` clause is present on a task construct, and the value of
481+
`if_expr` evaluates to `false`, an "undeferred task" is generated, and the
482+
encountering thread must suspend the current task region, for which
483+
execution cannot be resumed until execution of the structured block that is
484+
associated with the generated task is completed.
485+
486+
When a `final` clause is present on a task construct and the `final_expr`
487+
evaluates to `true`, the generated task will be a "final task". All task
488+
constructs encountered during execution of a final task will generate final
489+
and included tasks.
490+
491+
If the `untied` clause is present on a task construct, any thread in the
492+
team can resume the task region after a suspension. The `untied` clause is
493+
ignored if a `final` clause is present on the same task construct and the
494+
`final_expr` evaluates to `true`, or if a task is an included task.
495+
496+
When the `mergeable` clause is present on a task construct, the generated
497+
task is a "mergeable task".
498+
499+
The `in_reduction` clause specifies that this particular task (among all the
500+
tasks in current taskgroup, if any) participates in a reduction.
501+
502+
The `priority` clause is a hint for the priority of the generated task.
503+
The `priority` is a non-negative integer expression that provides a hint for
504+
task execution order. Among all tasks ready to be executed, higher priority
505+
tasks (those with a higher numerical value in the priority clause
506+
expression) are recommended to execute before lower priority ones. The
507+
default priority-value when no priority clause is specified should be
508+
assumed to be zero (the lowest priority).
509+
510+
The `allocators_vars` and `allocate_vars` parameters are a variadic list of
511+
values that specify the memory allocator to be used to obtain storage for
512+
private values.
513+
514+
}];
515+
516+
// TODO: depend, affinity and detach clauses
517+
let arguments = (ins Optional<I1>:$if_expr,
518+
Optional<I1>:$final_expr,
519+
UnitAttr:$untied,
520+
UnitAttr:$mergeable,
521+
Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
522+
OptionalAttr<SymbolRefArrayAttr>:$in_reductions,
523+
Optional<I32>:$priority,
524+
Variadic<AnyType>:$allocate_vars,
525+
Variadic<AnyType>:$allocators_vars);
526+
let regions = (region AnyRegion:$region);
527+
let assemblyFormat = [{
528+
oilist(`if` `(` $if_expr `)`
529+
|`final` `(` $final_expr `)`
530+
|`untied` $untied
531+
|`mergeable` $mergeable
532+
|`in_reduction` `(`
533+
custom<ReductionVarList>(
534+
$in_reduction_vars, type($in_reduction_vars), $in_reductions
535+
) `)`
536+
|`priority` `(` $priority `)`
537+
|`allocate` `(`
538+
custom<AllocateAndAllocator>(
539+
$allocate_vars, type($allocate_vars),
540+
$allocators_vars, type($allocators_vars)
541+
) `)`
542+
) $region attr-dict
543+
}];
544+
let extraClassDeclaration = [{
545+
/// Returns the reduction variables
546+
operand_range getReductionVars() { return in_reduction_vars(); }
547+
}];
548+
let hasVerifier = 1;
549+
}
550+
466551
//===----------------------------------------------------------------------===//
467552
// 2.10.4 taskyield Construct
468553
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,13 @@ LogicalResult ReductionOp::verify() {
725725
return emitOpError() << "the accumulator is not used by the parent";
726726
}
727727

728+
//===----------------------------------------------------------------------===//
729+
// TaskOp
730+
//===----------------------------------------------------------------------===//
731+
LogicalResult TaskOp::verify() {
732+
return verifyReductionVarList(*this, in_reductions(), in_reduction_vars());
733+
}
734+
728735
//===----------------------------------------------------------------------===//
729736
// WsLoopOp
730737
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,3 +980,70 @@ func @omp_single(%data_var : memref<i32>) -> () {
980980
}) {operand_segment_sizes = dense<[1,0]> : vector<2xi32>} : (memref<i32>) -> ()
981981
return
982982
}
983+
984+
// -----
985+
986+
func @omp_task(%ptr: !llvm.ptr<f32>) {
987+
// expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}}
988+
omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr<f32>) {
989+
// CHECK: "test.foo"() : () -> ()
990+
"test.foo"() : () -> ()
991+
// CHECK: omp.terminator
992+
omp.terminator
993+
}
994+
}
995+
996+
// -----
997+
998+
omp.reduction.declare @add_f32 : f32
999+
init {
1000+
^bb0(%arg: f32):
1001+
%0 = arith.constant 0.0 : f32
1002+
omp.yield (%0 : f32)
1003+
}
1004+
combiner {
1005+
^bb1(%arg0: f32, %arg1: f32):
1006+
%1 = arith.addf %arg0, %arg1 : f32
1007+
omp.yield (%1 : f32)
1008+
}
1009+
1010+
func @omp_task(%ptr: !llvm.ptr<f32>) {
1011+
// expected-error @below {{op accumulator variable used more than once}}
1012+
omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr<f32>, @add_f32 -> %ptr : !llvm.ptr<f32>) {
1013+
// CHECK: "test.foo"() : () -> ()
1014+
"test.foo"() : () -> ()
1015+
// CHECK: omp.terminator
1016+
omp.terminator
1017+
}
1018+
}
1019+
1020+
// -----
1021+
1022+
omp.reduction.declare @add_i32 : i32
1023+
init {
1024+
^bb0(%arg: i32):
1025+
%0 = arith.constant 0 : i32
1026+
omp.yield (%0 : i32)
1027+
}
1028+
combiner {
1029+
^bb1(%arg0: i32, %arg1: i32):
1030+
%1 = arith.addi %arg0, %arg1 : i32
1031+
omp.yield (%1 : i32)
1032+
}
1033+
atomic {
1034+
^bb2(%arg2: !llvm.ptr<i32>, %arg3: !llvm.ptr<i32>):
1035+
%2 = llvm.load %arg3 : !llvm.ptr<i32>
1036+
llvm.atomicrmw add %arg2, %2 monotonic : i32
1037+
omp.yield
1038+
}
1039+
1040+
func @omp_task(%mem: memref<1xf32>) {
1041+
// expected-error @below {{op expected accumulator ('memref<1xf32>') to be the same type as reduction declaration ('!llvm.ptr<i32>')}}
1042+
omp.task in_reduction(@add_i32 -> %mem : memref<1xf32>) {
1043+
// CHECK: "test.foo"() : () -> ()
1044+
"test.foo"() : () -> ()
1045+
// CHECK: omp.terminator
1046+
omp.terminator
1047+
}
1048+
return
1049+
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,96 @@ func @omp_single_allocate_nowait(%data_var: memref<i32>) {
897897
return
898898
}
899899

900+
// CHECK-LABEL: @omp_task
901+
// CHECK-SAME: (%[[bool_var:.*]]: i1, %[[i64_var:.*]]: i64, %[[i32_var:.*]]: i32, %[[data_var:.*]]: memref<i32>)
902+
func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memref<i32>) {
903+
904+
// Checking simple task
905+
// CHECK: omp.task {
906+
omp.task {
907+
// CHECK: "test.foo"() : () -> ()
908+
"test.foo"() : () -> ()
909+
// CHECK: omp.terminator
910+
omp.terminator
911+
}
912+
913+
// Checking `if` clause
914+
// CHECK: omp.task if(%[[bool_var]]) {
915+
omp.task if(%bool_var) {
916+
// CHECK: "test.foo"() : () -> ()
917+
"test.foo"() : () -> ()
918+
// CHECK: omp.terminator
919+
omp.terminator
920+
}
921+
922+
// Checking `final` clause
923+
// CHECK: omp.task final(%[[bool_var]]) {
924+
omp.task final(%bool_var) {
925+
// CHECK: "test.foo"() : () -> ()
926+
"test.foo"() : () -> ()
927+
// CHECK: omp.terminator
928+
omp.terminator
929+
}
930+
931+
// Checking `untied` clause
932+
// CHECK: omp.task untied {
933+
omp.task untied {
934+
// CHECK: "test.foo"() : () -> ()
935+
"test.foo"() : () -> ()
936+
// CHECK: omp.terminator
937+
omp.terminator
938+
}
939+
940+
// Checking `in_reduction` clause
941+
%c1 = arith.constant 1 : i32
942+
// CHECK: %[[redn_var1:.*]] = llvm.alloca %{{.*}} x f32 : (i32) -> !llvm.ptr<f32>
943+
%0 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr<f32>
944+
// CHECK: %[[redn_var2:.*]] = llvm.alloca %{{.*}} x f32 : (i32) -> !llvm.ptr<f32>
945+
%1 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr<f32>
946+
// CHECK: omp.task in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr<f32>, @add_f32 -> %[[redn_var2]] : !llvm.ptr<f32>) {
947+
omp.task in_reduction(@add_f32 -> %0 : !llvm.ptr<f32>, @add_f32 -> %1 : !llvm.ptr<f32>) {
948+
// CHECK: "test.foo"() : () -> ()
949+
"test.foo"() : () -> ()
950+
// CHECK: omp.terminator
951+
omp.terminator
952+
}
953+
954+
// Checking priority clause
955+
// CHECK: omp.task priority(%[[i32_var]]) {
956+
omp.task priority(%i32_var) {
957+
// CHECK: "test.foo"() : () -> ()
958+
"test.foo"() : () -> ()
959+
// CHECK: omp.terminator
960+
omp.terminator
961+
}
962+
963+
// Checking allocate clause
964+
// CHECK: omp.task allocate(%[[data_var]] : memref<i32> -> %[[data_var]] : memref<i32>) {
965+
omp.task allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
966+
// CHECK: "test.foo"() : () -> ()
967+
"test.foo"() : () -> ()
968+
// CHECK: omp.terminator
969+
omp.terminator
970+
}
971+
972+
// Checking multiple clauses
973+
// CHECK: omp.task if(%[[bool_var]]) final(%[[bool_var]]) untied
974+
omp.task if(%bool_var) final(%bool_var) untied
975+
// CHECK-SAME: in_reduction(@add_f32 -> %[[redn_var1]] : !llvm.ptr<f32>, @add_f32 -> %[[redn_var2]] : !llvm.ptr<f32>)
976+
in_reduction(@add_f32 -> %0 : !llvm.ptr<f32>, @add_f32 -> %1 : !llvm.ptr<f32>)
977+
// CHECK-SAME: priority(%[[i32_var]])
978+
priority(%i32_var)
979+
// CHECK-SAME: allocate(%[[data_var]] : memref<i32> -> %[[data_var]] : memref<i32>)
980+
allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
981+
// CHECK: "test.foo"() : () -> ()
982+
"test.foo"() : () -> ()
983+
// CHECK: omp.terminator
984+
omp.terminator
985+
}
986+
987+
return
988+
}
989+
900990
// -----
901991

902992
func @omp_threadprivate() {

0 commit comments

Comments
 (0)