diff --git a/src/04kernel/src/attributes/transpose_info.cc b/src/04kernel/src/attributes/transpose_info.cc index 2b563d03..9ae385a9 100644 --- a/src/04kernel/src/attributes/transpose_info.cc +++ b/src/04kernel/src/attributes/transpose_info.cc @@ -35,7 +35,7 @@ namespace refactor::kernel { } } } - if (rank == 0) { + if (rank <= 1) { dims = {{1, 1}}; blockSize *= blockCount; blockCount = 1; @@ -73,6 +73,12 @@ namespace refactor::kernel { } perm.resize(rank); } + if (rank <= 1) { + dims = {{1, 1}}; + blockSize *= blockCount; + blockCount = 1; + return; + } // 合并末尾连续访存 if (perm.back() == rank - 1) { blockSize *= shape.back(); diff --git a/src/04kernel/test/attributes/test_transpose_info.cpp b/src/04kernel/test/attributes/test_transpose_info.cpp new file mode 100644 index 00000000..fd735801 --- /dev/null +++ b/src/04kernel/test/attributes/test_transpose_info.cpp @@ -0,0 +1,39 @@ +#include "kernel/attributes/transpose_info.h" +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, TransposeInfo) { + { + TransposeInfo info( + DataType::F32, + {1, 2, 3, 2, 1}, + {1, 2, 3, 0, 4}); + EXPECT_EQ(info.blockSize, 48); + EXPECT_EQ(info.blockCount, 1); + EXPECT_EQ(info.dims.size(), 1); + } + { + TransposeInfo info( + DataType::F32, + {1, 1, 2, 1, 1}, + {1, 2, 3, 0, 4}); + EXPECT_EQ(info.blockSize, 8); + EXPECT_EQ(info.blockCount, 1); + EXPECT_EQ(info.dims.size(), 1); + } + { + TransposeInfo info( + DataType::F32, + {1, 2, 3, 4, 5}, + {2, 3, 1, 0, 4}); + EXPECT_EQ(info.blockSize, 20); + EXPECT_EQ(info.blockCount, 24); + EXPECT_EQ(info.dims.size(), 2); + EXPECT_EQ(info.dims[1].strideI, 12); + EXPECT_EQ(info.dims[1].strideO, 1); + EXPECT_EQ(info.dims[0].strideI, 1); + EXPECT_EQ(info.dims[0].strideO, 2); + } +} diff --git a/src/06frontend/src/graph.cc b/src/06frontend/src/graph.cc index cb1104a4..7fd2eb57 100644 --- a/src/06frontend/src/graph.cc +++ b/src/06frontend/src/graph.cc @@ -102,7 +102,7 @@ namespace refactor::frontend { for (auto i : range0_(inputs.size())) { auto j = inputs[i]; auto const &input = _internal.edges[j].tensor; - ASSERT(input, "The {}th input of \"{}\" is nullptr", i, _internal.nodes[nodeIdx].name); + ASSERT(input, "The input[{}] of \"{}\" is nullptr", i, _internal.nodes[nodeIdx].name); auto checked = edgeChanged[2 * j]; // NOTICE `std::vector::operator[]` 产生常引用!!! auto changed = edgeChanged[2 * j + 1];// NOTICE `std::vector::operator[]` 产生常引用!!! if (!checked) { diff --git a/src/07onnx/src/operators/gather.cc b/src/07onnx/src/operators/gather.cc index 3aa08890..6373bda2 100644 --- a/src/07onnx/src/operators/gather.cc +++ b/src/07onnx/src/operators/gather.cc @@ -1,6 +1,8 @@ #include "computation/operators/gather.h" #include "common.h" #include "gather.hh" +#include "kernel/collectors/gather.h" +#include "runtime/resource.h" #include namespace refactor::onnx { @@ -42,41 +44,34 @@ namespace refactor::onnx { if (!options.shouldCalculate(inputs, {*ans})) { return Ok(Tensors{std::move(ans)}); } + { + using Shape = kernel::Shape; + using Tensor = kernel::Tensor; + using LayoutType = kernel::LayoutType; - std::for_each_n(std::execution::unseq, natural_t(0), ans->elementsSize(), - [&data, &indices, &output, - axis_, - q = indices.shape.size(), - ssz = output.size(), - src = data.data->get(), - dst = reinterpret_cast(ans->malloc()), - eleSize = data.dataType.size()](auto const i) { - auto indices_ = locateN(output, i); - int64_t k; - { - size_t ii = 0, mul = 1; - for (auto j : range0_(q).rev()) { - ii += indices_[j] * mul; - mul *= indices.shape[j].value(); - } - k = indices.dataType == DataType::I64 - ? indices.data->get()[ii] - : indices.data->get()[ii]; - } - { - size_t ii = 0, mul = 1; - for (auto j : range(static_cast(axis_) + q, ssz).rev()) { - ii += indices_[j] * mul; - mul *= data.shape[j - q + 1].value(); - } - ii += k * mul; - for (auto j : range0_(axis_).rev()) { - ii += indices_[j] * mul; - mul *= data.shape[j].value(); - } - std::memcpy(dst + i * eleSize, src + ii * eleSize, eleSize); - } - }); + Shape t1Shape(data.shape.size(), 1); + Shape t2Shape(indices.shape.size(), 1); + Shape oShape(ans->shape.size(), 1); + std::transform(std::execution::unseq, + data.shape.begin(), data.shape.end(), t1Shape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + std::transform(std::execution::unseq, + indices.shape.begin(), indices.shape.end(), t2Shape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + auto t1 = Tensor::share(data.dataType, t1Shape, LayoutType::Others, data.data); + auto t2 = Tensor::share(indices.dataType, t2Shape, LayoutType::Others, indices.data); + std::transform(std::execution::unseq, + ans->shape.begin(), ans->shape.end(), oShape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + auto o = Tensor::share(data.dataType, oShape, LayoutType::Others); + runtime::Resources res; + const auto collector = kernel::GatherCollector(computation::Target::Cpu, axis_); + auto routine = std::move(collector.filter({*t1, *t2}, {*o}).at(0))->lower(res).routine; + void const *inputsCpu[]{*t1->data, *t2->data}; + void *outputsCpu[]{o->malloc()}; + routine(res, nullptr, inputsCpu, outputsCpu); + ans->data = o->data; + } return Ok(Tensors{std::move(ans)}); } diff --git a/src/07onnx/src/operators/reduce.cc b/src/07onnx/src/operators/reduce.cc index a8ae15b5..329d059a 100644 --- a/src/07onnx/src/operators/reduce.cc +++ b/src/07onnx/src/operators/reduce.cc @@ -20,13 +20,23 @@ namespace refactor::onnx { auto noopWithEmptyAxes = false; decltype(Op::axes) axes = std::nullopt; - if (opsetVer >= 18) { - noopWithEmptyAxes = attributes.getOrInsert( "noop_with_empty_axes", {0}).int_() != 0; + + // 针对ReduceSum做特判 + if (opType == "onnx::ReduceSum") { + if (opsetVer >= 13) { + noopWithEmptyAxes = attributes.getOrInsert("noop_with_empty_axes", {0}).int_() != 0; + } else { + axes.emplace(attributes.getOrInsert("axes", {{}}).ints()); + } } else { - axes.emplace(attributes.getOrInsert( "axes", {{}}).ints()); + if (opsetVer >= 18) { + noopWithEmptyAxes = attributes.getOrInsert("noop_with_empty_axes", {0}).int_() != 0; + } else { + axes.emplace(attributes.getOrInsert("axes", {{}}).ints()); + } } - auto keepDims = attributes.getOrInsert( "keepdims", {1}).int_(); + auto keepDims = attributes.getOrInsert("keepdims", {1}).int_(); Ty ty; if (opType == "onnx::ReduceMean") { ty = Ty::Mean; diff --git a/src/07onnx/src/operators/simple_binary.cc b/src/07onnx/src/operators/simple_binary.cc index 2db99bdd..bccc99ad 100644 --- a/src/07onnx/src/operators/simple_binary.cc +++ b/src/07onnx/src/operators/simple_binary.cc @@ -1,6 +1,9 @@ #include "simple_binary.hh" #include "common.h" #include "computation/operators/simple_binary.h" +#include "kernel/collectors/simple_binary.h" +#include "runtime/resource.h" +#include namespace refactor::onnx { using Op = SimpleBinary; @@ -10,7 +13,7 @@ namespace refactor::onnx { : Operator(), type(type_) {} auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { - auto fmod = attributes.getOrInsert( "fmod", {0}).int_(); + auto fmod = attributes.getOrInsert("fmod", {0}).int_(); // clang-format off auto type = opType == "onnx::Add" ? Ty::Add : @@ -93,30 +96,6 @@ namespace refactor::onnx { // clang-format on } - template - void calculate(Ty ty, void *dst, void const *a, void const *b) { - using T_ = typename primitive::type; - auto a_ = *reinterpret_cast(a); - auto b_ = *reinterpret_cast(b); - auto dst_ = reinterpret_cast(dst); - switch (ty) { - case Ty::Add: - *dst_ = a_ + b_; - break; - case Ty::Sub: - *dst_ = a_ - b_; - break; - case Ty::Mul: - *dst_ = a_ * b_; - break; - case Ty::Div: - *dst_ = a_ / b_; - break; - default: - UNREACHABLE(); - } - } - auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult { EXPECT_SIZE(2) @@ -139,35 +118,36 @@ namespace refactor::onnx { return Ok(Tensors{std::move(ans)}); } - auto eleSize = dataType.size(); - auto dst = reinterpret_cast(ans->malloc()); - for (auto i : range0_(ans->elementsSize())) { - auto indices = locateN(ans->shape, i); - auto a_ = locate1(a, indices), - b_ = locate1(b, indices); - auto dst_ = dst + i * eleSize; - //------------------------------------- -#define CASE(T) \ - case DataType::T: \ - calculate(type, dst_, a_, b_); \ - break - //------------------------------------- - switch (dataType.internal) { - CASE(F32); - CASE(F64); - CASE(I32); - CASE(I64); - CASE(I8); - CASE(I16); - CASE(U8); - CASE(U16); - CASE(U32); - CASE(U64); - default: - ans->free(); - break; - } + { + using Shape = kernel::Shape; + using Tensor = kernel::Tensor; + using LayoutType = kernel::LayoutType; + + Shape t1Shape(a.shape.size(), 1); + Shape t2Shape(b.shape.size(), 1); + Shape oShape(ans->shape.size(), 1); + std::transform(std::execution::unseq, + a.shape.begin(), a.shape.end(), t1Shape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + std::transform(std::execution::unseq, + b.shape.begin(), b.shape.end(), t2Shape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + auto t1 = Tensor::share(a.dataType, t1Shape, LayoutType::Others, a.data); + auto t2 = Tensor::share(b.dataType, t2Shape, LayoutType::Others, b.data); + std::transform(std::execution::unseq, + ans->shape.begin(), ans->shape.end(), oShape.begin(), + [](auto const &i) { return static_cast(i.value()); }); + auto o = Tensor::share(a.dataType, oShape, LayoutType::Others); + runtime::Resources res; + auto type_ = static_cast(type); + const auto collector = kernel::SimpleBinaryCollector(computation::Target::Cpu, type_); + auto routine = std::move(collector.filter({*t1, *t2}, {*o}).at(0))->lower(res).routine; + void const *inputsCpu[]{*t1->data, *t2->data}; + void *outputsCpu[]{o->malloc()}; + routine(res, nullptr, inputsCpu, outputsCpu); + ans->data = o->data; } + return Ok(Tensors{std::move(ans)}); }