forked from seung-lab/znn-release
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGaussian_init.hpp
81 lines (68 loc) · 2.14 KB
/
Gaussian_init.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
//
// Copyright (C) 2014 Kisuk Lee <[email protected]>
// ----------------------------------------------------------
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
//
#ifndef ZNN_GAUSSIAN_INIT_HPP_INCLUDED
#define ZNN_GAUSSIAN_INIT_HPP_INCLUDED
#include "initializer.hpp"
#include "../core/volume_utils.hpp"
#include <boost/random/normal_distribution.hpp>
namespace zi {
namespace znn {
class Gaussian_init : virtual public initializer
{
private:
typedef boost::normal_distribution<> norm_dist;
typedef boost::variate_generator<boost::mt19937&, norm_dist> generator;
private:
double mu; // mean
double sigma; // standard deviation
public:
virtual void initialize( double3d_ptr w )
{
volume_utils::zero_out(w);
norm_dist dist(mu,sigma);
generator gen(rng,dist);
volume_utils::random_initialization(gen,w);
}
virtual void init( const std::string& params )
{
// parser for parsing arguments
std::vector<double> args;
zi::zargs_::parser<std::vector<double> > arg_parser;
bool parsed = arg_parser.parse(&args,params);
if ( parsed )
{
if ( args.size() == 1 )
{
sigma = args[0];
}
else if ( args.size() == 2 )
{
mu = args[0];
sigma = args[1];
}
}
}
public:
Gaussian_init( double _mu = 0.0,
double _sigma = 0.01 )
: mu(_mu)
, sigma(_sigma)
{}
}; // class Gaussian_init
}} // namespace zi::znn
#endif // ZNN_GAUSSIAN_INIT_HPP_INCLUDED