forked from handspeaker/RandomForests
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathNode.h
103 lines (91 loc) · 2.54 KB
/
Node.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/************************************************
*Random Forest Program
*Function: implementation of two kinds of node
for classification and regression
*Author: [email protected]
*CreateTime: 2014.7.10
*Version: V0.1
*************************************************/
#ifndef NODE_H
#define NODE_H
#include"Sample.h"
#include<stdio.h>
struct Result
{
float label; //label or value
float prob; //prob or 1
};
struct Pair
{
float feature;
int id;
};
int compare_pair( const void* a, const void* b );
class Node
{
public:
Node();
virtual ~Node();
//sort the selected samples in the ascending order based on featureId
void sortIndex(int featureId);
Sample*_samples;//the samples hold by this node
//set this node as leaf node
inline void setLeaf(bool flag){_isLeaf=flag;};
inline bool isLeaf(){return _isLeaf;};
//calculate the information gain
virtual void calculateInfoGain(Node**_cartreeArray,int id,float minInfoGain)=0;
virtual void calculateParams()=0;
//create a leaf node
virtual void createLeaf()=0;
//predict the data
virtual int predict(float*data,int id)=0;
virtual void getResult(Result&r)=0;
inline int getFeatureIndex(){return _featureIndex;};
inline void setFeatureIndex(int featureIndex){_featureIndex=featureIndex;};
inline float getThreshold(){return _threshold;};
inline void setThreshold(float threshold){_threshold=threshold;};
protected:
bool _isLeaf;
int _featureIndex;
float _threshold;
};
class ClasNode:public Node
{
public:
ClasNode();
~ClasNode();
void calculateInfoGain(Node**_cartreeArray,int id,float minInfoGain);
void calculateParams();
void createLeaf();
int predict(float*data,int id);
void getResult(Result&r);
inline float getClass(){return _class;};
inline float getProb(){return _prob;};
inline void setClass(float clas){_class=clas;};
inline void setProb(float prob){_prob=prob;};
//parameters for training
float _gini;
float*_probs;
private:
float _class; //the class
float _prob; //the probablity
};
class RegrNode:public Node
{
public:
RegrNode();
~RegrNode();
void calculateInfoGain(Node**_cartreeArray,int id,float minInfoGain);
void calculateParams();
void createLeaf();
int predict(float*data,int id);
void getResult(Result&r);
inline float getValue(){return _value;};
inline void setValue(float value){_value=value;};
//parameters for training
float _mean;
float _variance;
private:
float _value; //the regression value
};
#endif//NODE_H