forked from seung-lab/znn-release
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patherror_fn.hpp
51 lines (41 loc) · 1.47 KB
/
error_fn.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
//
// Copyright (C) 2014 Aleksandar Zlateski <[email protected]>
// ----------------------------------------------------------
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//
#ifndef ZNN_ERROR_FN_HPP_INCLUDED
#define ZNN_ERROR_FN_HPP_INCLUDED
#include "../core/types.hpp"
namespace zi {
namespace znn {
class error_fn
{
public:
virtual double3d_ptr gradient(double3d_ptr /* dEdF */,
double3d_ptr /* F */) = 0;
virtual void apply(double3d_ptr) = 0;
virtual void add_apply(double c, double3d_ptr v)
{
std::size_t n = v->shape()[0]*v->shape()[1]*v->shape()[2];
for ( std::size_t i = 0; i < n; ++i )
{
v->data()[i] += c;
}
this->apply(v);
}
}; // abstract class error_fn
typedef boost::shared_ptr<error_fn> error_fn_ptr;
}} // namespace zi::znn
#endif // ZNN_ERROR_FN_HPP_INCLUDED