-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensor.cpp
121 lines (97 loc) · 1.78 KB
/
tensor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include "tensor.h"
#include "utils.h"
Tensor::Tensor()
{
_device_type = Tensor_CPU;
_data = new vector<float>();
}
Tensor::~Tensor()
{
clear();
}
int Tensor::get_device_type()
{
return _device_type;
}
void Tensor::set_device_type(int device_type)
{
_device_type = device_type;
}
float* Tensor::get_device_data()
{
return _device_data;
}
int Tensor::set_shape(const vector<int> &shape)
{
if(is_same_shape(shape, _shape))
{
return 0;
}
int oldSize = 1;
for(int i=0; i<_shape.size(); i++)
{
oldSize *= _shape[i];
}
if(_shape.size() == 0)
{
oldSize = 0;
}
int newSize = 1;
for(int i=0; i<shape.size(); i++)
{
newSize *= shape[i];
}
_shape.clear();
for(int i=0; i<shape.size(); i++)
{
_shape.push_back(shape[i]);
}
if(newSize != oldSize)
{
_data->resize(newSize);
}
return 0;
}
void Tensor::set_data(const vector<float> &data)
{
_data->resize(data.size());
memcpy(_data->data(), data.data(), data.size()*sizeof(float));
}
void Tensor::set_shape_data(const vector<int> &shape, const vector<float> *data)
{
_shape.clear();
for(int i=0; i<shape.size(); i++)
{
_shape.push_back(shape[i]);
}
_data = (vector<float>*)data;
}
vector<float>* Tensor::get_data()
{
return _data;
}
vector<int> Tensor::get_shape()
{
return _shape;
}
int Tensor::get_size()
{
int size = 1;
for(int i=0; i<_shape.size(); i++)
{
size *= _shape[i];
}
return size;
}
int Tensor::load_data(FILE *fp, long offset)
{
fseek(fp, offset, SEEK_SET);
fread(_data->data(), _data->size(), sizeof(float), fp);
return 0;
}
void Tensor::clear()
{
_shape.clear();
_data->clear();
delete _data;
}