Skip to content

Commit e38e8ca

Browse files
Regenerate MLIR Bindings (#644)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 6951708 commit e38e8ca

File tree

6 files changed

+337
-26
lines changed

6 files changed

+337
-26
lines changed

src/mlir/Dialects/Nvvm.jl

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,132 @@ function cp_async_bulk_commit_group(; location=Location())
688688
)
689689
end
690690

691+
"""
692+
`cp_async_bulk_shared_cluster_global`
693+
694+
Initiates an asynchronous copy operation from global memory to cluster\'s
695+
shared memory.
696+
697+
The `multicastMask` operand is optional. When it is present, the Op copies
698+
data from global memory to shared memory of multiple CTAs in the cluster.
699+
Operand `multicastMask` specifies the destination CTAs in the cluster such
700+
that each bit position in the 16-bit `multicastMask` operand corresponds to
701+
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
702+
703+
The `l2CacheHint` operand is optional, and it is used to specify cache
704+
eviction policy that may be used during the memory access.
705+
[For more information, see PTX ISA]
706+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
707+
"""
708+
function cp_async_bulk_shared_cluster_global(
709+
dstMem::Value,
710+
srcMem::Value,
711+
mbar::Value,
712+
size::Value,
713+
multicastMask=nothing::Union{Nothing,Value};
714+
l2CacheHint=nothing::Union{Nothing,Value},
715+
location=Location(),
716+
)
717+
op_ty_results = IR.Type[]
718+
operands = Value[dstMem, srcMem, mbar, size]
719+
owned_regions = Region[]
720+
successors = Block[]
721+
attributes = NamedAttribute[]
722+
!isnothing(multicastMask) && push!(operands, multicastMask)
723+
!isnothing(l2CacheHint) && push!(operands, l2CacheHint)
724+
push!(attributes, operandsegmentsizes([
725+
1,
726+
1,
727+
1,
728+
1,
729+
if (multicastMask == nothing)
730+
0
731+
elseif 1(l2CacheHint == nothing)
732+
0
733+
else
734+
1
735+
end,
736+
]))
737+
738+
return create_operation(
739+
"nvvm.cp.async.bulk.shared.cluster.global",
740+
location;
741+
operands,
742+
owned_regions,
743+
successors,
744+
attributes,
745+
results=op_ty_results,
746+
result_inference=false,
747+
)
748+
end
749+
750+
"""
751+
`cp_async_bulk_global_shared_cta`
752+
753+
Initiates an asynchronous copy operation from Shared CTA memory to
754+
global memory.
755+
756+
The `l2CacheHint` operand is optional, and it is used to specify cache
757+
eviction policy that may be used during the memory access.
758+
[For more information, see PTX ISA]
759+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
760+
"""
761+
function cp_async_bulk_global_shared_cta(
762+
dstMem::Value,
763+
srcMem::Value,
764+
size::Value,
765+
l2CacheHint=nothing::Union{Nothing,Value};
766+
location=Location(),
767+
)
768+
op_ty_results = IR.Type[]
769+
operands = Value[dstMem, srcMem, size]
770+
owned_regions = Region[]
771+
successors = Block[]
772+
attributes = NamedAttribute[]
773+
!isnothing(l2CacheHint) && push!(operands, l2CacheHint)
774+
775+
return create_operation(
776+
"nvvm.cp.async.bulk.global.shared.cta",
777+
location;
778+
operands,
779+
owned_regions,
780+
successors,
781+
attributes,
782+
results=op_ty_results,
783+
result_inference=false,
784+
)
785+
end
786+
787+
"""
788+
`cp_async_bulk_shared_cluster_shared_cta`
789+
790+
Initiates an asynchronous copy operation from Shared CTA memory to Shared
791+
cluster memory.
792+
793+
[For more information, see PTX ISA]
794+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
795+
"""
796+
function cp_async_bulk_shared_cluster_shared_cta(
797+
dstMem::Value, srcMem::Value, mbar::Value, size::Value; location=Location()
798+
)
799+
op_ty_results = IR.Type[]
800+
operands = Value[dstMem, srcMem, mbar, size]
801+
owned_regions = Region[]
802+
successors = Block[]
803+
attributes = NamedAttribute[]
804+
805+
return create_operation(
806+
"nvvm.cp.async.bulk.shared.cluster.shared.cta",
807+
location;
808+
operands,
809+
owned_regions,
810+
successors,
811+
attributes,
812+
results=op_ty_results,
813+
result_inference=false,
814+
)
815+
end
816+
691817
"""
692818
`cp_async_bulk_tensor_shared_cluster_global`
693819
@@ -1063,6 +1189,41 @@ function cp_async_wait_group(; n, location=Location())
10631189
)
10641190
end
10651191

1192+
"""
1193+
`cvt_float_to_tf32`
1194+
1195+
This Op converts the given f32 input to tf32.
1196+
The result `res` is represented as an i32 type.
1197+
The `relu` attribute, when set, lowers to the \'.relu\' variant of
1198+
the cvt instruction. The `rnd` and `sat` attributes specify the
1199+
the rounding and saturation modes respectively.
1200+
[For more information, see PTX ISA]
1201+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1202+
"""
1203+
function cvt_float_to_tf32(
1204+
src::Value; res::IR.Type, rnd=nothing, sat=nothing, relu=nothing, location=Location()
1205+
)
1206+
op_ty_results = IR.Type[res,]
1207+
operands = Value[src,]
1208+
owned_regions = Region[]
1209+
successors = Block[]
1210+
attributes = NamedAttribute[]
1211+
!isnothing(rnd) && push!(attributes, namedattribute("rnd", rnd))
1212+
!isnothing(sat) && push!(attributes, namedattribute("sat", sat))
1213+
!isnothing(relu) && push!(attributes, namedattribute("relu", relu))
1214+
1215+
return create_operation(
1216+
"nvvm.cvt.float.to.tf32",
1217+
location;
1218+
operands,
1219+
owned_regions,
1220+
successors,
1221+
attributes,
1222+
results=op_ty_results,
1223+
result_inference=false,
1224+
)
1225+
end
1226+
10661227
"""
10671228
`elect_sync`
10681229

