forked from seung-lab/znn-release
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsquare_cost_fn.hpp
89 lines (78 loc) · 2.93 KB
/
square_cost_fn.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
//
// Copyright (C) 2014 Kisuk Lee <[email protected]>
// ----------------------------------------------------------
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//
#ifndef ZNN_SQUARE_COST_FN_HPP_INCLUDED
#define ZNN_SQUARE_COST_FN_HPP_INCLUDED
#include "cost_fn.hpp"
#include "../core/volume_pool.hpp"
namespace zi {
namespace znn {
class square_cost_fn: virtual public cost_fn
{
public:
virtual double3d_ptr gradient( double3d_ptr out,
double3d_ptr lbl,
bool3d_ptr msk )
{
double3d_ptr ret = volume_pool.get_double3d(out);
volume_utils::sub_from_mul(ret,out,lbl,2);
volume_utils::elementwise_masking(ret,msk);
return ret;
}
virtual std::list<double3d_ptr> gradient( std::list<double3d_ptr> outs,
std::list<double3d_ptr> lbls,
std::list<bool3d_ptr> msks )
{
std::list<double3d_ptr> ret;
std::list<double3d_ptr>::iterator lit = lbls.begin();
std::list<bool3d_ptr>::iterator mit = msks.begin();
FOR_EACH( it, outs )
{
ret.push_back(gradient(*it,*lit++,*mit++));
}
return ret;
}
virtual double compute_cost( double3d_ptr out,
double3d_ptr lbl,
bool3d_ptr msk )
{
double3d_ptr err = volume_pool.get_double3d(out);
volume_utils::sub_from_mul(err,out,lbl,1);
volume_utils::elementwise_masking(err,msk);
return volume_utils::square_sum(err);
}
virtual double compute_cost( std::list<double3d_ptr> outs,
std::list<double3d_ptr> lbls,
std::list<bool3d_ptr> msks )
{
double ret = static_cast<double>(0);
std::list<double3d_ptr>::iterator lit = lbls.begin();
std::list<bool3d_ptr>::iterator mit = msks.begin();
FOR_EACH( it, outs )
{
ret += compute_cost(*it,*lit++,*mit++);
}
return ret;
}
virtual void print_cost( double cost )
{
// std::cout << "MSE: " << cost << std::endl;
std::cout << "RMSE: " << std::sqrt(cost);
}
}; // class square_cost_fn
}} // namespace zi::znn
#endif // ZNN_SQUARE_COST_FN_HPP_INCLUDED