Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelization of ConstProp compilation #3042

Merged
merged 18 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 79 additions & 38 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp"

#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/STLExtras.h"

#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
Expand Down Expand Up @@ -187,18 +188,21 @@ ElementsAttr ElementsAttrBuilder::fromWideNums(
// demonstrates a speedup.
ElementsAttr ElementsAttrBuilder::combine(ElementsAttr lhs, ElementsAttr rhs,
ShapedType combinedType, WideNum (*combiner)(WideNum, WideNum)) {
MLIRContext *ctx = lhs.getElementType().getContext();
if (lhs.isSplat()) {
WideNum lhsNum = getElementsSplatWideNum(lhs);
return expandAndTransform(rhs, combinedType,
functionTransformer(
[lhsNum, combiner](WideNum n) { return combiner(lhsNum, n); }));
[lhsNum, combiner](WideNum n) { return combiner(lhsNum, n); },
ctx));
}

if (rhs.isSplat()) {
WideNum rhsNum = getElementsSplatWideNum(rhs);
return expandAndTransform(lhs, combinedType,
functionTransformer(
[rhsNum, combiner](WideNum n) { return combiner(n, rhsNum); }));
[rhsNum, combiner](WideNum n) { return combiner(n, rhsNum); },
ctx));
}

auto combinedShape = combinedType.getShape();
Expand Down Expand Up @@ -231,6 +235,7 @@ ElementsAttr ElementsAttrBuilder::where(ElementsAttr cond, ElementsAttr lhs,
assert(lhs.getElementType() == rhs.getElementType());
assert(lhs.getElementType() == combinedType.getElementType());

MLIRContext *ctx = lhs.getElementType().getContext();
if (cond.isSplat()) {
bool condBool = getElementsSplatWideNum(cond).u64;
return expand(condBool ? lhs : rhs, combinedType.getShape());
Expand All @@ -241,7 +246,8 @@ ElementsAttr ElementsAttrBuilder::where(ElementsAttr cond, ElementsAttr lhs,
WideNum rhsNum = getElementsSplatWideNum(rhs);
return expandAndTransform(cond, combinedType,
functionTransformer(
[lhsNum, rhsNum](WideNum n) { return n.u64 ? lhsNum : rhsNum; }));
[lhsNum, rhsNum](WideNum n) { return n.u64 ? lhsNum : rhsNum; },
ctx));
}

