Skip to content

Commit

Permalink
Implement flatten layer (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
NeiroYT authored Oct 27, 2024
1 parent d169635 commit b6c2736
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
15 changes: 15 additions & 0 deletions include/layers/FlattenLayer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once
#include <string>

#include "layers/Layer.hpp"

namespace itlab_2023 {

class FlattenLayer : public Layer {
public:
FlattenLayer() = default;
static std::string get_name() { return "Flatten layer"; }
void run(const Tensor& input, Tensor& output) override;
};

} // namespace itlab_2023
23 changes: 23 additions & 0 deletions src/layers/FlattenLayer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#include "layers/FlattenLayer.hpp"

namespace itlab_2023 {

void FlattenLayer::run(const Tensor &input, Tensor &output) {
switch (input.get_type()) {
case Type::kInt: {
output =
make_tensor(*input.as<int>(), Shape({input.get_shape().count()}));
break;
}
case Type::kFloat: {
output =
make_tensor(*input.as<float>(), Shape({input.get_shape().count()}));
break;
}
default: {
throw std::runtime_error("No such type");
}
}
}

} // namespace itlab_2023
30 changes: 30 additions & 0 deletions test/single_layer/test_flattenlayer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <vector>

#include "gtest/gtest.h"
#include "layers/FlattenLayer.hpp"

using namespace itlab_2023;

TEST(flattenlayer, new_flattenlayer_can_flatten_int) {
FlattenLayer layer;
Shape sh({2, 2});
Tensor input = make_tensor<int>({1, -1, 2, -2}, sh);
Tensor output;
layer.run(input, output);
EXPECT_EQ(output.get_shape().dims(), 1);
EXPECT_EQ(output.get_shape()[0], 4);
}

TEST(flattenlayer, new_flattenlayer_can_flatten_float) {
FlattenLayer layer;
Shape sh({2, 2});
Tensor input = make_tensor<float>({1.0F, -1.0F, 2.0F, -2.0F}, sh);
Tensor output;
layer.run(input, output);
EXPECT_EQ(output.get_shape().dims(), 1);
EXPECT_EQ(output.get_shape()[0], 4);
}

TEST(flattenlayer, get_layer_name) {
EXPECT_EQ(FlattenLayer::get_name(), "Flatten layer");
}

0 comments on commit b6c2736

Please sign in to comment.