forked from seung-lab/znn-release
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhyperbolic_tangent_error_fn.hpp
81 lines (70 loc) · 2.56 KB
/
hyperbolic_tangent_error_fn.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
//
// Copyright (C) 2014 Kisuk Lee <[email protected]>
// ----------------------------------------------------------
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//
#ifndef ZNN_HYPERBOLIC_TANGENT_ERROR_FN_HPP_INCLUDED
#define ZNN_HYPERBOLIC_TANGENT_ERROR_FN_HPP_INCLUDED
#include "error_fn.hpp"
#include "../core/volume_pool.hpp"
namespace zi {
namespace znn {
class hyperbolic_tangent_error_fn: virtual public error_fn
{
private:
double a_; // a*tanh(b*x)
double b_; // a*tanh(b*x)
public:
virtual double3d_ptr gradient(double3d_ptr dEdF, double3d_ptr F)
{
std::size_t n = F->shape()[0]*F->shape()[1]*F->shape()[2];
double3d_ptr r = volume_pool.get_double3d(F->shape()[0],
F->shape()[1],
F->shape()[2]);
for ( std::size_t i = 0; i < n; ++i )
{
r->data()[i] = dEdF->data()[i] * (b_/a_) *
(a_ - F->data()[i])*(a_ + F->data()[i]);
}
return r;
}
virtual void apply(double3d_ptr v)
{
std::size_t n = v->shape()[0]*v->shape()[1]*v->shape()[2];
for ( std::size_t i = 0; i < n; ++i )
{
v->data()[i] = a_*std::tanh(b_*v->data()[i]);
}
}
virtual void add_apply(double c, double3d_ptr v)
{
std::size_t n = v->shape()[0]*v->shape()[1]*v->shape()[2];
for ( std::size_t i = 0; i < n; ++i )
{
v->data()[i] = a_*std::tanh(b_*(c + v->data()[i]));
}
}
public:
// hyperbolic_tangent_error_fn(double a = 1.7159, double b = 0.6666)
hyperbolic_tangent_error_fn(double a = static_cast<double>(1),
double b = static_cast<double>(1))
: a_(a)
, b_(b)
{
// std::cout << a << "*tanh(" << b << " * x)" << std::endl;
}
}; // class hyperbolic_tangent_error_fn
}} // namespace zi::znn
#endif // ZNN_HYPERBOLIC_TANGENT_ERROR_FN_HPP_INCLUDED