diff --git a/include/layers/FlattenLayer.hpp b/include/layers/FlattenLayer.hpp new file mode 100644 index 00000000..c29264d3 --- /dev/null +++ b/include/layers/FlattenLayer.hpp @@ -0,0 +1,15 @@ +#pragma once +#include + +#include "layers/Layer.hpp" + +namespace itlab_2023 { + +class FlattenLayer : public Layer { + public: + FlattenLayer() = default; + static std::string get_name() { return "Flatten layer"; } + void run(const Tensor& input, Tensor& output) override; +}; + +} // namespace itlab_2023 diff --git a/src/layers/FlattenLayer.cpp b/src/layers/FlattenLayer.cpp new file mode 100644 index 00000000..c32aee86 --- /dev/null +++ b/src/layers/FlattenLayer.cpp @@ -0,0 +1,23 @@ +#include "layers/FlattenLayer.hpp" + +namespace itlab_2023 { + +void FlattenLayer::run(const Tensor &input, Tensor &output) { + switch (input.get_type()) { + case Type::kInt: { + output = + make_tensor(*input.as(), Shape({input.get_shape().count()})); + break; + } + case Type::kFloat: { + output = + make_tensor(*input.as(), Shape({input.get_shape().count()})); + break; + } + default: { + throw std::runtime_error("No such type"); + } + } +} + +} // namespace itlab_2023 diff --git a/test/single_layer/test_flattenlayer.cpp b/test/single_layer/test_flattenlayer.cpp new file mode 100644 index 00000000..d83f08c7 --- /dev/null +++ b/test/single_layer/test_flattenlayer.cpp @@ -0,0 +1,30 @@ +#include + +#include "gtest/gtest.h" +#include "layers/FlattenLayer.hpp" + +using namespace itlab_2023; + +TEST(flattenlayer, new_flattenlayer_can_flatten_int) { + FlattenLayer layer; + Shape sh({2, 2}); + Tensor input = make_tensor({1, -1, 2, -2}, sh); + Tensor output; + layer.run(input, output); + EXPECT_EQ(output.get_shape().dims(), 1); + EXPECT_EQ(output.get_shape()[0], 4); +} + +TEST(flattenlayer, new_flattenlayer_can_flatten_float) { + FlattenLayer layer; + Shape sh({2, 2}); + Tensor input = make_tensor({1.0F, -1.0F, 2.0F, -2.0F}, sh); + Tensor output; + layer.run(input, output); + EXPECT_EQ(output.get_shape().dims(), 1); + EXPECT_EQ(output.get_shape()[0], 4); +} + +TEST(flattenlayer, get_layer_name) { + EXPECT_EQ(FlattenLayer::get_name(), "Flatten layer"); +}