Skip to content

Leaky ReLUs with test case #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/fann.c
Original file line number Diff line number Diff line change
Expand Up @@ -683,13 +683,20 @@ FANN_EXTERNAL fann_type *FANN_API fann_run(struct fann *ann, fann_type *input) {
neuron_it->value = neuron_sum;
break;
case FANN_LINEAR_PIECE:
neuron_it->value = (fann_type)(
(neuron_sum < 0) ? 0 : (neuron_sum > multiplier) ? multiplier : neuron_sum);
neuron_it->value = (fann_type)((neuron_sum < 0) ? 0
: (neuron_sum > multiplier) ? multiplier
: neuron_sum);
Comment on lines +686 to +688
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Formatting should actually be done using ./format.sh script that uses clang-format so this will get overwritten. I will run it after merging.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should really add check for that to the pipeline.

break;
case FANN_LINEAR_PIECE_SYMMETRIC:
neuron_it->value = (fann_type)((neuron_sum < -multiplier)
? -multiplier
: (neuron_sum > multiplier) ? multiplier : neuron_sum);
neuron_it->value = (fann_type)((neuron_sum < -multiplier) ? -multiplier
: (neuron_sum > multiplier) ? multiplier
: neuron_sum);
break;
case FANN_LINEAR_PIECE_RECT:
neuron_it->value = (fann_type)((neuron_sum < 0) ? 0 : neuron_sum);
break;
case FANN_LINEAR_PIECE_RECT_LEAKY:
neuron_it->value = (fann_type)((neuron_sum < 0) ? 0.01 * neuron_sum : neuron_sum);
break;
case FANN_ELLIOT:
case FANN_ELLIOT_SYMMETRIC:
Expand Down
2 changes: 2 additions & 0 deletions src/fann_cascade.c
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,8 @@ fann_type fann_train_candidates_epoch(struct fann *ann, struct fann_train_data *
case FANN_GAUSSIAN_STEPWISE:
case FANN_ELLIOT:
case FANN_LINEAR_PIECE:
case FANN_LINEAR_PIECE_RECT:
case FANN_LINEAR_PIECE_RECT_LEAKY:
case FANN_SIN:
case FANN_COS:
break;
Expand Down
6 changes: 6 additions & 0 deletions src/fann_train.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ fann_type fann_activation_derived(unsigned int activation_function, fann_type st
case FANN_LINEAR_PIECE:
case FANN_LINEAR_PIECE_SYMMETRIC:
return (fann_type)fann_linear_derive(steepness, value);
case FANN_LINEAR_PIECE_RECT:
return (fann_type)((value < 0) ? 0 : steepness);
case FANN_LINEAR_PIECE_RECT_LEAKY:
return (fann_type)((value < 0) ? steepness * 0.01 : steepness);
case FANN_SIGMOID:
case FANN_SIGMOID_STEPWISE:
value = fann_clip(value, 0.01f, 0.99f);
Expand Down Expand Up @@ -125,6 +129,8 @@ fann_type fann_update_MSE(struct fann *ann, struct fann_neuron *neuron, fann_typ
case FANN_GAUSSIAN_STEPWISE:
case FANN_ELLIOT:
case FANN_LINEAR_PIECE:
case FANN_LINEAR_PIECE_RECT:
case FANN_LINEAR_PIECE_RECT_LEAKY:
case FANN_SIN:
case FANN_COS:
break;
Expand Down
6 changes: 6 additions & 0 deletions src/include/fann_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ __doublefann_h__ is not defined
case FANN_GAUSSIAN_STEPWISE: \
result = 0; \
break; \
case FANN_LINEAR_PIECE_RECT: \
result = (fann_type)((value < 0) ? 0 : value); \
break; \
case FANN_LINEAR_PIECE_RECT_LEAKY: \
result = (fann_type)((value < 0) ? value * 0.01 : value); \
break; \
}

#endif
18 changes: 16 additions & 2 deletions src/include/fann_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,16 @@ static char const *const FANN_TRAIN_NAMES[] = {"FANN_TRAIN_INCREMENTAL", "FANN_T
* y = cos(x*s)/2+0.5
* d = s*-sin(x*s)/2

FANN_LINEAR_PIECE_RECT - ReLU
* span: -inf < y < inf
* y = x<0? 0: x
* d = x<0? 0: 1

FANN_LINEAR_PIECE_RECT_LEAKY - leaky ReLU
* span: -inf < y < inf
* y = x<0? 0.01*x: x
* d = x<0? 0.01: 1

See also:
<fann_set_activation_function_layer>, <fann_set_activation_function_hidden>,
<fann_set_activation_function_output>, <fann_set_activation_steepness>,
Expand Down Expand Up @@ -223,7 +233,9 @@ enum fann_activationfunc_enum {
FANN_SIN_SYMMETRIC,
FANN_COS_SYMMETRIC,
FANN_SIN,
FANN_COS
FANN_COS,
FANN_LINEAR_PIECE_RECT,
FANN_LINEAR_PIECE_RECT_LEAKY
};

/* Constant: FANN_ACTIVATIONFUNC_NAMES
Expand Down Expand Up @@ -254,7 +266,9 @@ static char const *const FANN_ACTIVATIONFUNC_NAMES[] = {"FANN_LINEAR",
"FANN_SIN_SYMMETRIC",
"FANN_COS_SYMMETRIC",
"FANN_SIN",
"FANN_COS"};
"FANN_COS",
"FANN_LINEAR_PIECE_RECT",
"FANN_LINEAR_PIECE_RECT_LEAKY"};

/* Enum: fann_errorfunc_enum
Error function used during training.
Expand Down
14 changes: 13 additions & 1 deletion src/include/fann_data_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ enum training_algorithm_enum {
* y = cos(x*s)
* d = s*-sin(x*s)

FANN_LINEAR_PIECE_RECT - ReLU
* span: -inf < y < inf
* y = x<0? 0: x
* d = x<0? 0: 1

FANN_LINEAR_PIECE_RECT_LEAKY - leaky ReLU
* span: -inf < y < inf
* y = x<0? 0.01*x: x
* d = x<0? 0.01: 1

See also:
<neural_net::set_activation_function_hidden>,
<neural_net::set_activation_function_output>
Expand All @@ -220,7 +230,9 @@ enum activation_function_enum {
LINEAR_PIECE,
LINEAR_PIECE_SYMMETRIC,
SIN_SYMMETRIC,
COS_SYMMETRIC
COS_SYMMETRIC,
LINEAR_PIECE_RECT,
LINEAR_PIECE_RECT_LEAKY
};

/* Enum: network_type_enum
Expand Down
25 changes: 25 additions & 0 deletions tests/fann_test_train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,30 @@ TEST_F(FannTestTrain, TrainOnDateSimpleXor) {
EXPECT_LT(net.test_data(data), 0.001);
}

TEST_F(FannTestTrain, TrainOnReLUSimpleXor) {
neural_net net(LAYER, 3, 2, 3, 1);

data.set_train_data(4, 2, xorInput, 1, xorOutput);
net.set_activation_function_hidden(FANN::LINEAR_PIECE_RECT);
net.set_activation_steepness_hidden(1.0);
net.train_on_data(data, 100, 100, 0.001);

EXPECT_LT(net.get_MSE(), 0.001);
EXPECT_LT(net.test_data(data), 0.001);
}

TEST_F(FannTestTrain, TrainOnReLULeakySimpleXor) {
neural_net net(LAYER, 3, 2, 3, 1);

data.set_train_data(4, 2, xorInput, 1, xorOutput);
net.set_activation_function_hidden(FANN::LINEAR_PIECE_RECT_LEAKY);
net.set_activation_steepness_hidden(1.0);
net.train_on_data(data, 100, 100, 0.001);

EXPECT_LT(net.get_MSE(), 0.001);
EXPECT_LT(net.test_data(data), 0.001);
}

TEST_F(FannTestTrain, TrainSimpleIncrementalXor) {
neural_net net(LAYER, 3, 2, 3, 1);

Expand All @@ -41,3 +65,4 @@ TEST_F(FannTestTrain, TrainSimpleIncrementalXor) {

EXPECT_LT(net.get_MSE(), 0.01);
}