src/mlir/Dialects/Shardy.jl

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ import ...API
1818
1919
Gathers chunks of a tensor along axes specified in `gathering_axes`.
2020
21-
The `gathering_axes` is a list of lists of axes. Each inner list specifies
22-
the axes along which a separate gather should be performed. The outer list
23-
is over the dimensions of the tensor. It will be applied to the sharding of
24-
the operand (`tensor`) to obtain the sharding of the result (`out_sharding`).
21+
The `gathering_axes` is a list of lists of axes. The outer list is over the
22+
dimensions of the tensor. Each inner list specifies the axes along which a
23+
separate gather should be performed on the respective dimension. It will be
24+
applied to the sharding of the operand (`tensor`) to obtain the sharding of
25+
the result (`out_sharding`).
2526
2627
Note that `out_sharding` is not used to determine the sharding of the
2728
result. Instead, the sharding of the result is determined by the sharding of
@@ -35,7 +36,7 @@ inferred sharding.
3536
```
3637
3738
**Constraints:**
38-
- Elements in `gatheringAxes` must satisfy the constraints listed in
39+
- Elements in `gathering_axes` must satisfy the constraints listed in
3940
`AxisRefListAttr`.
4041
- `out_sharding` must satisfy the constraints listed in
4142
`TensorShardingAttr`.
@@ -72,6 +73,67 @@ function all_gather(
7273
)
7374
end
7475

76+
"""
77+
`all_slice`
78+
79+
Slices chunks of a tensor along axes specified in `slicing_axes`. There is
80+
an algebric duality between `sdy.all_slice` and `sdy.all_gather`.
81+
82+
The `slicing_axes` is a list of lists of axes. The outer list is over the
83+
dimensions of the tensor. Each inner list specifies the axes along which a
84+
slice should be performed on the respective dimension. It will be applied to
85+
the sharding of the operand (`tensor`) to obtain the sharding of the result
86+
(`out_sharding`).
87+
88+
Note that `out_sharding` is not used to determine the sharding of the
89+
result. Instead, the sharding of the result is determined by the sharding of
90+
the operand and the `slicing_axes`, and `out_sharding` must match this
91+
inferred sharding.
92+
93+
# Example
94+
```mlir
95+
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{\"a\"}, {}, {}\\]>]>} : tensor<8x8xf32>
96+
%2 = sdy.all_slice [{\"b\", \"c\"}, {}, {\"d\"}\\] %1 to_sharding=<@mesh, [{\"a\", \"b\", \"c\"}, {}, {\"d\"}\\]> : tensor<8x8xf32>
97+
```
98+
99+
**Constraints:**
100+
- Elements in `slicing_axes` must satisfy the constraints listed in
101+
`AxisRefListAttr`.
102+
- `out_sharding` must satisfy the constraints listed in
103+
`TensorShardingAttr`.
104+
- The operand must have a sharding.
105+
- Both operand and result shardings should be bound to the same `MeshAttr`.
106+
- Applying `slicing_axes` to the operand sharding gets `out_sharding`.
107+
"""
108+
function all_slice(
109+
tensor::Value;
110+
result=nothing::Union{Nothing,IR.Type},
111+
slicing_axes,
112+
out_sharding,
113+
location=Location(),
114+
)
115+
op_ty_results = IR.Type[]
116+
operands = Value[tensor,]
117+
owned_regions = Region[]
118+
successors = Block[]
119+
attributes = NamedAttribute[
120+
namedattribute("slicing_axes", slicing_axes),
121+
namedattribute("out_sharding", out_sharding),
122+
]
123+
!isnothing(result) && push!(op_ty_results, result)
124+
125+
return create_operation(
126+
"sdy.all_slice",
127+
location;
128+
operands,
129+
owned_regions,
130+
successors,
131+
attributes,
132+
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
133+
result_inference=(length(op_ty_results) == 0 ? true : false),
134+
)
135+
end
136+
75137
"""
76138
`constant`
77139
@@ -136,7 +198,7 @@ This while op has n data flow edges, the i-th data flow edges is between
136198
sources `x_i`, `return_value_i` and targets `y_i`, `pred_arg_i`,
137199
`body_arg_i`.
138200
139-
An `sdy.data_flow_edge` takes as input the root target of an edge (can be
201+
An `sdy.data_flow_edge` takes as input the owner of an edge (can be
140202
any of the targets, but preferably an op result rather than a block
141203
argument), which shouldn\'t have any other uses. This op isn\'t pure because
142204
it can take an input that originally didn\'t have any uses.
@@ -163,8 +225,8 @@ We don\'t allow the input of a `sdy.data_flow_edge` to be defined by an
163225
unregistered `sdy.sharding` attribute.
164226
165227
NOTE: it\'s NOT the responsibility of the `sdy.data_flow_edge` to link
166-
between sources and targets, it\'s simply attached to the root target of the
167-
edge. The op that this edge is bound to (while in the example above) is
228+
between sources and targets, it\'s simply attached to the owner of the edge.
229+
The op that this edge is bound to (while in the example above) is
168230
responsible for providing this information.
169231
"""
170232
function data_flow_edge(

src/mlir/Dialects/StableHLO.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1994,14 +1994,19 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential
19941994
```
19951995
"""
19961996
function exponential(
1997-
operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()
1997+
operand::Value;
1998+
result=nothing::Union{Nothing,IR.Type},
1999+
result_accuracy=nothing,
2000+
location=Location(),
19982001
)
19992002
op_ty_results = IR.Type[]
20002003
operands = Value[operand,]
20012004
owned_regions = Region[]
20022005
successors = Block[]
20032006
attributes = NamedAttribute[]
20042007
!isnothing(result) && push!(op_ty_results, result)
2008+
!isnothing(result_accuracy) &&
2009+
push!(attributes, namedattribute("result_accuracy", result_accuracy))
20052010

20062011
return create_operation(
20072012
"stablehlo.exponential",

src/mlir/Dialects/Triton.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,14 +471,17 @@ function dot_scaled(
471471
d::IR.Type,
472472
lhs_type,
473473
rhs_type,
474+
fastMath,
474475
location=Location(),
475476
)
476477
op_ty_results = IR.Type[d,]
477478
operands = Value[lhs, rhs, c]
478479
owned_regions = Region[]
479480
successors = Block[]
480481
attributes = NamedAttribute[
481-
namedattribute("lhs_type", lhs_type), namedattribute("rhs_type", rhs_type)
482+
namedattribute("lhs_type", lhs_type),
483+
namedattribute("rhs_type", rhs_type),
484+
namedattribute("fastMath", fastMath),
482485
]
483486
!isnothing(lhs_scale) && push!(operands, lhs_scale)
484487
!isnothing(rhs_scale) && push!(operands, rhs_scale)
@@ -785,12 +788,17 @@ tensor. The input and indices tensors must have the same number of
785788
dimension, and each dimension of the indices tensor that is not the gather
786789
dimension cannot be greater than the corresponding dimension in the input
787790
tensor.
791+
792+
The `efficient_layout` attribute is set when the compiler has determined an
793+
optimized layout for the operation, indicating that it should not be
794+
changed.
788795
"""
789796
function gather(
790797
src::Value,
791798
indices::Value;
792799
result=nothing::Union{Nothing,IR.Type},
793800
axis,
801+
efficient_layout=nothing,
794802
location=Location(),
795803
)
796804
op_ty_results = IR.Type[]
@@ -799,6 +807,8 @@ function gather(
799807
successors = Block[]
800808
attributes = NamedAttribute[namedattribute("axis", axis),]
801809
!isnothing(result) && push!(op_ty_results, result)
810+
!isnothing(efficient_layout) &&
811+
push!(attributes, namedattribute("efficient_layout", efficient_layout))
802812

803813
return create_operation(
804814
"tt.gather",

src/mlir/Dialects/VHLO.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,27 @@ function exponential_v1(operand::Value; result::IR.Type, location=Location())
14521452
)
14531453
end
14541454

1455+
function exponential_v2(
1456+
operand::Value; result::IR.Type, result_accuracy, location=Location()
1457+
)
1458+
op_ty_results = IR.Type[result,]
1459+
operands = Value[operand,]
1460+
owned_regions = Region[]
1461+
successors = Block[]
1462+
attributes = NamedAttribute[namedattribute("result_accuracy", result_accuracy),]
1463+
1464+
return create_operation(
1465+
"vhlo.exponential_v2",
1466+
location;
1467+
operands,
1468+
owned_regions,
1469+
successors,
1470+
attributes,
1471+
results=op_ty_results,
1472+
result_inference=false,
1473+
)
1474+
end
1475+
14551476
function exponential_minus_one_v1(operand::Value; result::IR.Type, location=Location())
14561477
op_ty_results = IR.Type[result,]
14571478
operands = Value[operand,]

0 commit comments

Comments
 (0)