Skip to content

Commit

Permalink
Regenerate MLIR Bindings (#644)
Browse files Browse the repository at this point in the history
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and enzyme-ci-bot[bot] authored Jan 28, 2025
1 parent 6951708 commit e38e8ca
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 26 deletions.
161 changes: 161 additions & 0 deletions src/mlir/Dialects/Nvvm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,132 @@ function cp_async_bulk_commit_group(; location=Location())
)
end

"""
`cp_async_bulk_shared_cluster_global`
Initiates an asynchronous copy operation from global memory to cluster\'s
shared memory.
The `multicastMask` operand is optional. When it is present, the Op copies
data from global memory to shared memory of multiple CTAs in the cluster.
Operand `multicastMask` specifies the destination CTAs in the cluster such
that each bit position in the 16-bit `multicastMask` operand corresponds to
the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.
[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
"""
function cp_async_bulk_shared_cluster_global(
dstMem::Value,
srcMem::Value,
mbar::Value,
size::Value,
multicastMask=nothing::Union{Nothing,Value};
l2CacheHint=nothing::Union{Nothing,Value},
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[dstMem, srcMem, mbar, size]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(multicastMask) && push!(operands, multicastMask)
!isnothing(l2CacheHint) && push!(operands, l2CacheHint)
push!(attributes, operandsegmentsizes([
1,
1,
1,
1,
if (multicastMask == nothing)
0
elseif 1(l2CacheHint == nothing)
0
else
1
end,
]))

return create_operation(
"nvvm.cp.async.bulk.shared.cluster.global",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`cp_async_bulk_global_shared_cta`
Initiates an asynchronous copy operation from Shared CTA memory to
global memory.
The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.
[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
"""
function cp_async_bulk_global_shared_cta(
dstMem::Value,
srcMem::Value,
size::Value,
l2CacheHint=nothing::Union{Nothing,Value};
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[dstMem, srcMem, size]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(l2CacheHint) && push!(operands, l2CacheHint)

return create_operation(
"nvvm.cp.async.bulk.global.shared.cta",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`cp_async_bulk_shared_cluster_shared_cta`
Initiates an asynchronous copy operation from Shared CTA memory to Shared
cluster memory.
[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk)
"""
function cp_async_bulk_shared_cluster_shared_cta(
dstMem::Value, srcMem::Value, mbar::Value, size::Value; location=Location()
)
op_ty_results = IR.Type[]
operands = Value[dstMem, srcMem, mbar, size]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"nvvm.cp.async.bulk.shared.cluster.shared.cta",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`cp_async_bulk_tensor_shared_cluster_global`
Expand Down Expand Up @@ -1063,6 +1189,41 @@ function cp_async_wait_group(; n, location=Location())
)
end

"""
`cvt_float_to_tf32`
This Op converts the given f32 input to tf32.
The result `res` is represented as an i32 type.
The `relu` attribute, when set, lowers to the \'.relu\' variant of
the cvt instruction. The `rnd` and `sat` attributes specify the
the rounding and saturation modes respectively.
[For more information, see PTX ISA]
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
"""
function cvt_float_to_tf32(
src::Value; res::IR.Type, rnd=nothing, sat=nothing, relu=nothing, location=Location()
)
op_ty_results = IR.Type[res,]
operands = Value[src,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(rnd) && push!(attributes, namedattribute("rnd", rnd))
!isnothing(sat) && push!(attributes, namedattribute("sat", sat))
!isnothing(relu) && push!(attributes, namedattribute("relu", relu))

return create_operation(
"nvvm.cvt.float.to.tf32",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`elect_sync`
Expand Down
78 changes: 70 additions & 8 deletions src/mlir/Dialects/Shardy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ import ...API
Gathers chunks of a tensor along axes specified in `gathering_axes`.
The `gathering_axes` is a list of lists of axes. Each inner list specifies
the axes along which a separate gather should be performed. The outer list
is over the dimensions of the tensor. It will be applied to the sharding of
the operand (`tensor`) to obtain the sharding of the result (`out_sharding`).
The `gathering_axes` is a list of lists of axes. The outer list is over the
dimensions of the tensor. Each inner list specifies the axes along which a
separate gather should be performed on the respective dimension. It will be
applied to the sharding of the operand (`tensor`) to obtain the sharding of
the result (`out_sharding`).
Note that `out_sharding` is not used to determine the sharding of the
result. Instead, the sharding of the result is determined by the sharding of
Expand All @@ -35,7 +36,7 @@ inferred sharding.
```
**Constraints:**
- Elements in `gatheringAxes` must satisfy the constraints listed in
- Elements in `gathering_axes` must satisfy the constraints listed in
`AxisRefListAttr`.
- `out_sharding` must satisfy the constraints listed in
`TensorShardingAttr`.
Expand Down Expand Up @@ -72,6 +73,67 @@ function all_gather(
)
end

"""
`all_slice`
Slices chunks of a tensor along axes specified in `slicing_axes`. There is
an algebric duality between `sdy.all_slice` and `sdy.all_gather`.
The `slicing_axes` is a list of lists of axes. The outer list is over the
dimensions of the tensor. Each inner list specifies the axes along which a
slice should be performed on the respective dimension. It will be applied to
the sharding of the operand (`tensor`) to obtain the sharding of the result
(`out_sharding`).
Note that `out_sharding` is not used to determine the sharding of the
result. Instead, the sharding of the result is determined by the sharding of
the operand and the `slicing_axes`, and `out_sharding` must match this
inferred sharding.
# Example
```mlir
%1 = stablehlo.tanh(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{\"a\"}, {}, {}\\]>]>} : tensor<8x8xf32>
%2 = sdy.all_slice [{\"b\", \"c\"}, {}, {\"d\"}\\] %1 to_sharding=<@mesh, [{\"a\", \"b\", \"c\"}, {}, {\"d\"}\\]> : tensor<8x8xf32>
```
**Constraints:**
- Elements in `slicing_axes` must satisfy the constraints listed in
`AxisRefListAttr`.
- `out_sharding` must satisfy the constraints listed in
`TensorShardingAttr`.
- The operand must have a sharding.
- Both operand and result shardings should be bound to the same `MeshAttr`.
- Applying `slicing_axes` to the operand sharding gets `out_sharding`.
"""
function all_slice(
tensor::Value;
result=nothing::Union{Nothing,IR.Type},
slicing_axes,
out_sharding,
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[tensor,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[
namedattribute("slicing_axes", slicing_axes),
namedattribute("out_sharding", out_sharding),
]
!isnothing(result) && push!(op_ty_results, result)

return create_operation(
"sdy.all_slice",
location;
operands,
owned_regions,
successors,
attributes,
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
result_inference=(length(op_ty_results) == 0 ? true : false),
)
end

"""
`constant`
Expand Down Expand Up @@ -136,7 +198,7 @@ This while op has n data flow edges, the i-th data flow edges is between
sources `x_i`, `return_value_i` and targets `y_i`, `pred_arg_i`,
`body_arg_i`.
An `sdy.data_flow_edge` takes as input the root target of an edge (can be
An `sdy.data_flow_edge` takes as input the owner of an edge (can be
any of the targets, but preferably an op result rather than a block
argument), which shouldn\'t have any other uses. This op isn\'t pure because
it can take an input that originally didn\'t have any uses.
Expand All @@ -163,8 +225,8 @@ We don\'t allow the input of a `sdy.data_flow_edge` to be defined by an
unregistered `sdy.sharding` attribute.
NOTE: it\'s NOT the responsibility of the `sdy.data_flow_edge` to link
between sources and targets, it\'s simply attached to the root target of the
edge. The op that this edge is bound to (while in the example above) is
between sources and targets, it\'s simply attached to the owner of the edge.
The op that this edge is bound to (while in the example above) is
responsible for providing this information.
"""
function data_flow_edge(
Expand Down
7 changes: 6 additions & 1 deletion src/mlir/Dialects/StableHLO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1994,14 +1994,19 @@ https://github.com/openxla/stablehlo/blob/main/docs/spec.md#exponential
```
"""
function exponential(
operand::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()
operand::Value;
result=nothing::Union{Nothing,IR.Type},
result_accuracy=nothing,
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[operand,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(result) && push!(op_ty_results, result)
!isnothing(result_accuracy) &&
push!(attributes, namedattribute("result_accuracy", result_accuracy))

return create_operation(
"stablehlo.exponential",
Expand Down
12 changes: 11 additions & 1 deletion src/mlir/Dialects/Triton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,14 +471,17 @@ function dot_scaled(
d::IR.Type,
lhs_type,
rhs_type,
fastMath,
location=Location(),
)
op_ty_results = IR.Type[d,]
operands = Value[lhs, rhs, c]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[
namedattribute("lhs_type", lhs_type), namedattribute("rhs_type", rhs_type)
namedattribute("lhs_type", lhs_type),
namedattribute("rhs_type", rhs_type),
namedattribute("fastMath", fastMath),
]
!isnothing(lhs_scale) && push!(operands, lhs_scale)
!isnothing(rhs_scale) && push!(operands, rhs_scale)
Expand Down Expand Up @@ -785,12 +788,17 @@ tensor. The input and indices tensors must have the same number of
dimension, and each dimension of the indices tensor that is not the gather
dimension cannot be greater than the corresponding dimension in the input
tensor.
The `efficient_layout` attribute is set when the compiler has determined an
optimized layout for the operation, indicating that it should not be
changed.
"""
function gather(
src::Value,
indices::Value;
result=nothing::Union{Nothing,IR.Type},
axis,
efficient_layout=nothing,
location=Location(),
)
op_ty_results = IR.Type[]
Expand All @@ -799,6 +807,8 @@ function gather(
successors = Block[]
attributes = NamedAttribute[namedattribute("axis", axis),]
!isnothing(result) && push!(op_ty_results, result)
!isnothing(efficient_layout) &&
push!(attributes, namedattribute("efficient_layout", efficient_layout))

return create_operation(
"tt.gather",
Expand Down
21 changes: 21 additions & 0 deletions src/mlir/Dialects/VHLO.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,27 @@ function exponential_v1(operand::Value; result::IR.Type, location=Location())
)
end

function exponential_v2(
operand::Value; result::IR.Type, result_accuracy, location=Location()
)
op_ty_results = IR.Type[result,]
operands = Value[operand,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("result_accuracy", result_accuracy),]

return create_operation(
"vhlo.exponential_v2",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function exponential_minus_one_v1(operand::Value; result::IR.Type, location=Location())
op_ty_results = IR.Type[result,]
operands = Value[operand,]
Expand Down
Loading

0 comments on commit e38e8ca

Please sign in to comment.