Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…_2023 into neiroyt/feature18_tbbopt
  • Loading branch information
NeiroYT committed Jul 26, 2024
2 parents b49ab7a + b8e7ba0 commit 880acfd
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 1 deletion.
18 changes: 18 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,24 @@ cmake_minimum_required(VERSION 3.20)

project(itlab_2023)

option(ENABLE_STATISTIC_TENSORS "Enable statistic tensors" OFF)

if(ENABLE_STATISTIC_TENSORS)
add_definitions(-DENABLE_STATISTIC_TENSORS)
endif()

option(ENABLE_STATISTIC_TIME "Enable statistic time" OFF)

if(ENABLE_STATISTIC_TIME)
add_definitions(-DENABLE_STATISTIC_TIME)
endif()

option(ENABLE_STATISTIC_WEIGHTS "Enable statistic weights" OFF)

if(ENABLE_STATISTIC_WEIGHTS)
add_definitions(-DENABLE_STATISTIC_WEIGHTS)
endif()

set(CMAKE_CXX_STANDARD 17)

enable_testing()
Expand Down
36 changes: 36 additions & 0 deletions include/graph/graph.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#pragma once
#include <algorithm>
#include <chrono>
#include <queue>
#include <stdexcept>
#include <string>
#include <thread>
#include <vector>

#include "layers/Layer.hpp"
Expand All @@ -19,6 +21,15 @@ class Graph {
Tensor* outten_;
int start_;
int end_;
#ifdef ENABLE_STATISTIC_TENSORS
std::vector<Tensor> tensors_;
#endif
#ifdef ENABLE_STATISTIC_TIME
std::vector<int> time_;
#endif
#ifdef ENABLE_STATISTIC_WEIGHTS
std::vector<Tensor> weights_;
#endif

public:
Graph(int vertices) : BiggestSize_(vertices) {
Expand Down Expand Up @@ -90,13 +101,38 @@ class Graph {
}
}
for (int i : traversal) {
#ifdef ENABLE_STATISTIC_TIME
auto start = std::chrono::high_resolution_clock::now();
#endif
layers_[i]->run(inten_, *outten_);
#ifdef ENABLE_STATISTIC_TENSORS
tensors_.push_back(inten_);
tensors_.push_back(*outten_);
#endif
#ifdef ENABLE_STATISTIC_WEIGHTS
weights_.push_back(layers_[i]->get_weights());
#endif
inten_ = *outten_;
#ifdef ENABLE_STATISTIC_TIME
auto end = std::chrono::high_resolution_clock::now();
auto elapsed =
std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
time_.push_back(static_cast<int>(elapsed.count()));
#endif
}
}
void setOutput(const Layer& lay, Tensor& vec) {
end_ = lay.getID();
outten_ = &vec;
}
#ifdef ENABLE_STATISTIC_TENSORS
std::vector<Tensor> getTensors() { return tensors_; }
#endif
#ifdef ENABLE_STATISTIC_TIME
std::vector<int> getTime() { return time_; }
#endif
#ifdef ENABLE_STATISTIC_WEIGHTS
std::vector<Tensor> getWEIGHTS() { return weights_; }
#endif
};
} // namespace itlab_2023
3 changes: 3 additions & 0 deletions include/layers/ConvLayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class ConvolutionalLayer : public Layer {
kernel_ = kernel;
}
void run(const Tensor& input, Tensor& output) override;
#ifdef ENABLE_STATISTIC_WEIGHTS
Tensor get_weights() override { return kernel_; }
#endif
};