auto combinedShape = combinedType.getShape();
Expand Down Expand Up @@ -373,6 +379,7 @@ double wideToDouble(WideNum n) {
ElementsAttr ElementsAttrBuilder::castToIntElementType(
ElementsAttr elms, IntegerType newElementType, bool round) {
Type oldElementType = elms.getElementType();
MLIRContext *ctx = oldElementType.getContext();
if (newElementType == oldElementType)
return elms;

Expand All @@ -381,25 +388,27 @@ ElementsAttr ElementsAttrBuilder::castToIntElementType(
// Bool: +/-zero cast to 0, everything else including NaN cast to 1.
transformer = wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) {
using cpptype = decltype(wideZero);
return functionTransformer(isWideNonZero<cpptype>);
return functionTransformer(isWideNonZero<cpptype>, ctx);
});
} else if (isa<FloatType>(oldElementType)) {
constexpr bool ROUND = false, TRUNCATE = true;
unsigned width = newElementType.getWidth();
if (newElementType.isUnsigned()) {
uint64_t min = 0;
uint64_t max = std::numeric_limits<uint64_t>::max() >> (64 - width);
transformer = round ? functionTransformer(
convertIntFromFP<ROUND, uint64_t>(min, max))
: functionTransformer(
convertIntFromFP<TRUNCATE, uint64_t>(min, max));
transformer =
round ? functionTransformer(
convertIntFromFP<ROUND, uint64_t>(min, max), ctx)
: functionTransformer(
convertIntFromFP<TRUNCATE, uint64_t>(min, max), ctx);
} else {
int64_t min = std::numeric_limits<int64_t>::min() >> (64 - width);
int64_t max = std::numeric_limits<int64_t>::max() >> (64 - width);
transformer = round ? functionTransformer(
convertIntFromFP<ROUND, int64_t>(min, max))
: functionTransformer(
convertIntFromFP<TRUNCATE, int64_t>(min, max));
transformer =
round ? functionTransformer(
convertIntFromFP<ROUND, int64_t>(min, max), ctx)
: functionTransformer(
convertIntFromFP<TRUNCATE, int64_t>(min, max), ctx);
}
} else if (isa<IntegerType>(oldElementType)) {
// We assume that casts to other integer types don't intend to truncate the
Expand All @@ -413,8 +422,8 @@ ElementsAttr ElementsAttrBuilder::castToIntElementType(
// different signs.
// TODO: Consider relaxing the requirement and omit this transformation.
transformer = newElementType.isUnsigned()
? functionTransformer(wideCast<uint64_t, int64_t>)
: functionTransformer(wideCast<int64_t, uint64_t>);
? functionTransformer(wideCast<uint64_t, int64_t>, ctx)
: functionTransformer(wideCast<int64_t, uint64_t>, ctx);
} else {
ElementsProperties props = getElementsProperties(elms);
ShapedType newType = elms.getShapedType().clone(newElementType);
Expand All @@ -433,6 +442,7 @@ ElementsAttr ElementsAttrBuilder::castToFPElementType(
if (newElementType == oldElementType)
return elms;

MLIRContext *ctx = oldElementType.getContext();
return wideZeroDispatchNonBool(oldElementType, [&](auto wideZero) {
using cpptype = decltype(wideZero);
Transformer transformer;
Expand All @@ -450,16 +460,20 @@ ElementsAttr ElementsAttrBuilder::castToFPElementType(
// See https://github.com/onnx/onnx-mlir/issues/2369
//
// TODO: Change implementation to match the spec, or change the spec.
transformer = functionTransformer([max](WideNum n) {
double d = wideToDouble<cpptype>(n);
return WideNum::widen<BType::DOUBLE>(
// Order of operations is important to ensure NaN stays NaN:
d <= -max ? -max : (d >= max ? max : d));
});
transformer = functionTransformer(
[max](WideNum n) {
double d = wideToDouble<cpptype>(n);
return WideNum::widen<BType::DOUBLE>(
// Order of operations is important to ensure NaN stays NaN:
d <= -max ? -max : (d >= max ? max : d));
},
ctx);
} else if constexpr (std::is_integral_v<cpptype>) {
transformer = functionTransformer([](WideNum n) {
return WideNum::widen<BType::DOUBLE>(wideToDouble<cpptype>(n));
});
transformer = functionTransformer(
[](WideNum n) {
return WideNum::widen<BType::DOUBLE>(wideToDouble<cpptype>(n));
},
ctx);
} else {
ElementsProperties props = getElementsProperties(elms);
ShapedType newType = elms.getShapedType().clone(newElementType);
Expand Down Expand Up @@ -849,6 +863,8 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,
if (axes.empty())
return elms;

Type elementType = elms.getElementType();
MLIRContext *ctx = elementType.getContext();
SmallVector<unsigned, 4> sortedAxes(axes);
std::sort(sortedAxes.begin(), sortedAxes.end());
assert(
Expand Down Expand Up @@ -885,22 +901,47 @@ ElementsAttr ElementsAttrBuilder::reduce(ElementsAttr elms,

ShapedType reducedType = type.clone(reducedShape);
return fromWideNums(reducedType, [&](MutableArrayRef<WideNum> dstNums) {
// Traverse and populate each element d in dstNums.
for (auto &idxoffs : StridesRange<1>(reducedShape, {reducedStrides})) {
WideNum &d = dstNums[idxoffs.flattenedIndex];
int64_t srcPos = idxoffs[0];
// Traverse all the elements that reduce together into d.
// srcNums elements may be repeated if there are zeros in axesStrides.
StridesRange<1> axesRange(axesShape, {axesStrides});
auto axesIter = axesRange.begin();
auto axesEnd = axesRange.end();
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
d = srcNums.get()[srcPos];
while (++axesIter != axesEnd) {
int64_t srcOffset = axesIter->at(0);
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
StridesRange<1> sRange(reducedShape, {reducedStrides});
SmallVector<std::pair<int64_t, uint64_t>, 4> batch;
for (auto &idxoffs : sRange)
batch.emplace_back(std::make_pair(idxoffs.flattenedIndex, idxoffs[0]));

std::mutex mtx;
size_t beginOffset = 0;
imaihal marked this conversation as resolved.
Show resolved Hide resolved
auto fetchBatch = [&](size_t threadNumber) {
// Each thread fetches the same batch size. The remainder is set in the
// threads with small thread number.
const std::lock_guard<std::mutex> lock(mtx);
size_t batchSize = batch.size() / ctx->getNumThreads();
size_t batchSizeMod = batch.size() % ctx->getNumThreads();
if (threadNumber < batchSizeMod)
batchSize += 1;
auto batchBegin = batch.begin() + beginOffset;
auto batchEnd = batchBegin + batchSize;
beginOffset += batchSize;
return llvm::make_range(batchBegin, batchEnd);
};

auto work = [&](size_t threadNumber) {
auto batch = fetchBatch(threadNumber);
// Traverse and populate each element d in dstNums.
for (auto b : batch) {
WideNum &d = dstNums[b.first];
int64_t srcPos = b.second;
// Traverse all the elements that reduce together into d.
// srcNums elements may be repeated if there are zeros in axesStrides.
StridesRange<1> axesRange(axesShape, {axesStrides});
auto axesIter = axesRange.begin();
auto axesEnd = axesRange.end();
assert(axesIter->at(0) == 0 && "initial src offset must be zero");
d = srcNums.get()[srcPos];
while (++axesIter != axesEnd) {
int64_t srcOffset = axesIter->at(0);
d = reducer(d, srcNums.get()[srcPos + srcOffset]);
}
}
}
};
parallelFor(ctx, 0, ctx->getNumThreads(), work);
});
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also assume that the work up there assumes that there are batch.size() reductions that can all be done in parallel.

Since we have for quantization "whole tensor" quantization, we have cases where we have only 1 reduction.
That can also be done in parallel. Say you have 1000 elements and 10 threads. Each thread process its own 100 numbers, and save its result in its location in an array of 10 partial sum. Then after the parallel region, just reduce these 10 values sequentially. You will still get a near 10x speedup.

Also, should we check if that if the batch.size is small, we may want to do things sequentially? It would probably be good in case we have a few very small tensors. You can easily print out the sizes on stderr for a few benchmarks and see if you have such cases.


Expand Down
35 changes: 30 additions & 5 deletions src/Dialect/ONNX/ElementsAttr/ElementsAttrBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#ifndef ONNX_MLIR_ELEM_ATTR_BUILDER_H
#define ONNX_MLIR_ELEM_ATTR_BUILDER_H

#include "mlir/IR/Threading.h"

#include "src/Dialect/ONNX/ElementsAttr/BType.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposableElementsAttr.hpp"
#include "src/Dialect/ONNX/ElementsAttr/DisposablePool.hpp"
Expand Down Expand Up @@ -91,8 +93,9 @@ class ElementsAttrBuilder {
template <typename Function = WideNum (*)(WideNum)>
mlir::ElementsAttr transform(mlir::ElementsAttr elms,
mlir::Type transformedElementType, Function fun) {
mlir::MLIRContext *ctx = elms.getElementType().getContext();
return doTransform(
elms, transformedElementType, functionTransformer(std::move(fun)));
elms, transformedElementType, functionTransformer(std::move(fun), ctx));
}

// Returns an ElementsAttr that is the result of applying a binary function
Expand Down Expand Up @@ -244,10 +247,32 @@ class ElementsAttrBuilder {
// Constructs a transformer that changes every element to the result of
// applying the given function to the element.
template <typename Function = WideNum (*)(WideNum)>
static inline Transformer functionTransformer(Function fun) {
return [fun = std::move(fun)](llvm::MutableArrayRef<WideNum> data) -> void {
for (WideNum &n : data)
n = fun(n);
static inline Transformer functionTransformer(
imaihal marked this conversation as resolved.
Show resolved Hide resolved
Function fun, mlir::MLIRContext *ctx) {
return [fun = std::move(fun), ctx](
llvm::MutableArrayRef<WideNum> data) -> void {
std::mutex mtx;
size_t beginOffset = 0;
imaihal marked this conversation as resolved.
Show resolved Hide resolved
auto fetchBatch = [&](size_t threadNumber) {
// Each thread fetches the same batch size. The remainder is set in the
// threads with small thread number.
const std::lock_guard<std::mutex> lock(mtx);
size_t batchSize = data.size() / ctx->getNumThreads();
size_t batchSizeMod = data.size() % ctx->getNumThreads();
if (threadNumber < batchSizeMod)
batchSize += 1;
auto batchBegin = data.begin() + beginOffset;
auto batchEnd = batchBegin + batchSize;
beginOffset += batchSize;
return llvm::make_range(batchBegin, batchEnd);
};

auto work = [&](size_t threadNumber) {
auto batch = fetchBatch(threadNumber);
for (WideNum &n : batch)
n = fun(n);
};
parallelFor(ctx, 0, ctx->getNumThreads(), work);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned before, please check that there is enough work to go to parallel computations. I suspect that if the reduction is very small, then we really want to do it sequentially and it will be faster.

};
}

Expand Down
Loading