Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Ops.convolution #428

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,49 @@ extern "C" MlirAttribute mlirComplexAttrDoubleGetChecked(MlirLocation loc,
// wrap(complex::NumberAttr::getTypeID()); }
#pragma endregion


#include "stablehlo/dialect/TypeInference.h"


typedef struct {
MlirAttribute windowStrides;
MlirAttribute padding;
MlirAttribute lhsDilation;
MlirAttribute rhsDilation;
MlirAttribute windowReversal;
MlirAttribute dimensionNumber;
int64_t featureGroupCount;
int64_t batchGroupCount;
} ConvolutionParams;


extern "C" MlirType
inferConvolutionOp(MlirLocation location, MlirType lhsType, MlirType rhsType, ConvolutionParams *cp) {
auto windowStrides = dyn_cast<DenseI64ArrayAttr>(unwrap(cp->windowStrides));
auto lhsDilation = dyn_cast<DenseI64ArrayAttr>(unwrap(cp->lhsDilation));
auto rhsDilation = dyn_cast<DenseI64ArrayAttr>(unwrap(cp->rhsDilation));
auto windowReversal = dyn_cast<DenseBoolArrayAttr>(unwrap(cp->windowReversal));

std::optional<DenseIntElementsAttr> padding_mlir = dyn_cast_if_present<DenseIntElementsAttr>(unwrap(cp->padding));
auto dn_mlir = cast<stablehlo::ConvDimensionNumbersAttr>(unwrap(cp->dimensionNumber));
auto lhsType_ = dyn_cast<RankedTensorType>(unwrap(lhsType));

SmallVector<ShapedTypeComponents> inferredReturnShapes;
auto result = mlir::hlo::inferConvolutionOp(
unwrap(location), lhsType_, unwrap(rhsType), windowStrides, padding_mlir,
lhsDilation, rhsDilation, windowReversal, dn_mlir.getInputBatchDimension(),
dn_mlir.getInputFeatureDimension(), dn_mlir.getInputSpatialDimensions(),
dn_mlir.getKernelInputFeatureDimension(),
dn_mlir.getKernelOutputFeatureDimension(),
dn_mlir.getKernelSpatialDimensions(), dn_mlir.getOutputBatchDimension(),
dn_mlir.getOutputFeatureDimension(), dn_mlir.getOutputSpatialDimensions(),
cp->featureGroupCount, cp->batchGroupCount, nullptr, inferredReturnShapes);
if (result.failed())
return MlirType();
return wrap(RankedTensorType::get(inferredReturnShapes[0].getDims(),
lhsType_.getElementType()));
}

// auxiliar functions
#pragma region utils
template <typename T> const char *cstr_from_string(T text) {
Expand Down
121 changes: 92 additions & 29 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,35 +569,98 @@ end
return clamp(constant(min), x, constant(max))
end

# function convolution(
# lhs::TracedRArray{T,N},
# rhs::TracedRArray{T,N};
# dimension_numbers,
# feature_group_count,
# batch_group_count,
# window_strides=nothing,
# padding=nothing,
# lhs_dilation=nothing,
# rhs_dilation=nothing,
# location=mlir_stacktrace(
# "convolution", @__FILE__, @__LINE__
# ),
# ) where {T,N}
# res = MLIR.IR.result(
# stablehlo.convolution(
# lhs.mlir_data,
# rhs.mlir_data;
# result=mlir_type(TracedRArray{T,N}, ...), # TODO size of result
# window_strides, #*MLIR.IR.DenseArrayAttribute(window_strides)*#,
# padding, #*MLIR.IR.DenseArrayAttribute(padding)*#,
# lhs_dilation, #*MLIR.IR.DenseArrayAttribute(lhs_dilation)*#,
# rhs_dilation, #*MLIR.IR.DenseArrayAttribute(rhs_dilation)*#,
# feature_group_count=feature_group_count,
# location,
# ),
# )
# return TracedRArray{T,N}((), res, size(lhs))
# end
mutable struct ConvolutionParams
windowStrides::MLIR.API.MlirAttribute
padding::MLIR.API.MlirAttribute
lhsDilation::MLIR.API.MlirAttribute
rhsDilation::MLIR.API.MlirAttribute
windowReversal::MLIR.API.MlirAttribute
dimensionNumber::MLIR.API.MlirAttribute
featureGroupCount::Int64
batchGroupCount::Int64
end

function inferConvolutionOp(
loc::MLIR.IR.Location,
lhsType::MLIR.IR.Type,
rhsType::MLIR.IR.Type,
windowStrides::MLIR.IR.Attribute,
padding::MLIR.IR.Attribute,
lhsDilation::MLIR.IR.Attribute,
rhsDilation::MLIR.IR.Attribute,
windowReversal::MLIR.IR.Attribute,
dimensionNumber::MLIR.IR.Attribute,
featureGroupCount::Int,
batchGroupCount::Int,
)
cp = ConvolutionParams(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me but yeah I am confused about the arg thing

windowStrides,
padding,
lhsDilation,
rhsDilation,
windowReversal,
dimensionNumber,
featureGroupCount,
batchGroupCount,
)
@ccall MLIR.API.mlir_c.inferConvolutionOp(
loc::MLIR.API.MlirLocation, lhsType::MLIR.API.MlirType, rhsType::MLIR.API.MlirType, cp::Ref{ConvolutionParams}
)::MLIR.API.MlirType
end

function convolution(
input::TracedRArray{T,N},
kernel::TracedRArray{T,N},
dimension_numbers;
feature_group_count=1,
batch_group_count=1,
window_strides::AbstractArray{Int}=Int[],
padding::AbstractArray{Int}=Int[],
lhs_dilation::AbstractArray{Int}=Int[],
rhs_dilation::AbstractArray{Int}=Int[],
window_reversal::AbstractArray{Bool}=Bool[],
location=mlir_stacktrace("convolution", @__FILE__, @__LINE__),
) where {T,N}
padding =
isempty(padding) ? MLIR.IR.Attribute() : MLIR.IR.DenseElementsAttribute(padding)
window_strides = MLIR.IR.DenseArrayAttribute(window_strides)
lhs_dilation = MLIR.IR.DenseArrayAttribute(lhs_dilation)
rhs_dilation = MLIR.IR.DenseArrayAttribute(rhs_dilation)
window_reversal = MLIR.IR.DenseArrayAttribute(window_reversal)
output_type = inferConvolutionOp(
location,
mlir_type(input),
mlir_type(kernel),
window_strides,
padding,
lhs_dilation,
rhs_dilation,
window_reversal,
dimension_numbers,
feature_group_count,
batch_group_count,
)
@assert output_type.ptr != Ptr{Nothing}() "cannot infer result type"
output_type = MLIR.IR.Type(output_type)

res = MLIR.IR.result(
MLIR.Dialects.stablehlo.convolution(
input.mlir_data,
kernel.mlir_data;
result_0=output_type,
window_strides,
padding,
lhs_dilation,
rhs_dilation,
window_reversal,
dimension_numbers,
feature_group_count,
batch_group_count,
location,
),
)
return TracedRArray{T,N}((), res, size(output_type))
end

@noinline function dot_general(
lhs::TracedRArray{T},
Expand Down
Loading