template <typename ValueType>
Expand Down
8 changes: 7 additions & 1 deletion include/layers/EWLayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ class EWLayer : public Layer {
: func_(std::move(function)), alpha_(alpha), beta_(beta) {}
static std::string get_name() { return "Element-wise layer"; }
void run(const Tensor& input, Tensor& output) override;

#ifdef ENABLE_STATISTIC_WEIGHTS
Tensor get_weights() override {
std::vector<int> v = {0};
Tensor a = make_tensor(v);
return a;
}
#endif
private:
std::string func_;
float alpha_;
Expand Down
3 changes: 3 additions & 0 deletions include/layers/FCLayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class FCLayer : public Layer {
: weights_(std::move(weights)), bias_(bias), implType_(implType) {}
static std::string get_name() { return "Fully-connected layer"; }
void run(const Tensor& input, Tensor& output) override;
#ifdef ENABLE_STATISTIC_WEIGHTS
Tensor get_weights() override { return weights_; }
#endif
};

template <typename ValueType>
Expand Down
7 changes: 7 additions & 0 deletions include/layers/InputLayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ class InputLayer : public Layer {
mean_ = mean;
std_ = std;
} // layout = kNchw(0), kNhwc(1)
#ifdef ENABLE_STATISTIC_WEIGHTS
Tensor get_weights() override {
std::vector<int> v = {0};
Tensor a = make_tensor(v);
return a;
}
#endif
void run(const Tensor& input, Tensor& output) override {
switch (input.get_type()) {
case Type::kInt: {
Expand Down
3 changes: 3 additions & 0 deletions include/layers/Layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ class Layer {
LayerType getName() const { return type_; }
void setName(LayerType type) { type_ = type; }
virtual void run(const Tensor& input, Tensor& output) = 0;
#ifdef ENABLE_STATISTIC_WEIGHTS
virtual Tensor get_weights() = 0;
#endif

private:
int id_;
Expand Down
7 changes: 7 additions & 0 deletions include/layers/OutputLayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ class OutputLayer : public Layer {
std::vector<std::string> get_labels() const { return labels_; }
std::pair<std::vector<std::string>, Tensor> top_k(const Tensor& input,
size_t k) const;
#ifdef ENABLE_STATISTIC_WEIGHTS
Tensor get_weights() override {
std::vector<int> v = {0};
Tensor a = make_tensor(v);
return a;
}
#endif

private:
std::vector<std::string> labels_;
Expand Down
7 changes: 7 additions & 0 deletions include/layers/PoolingLayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ class PoolingLayer : public Layer {
implType_(implType) {}
static std::string get_name() { return "Pooling layer"; }
void run(const Tensor& input, Tensor& output) override;
#ifdef ENABLE_STATISTIC_WEIGHTS
Tensor get_weights() override {
std::vector<int> v = {0};
Tensor a = make_tensor(v);
return a;
}
#endif

private:
Shape poolingShape_;
Expand Down
63 changes: 63 additions & 0 deletions test/inference/test_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,46 @@ TEST(bfs, check_result_vec) {
graph.inference();
std::vector<int> tmp = *output.as<int>();
std::vector<int> res = {81, 81, 81};
#ifdef ENABLE_STATISTIC_TENSORS
std::vector<Tensor> tensors = graph.getTensors();
for (int i = 0; i < tensors.size(); i++) {
std::vector<int> ten = *tensors[i].as<int>();
for (int j = 0; j < ten.size(); j++) {
std::cout << ten[j] << ' ';
}
std::cout << '\n';
}
#endif
#ifdef ENABLE_STATISTIC_TIME
std::vector<int> times = graph.getTime();
for (int j = 0; j < times.size(); j++) {
std::cout << times[j] << ' ';
}
std::cout << '\n';
#endif
#ifdef ENABLE_STATISTIC_WEIGHTS
std::vector<Tensor> weights = graph.getWEIGHTS();
for (int i = 0; i < weights.size(); i++) {
switch (weights[i].get_type()) {
case Type::kInt: {
std::vector<int> ten = *weights[i].as<int>();
for (int j = 0; j < ten.size(); j++) {
std::cout << ten[j] << ' ';
}
std::cout << '\n';
break;
}
case Type::kFloat: {
std::vector<float> ten = *weights[i].as<float>();
for (int j = 0; j < ten.size(); j++) {
std::cout << ten[j] << ' ';
}
std::cout << '\n';
break;
}
}
}
#endif
ASSERT_EQ(tmp, res);
}
TEST(bfs, check_end_to_end) {
Expand Down Expand Up @@ -66,6 +106,29 @@ TEST(bfs, check_end_to_end) {
graph.makeConnection(a5, a6);
graph.setOutput(a5, output);
graph.inference();
#ifdef ENABLE_STATISTIC_WEIGHTS
std::vector<Tensor> weights = graph.getWEIGHTS();
for (int i = 0; i < weights.size(); i++) {
switch (weights[i].get_type()) {
case Type::kInt: {
std::vector<int> ten = *weights[i].as<int>();
for (int j = 0; j < ten.size(); j++) {
std::cout << ten[j] << ' ';
}
std::cout << '\n';
break;
}
case Type::kFloat: {
std::vector<float> ten = *weights[i].as<float>();
for (int j = 0; j < ten.size(); j++) {
std::cout << ten[j] << ' ';
}
std::cout << '\n';
break;
}
}
}
#endif
std::vector<float> tmp = *output.as<float>();
std::vector<float> tmp_output = softmax<float>(*output.as<float>());
std::vector<float> res(3, 21);
Expand Down

0 comments on commit 880acfd

Please sign in to comment.