Skip to content

Commit

Permalink
Added [scale_channels] layer for squeeze-and-excitation blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyAB committed Jun 19, 2019
1 parent 8c80ba6 commit cc41339
Show file tree
Hide file tree
Showing 14 changed files with 253 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ LDFLAGS+= -L/usr/local/zed/lib -lsl_core -lsl_input -lsl_zed
#-lstdc++ -D_GLIBCXX_USE_CXX11_ABI=0
endif

OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o upsample_layer.o lstm_layer.o conv_lstm_layer.o
OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o upsample_layer.o lstm_layer.o conv_lstm_layer.o scale_channels_layer.o
ifeq ($(GPU), 1)
LDFLAGS+= -lstdc++
OBJ+=convolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
Expand Down
2 changes: 2 additions & 0 deletions build/darknet/darknet.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@
<ClCompile Include="..\..\src\rnn_layer.c" />
<ClCompile Include="..\..\src\rnn_vid.c" />
<ClCompile Include="..\..\src\route_layer.c" />
<ClCompile Include="..\..\src\scale_channels_layer.c" />
<ClCompile Include="..\..\src\shortcut_layer.c" />
<ClCompile Include="..\..\src\softmax_layer.c" />
<ClCompile Include="..\..\src\super.c" />
Expand Down Expand Up @@ -284,6 +285,7 @@
<ClInclude Include="..\..\src\reorg_old_layer.h" />
<ClInclude Include="..\..\src\rnn_layer.h" />
<ClInclude Include="..\..\src\route_layer.h" />
<ClInclude Include="..\..\src\scale_channels_layer.h" />
<ClInclude Include="..\..\src\shortcut_layer.h" />
<ClInclude Include="..\..\src\softmax_layer.h" />
<ClInclude Include="..\..\src\stb_image.h" />
Expand Down
2 changes: 2 additions & 0 deletions include/darknet.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ typedef enum {
AVGPOOL,
LOCAL,
SHORTCUT,
SCALE_CHANNELS,
ACTIVE,
RNN,
GRU,
Expand All @@ -153,6 +154,7 @@ typedef enum {
UPSAMPLE,
LOGXENT,
L2NORM,
EMPTY,
BLANK
} LAYER_TYPE;

Expand Down
5 changes: 5 additions & 0 deletions src/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ void add_3_arrays_activate(float *a1, float *a2, float *a3, size_t size, ACTIVAT
void sum_of_mults(float *a1, float *a2, float *b1, float *b2, size_t size, float *dst);
void activate_and_mult(float *a1, float *a2, size_t size, ACTIVATION a, float *dst);

void scale_channels_gpu(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out);
void backward_scale_channels_gpu(float *in_w_h_c_delta, int size, int channel_size,
float *in_scales_c, float *out_from_delta,
float *in_from_output, float *out_state_delta);

#endif
#ifdef __cplusplus
}
Expand Down
46 changes: 46 additions & 0 deletions src/blas_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1099,3 +1099,49 @@ extern "C" void activate_and_mult(float *a1, float *a2, size_t size, ACTIVATION
}
activate_and_mult_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(a1, a2, size, a, dst);
}



__global__ void scale_channels_kernel(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < size) {
out[index] = in_w_h_c[index] * scales_c[index / channel_size];
}
}

extern "C" void scale_channels_gpu(float *in_w_h_c, int size, int channel_size, float *scales_c, float *out)
{
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(size, block_size);
scale_channels_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> >(in_w_h_c, size, channel_size, scales_c, out);
CHECK_CUDA(cudaPeekAtLastError());
}


__global__ void backward_scale_channels_kernel(float *in_w_h_c_delta, int size, int channel_size,
float *in_scales_c, float *out_from_delta,
float *in_from_output, float *out_state_delta)
{
const int index = blockIdx.x*blockDim.x + threadIdx.x;
if (index < size) {
out_state_delta[index / channel_size] += in_w_h_c_delta[index] * in_from_output[index]; // l.delta * from (should be divided by channel_size?)
out_from_delta[index] += in_scales_c[index / channel_size] * in_w_h_c_delta[index]; // input * l.delta

//out_state_delta[index / channel_size] += in_w_h_c_delta[index] / channel_size;
//out_from_delta[index] = in_w_h_c_delta[index];
}
}

extern "C" void backward_scale_channels_gpu(float *in_w_h_c_delta, int size, int channel_size,
float *in_scales_c, float *out_from_delta,
float *in_from_output, float *out_state_delta)
{
const int block_size = BLOCK;
const int num_blocks = get_number_of_blocks(size, block_size);
backward_scale_channels_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (in_w_h_c_delta, size, channel_size,
in_scales_c, out_from_delta,
in_from_output, out_state_delta);

CHECK_CUDA(cudaPeekAtLastError());
}
4 changes: 2 additions & 2 deletions src/classifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus,
#else
loss = train_network(net, train);
#endif
if(avg_loss == -1) avg_loss = loss;
if(avg_loss == -1 || isnan(avg_loss) || isinf(avg_loss)) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;

i = get_current_batch(net);

int calc_topk_for_each = iter_topk + 2 * train_images_num / (net.batch * net.subdivisions); // calculate TOPk for each 2 Epochs
calc_topk_for_each = fmax(calc_topk_for_each, net.burn_in);
calc_topk_for_each = fmax(calc_topk_for_each, 1000);
calc_topk_for_each = fmax(calc_topk_for_each, 100);
if (i % 10 == 0) {
if (calc_topk) {
fprintf(stderr, "\n (next TOP5 calculation at %d iterations) ", calc_topk_for_each);
Expand Down
7 changes: 5 additions & 2 deletions src/convolutional_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
//if (state.use_mixed_precision) {
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3*state.net.burn_in) &&
l.c % 8 == 0 && l.n % 8 == 0)
(l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train)
{
//printf("\n CUDNN_HALF!!! state.index = %d \n", state.index);

Expand Down Expand Up @@ -629,7 +629,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state
//#ifdef CUDNN_HALF
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
if (state.index != 0 && state.net.cudnn_half && !l.xnor && (!state.train || iteration_num > 3*state.net.burn_in) &&
l.c % 8 == 0 && l.n % 8 == 0)
(l.c / l.groups) % 8 == 0 && l.n % 8 == 0 && !state.train)
{
const size_t input16_size = l.batch*l.c*l.w*l.h;
const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h;
Expand Down Expand Up @@ -910,6 +910,9 @@ void update_convolutional_layer_gpu(layer l, int batch, float learning_rate_init
//float decay = a.decay;
//int batch = a.batch;

fix_nan_and_inf(l.weight_updates_gpu, l.nweights);
fix_nan_and_inf(l.weights_gpu, l.nweights);

if (l.adam) {
//adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, a.B1, a.B2, a.eps, decay, learning_rate, l.nweights, batch, a.t);
adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.nweights, batch, l.t);
Expand Down
7 changes: 7 additions & 0 deletions src/dropout_layer_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ __global__ void yoloswag420blazeit360noscope(float *input, int size, float *rand
void forward_dropout_layer_gpu(dropout_layer layer, network_state state)
{
if (!state.train) return;
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
//if (iteration_num < state.net.burn_in) return;


int size = layer.inputs*layer.batch;
cuda_random(layer.rand_gpu, size);
/*
Expand All @@ -32,6 +36,9 @@ void forward_dropout_layer_gpu(dropout_layer layer, network_state state)
void backward_dropout_layer_gpu(dropout_layer layer, network_state state)
{
if(!state.delta) return;
int iteration_num = (*state.net.seen) / (state.net.batch*state.net.subdivisions);
//if (iteration_num < state.net.burn_in) return;

int size = layer.inputs*layer.batch;

yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >>>(state.delta, size, layer.rand_gpu, layer.probability, layer.scale);
Expand Down
1 change: 1 addition & 0 deletions src/network.c
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ void backward_network(network net, network_state state)
}
layer l = net.layers[i];
if (l.stopbackward) break;
if (l.onlyforward) continue;
l.backward(l, state);
}
}
Expand Down
1 change: 1 addition & 0 deletions src/network_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void backward_network_gpu(network net, network_state state)
state.input = prev.output_gpu;
state.delta = prev.delta_gpu;
}
if (l.onlyforward) continue;
l.backward_gpu(l, state);

/*
Expand Down
38 changes: 38 additions & 0 deletions src/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "rnn_layer.h"
#include "route_layer.h"
#include "shortcut_layer.h"
#include "scale_channels_layer.h"
#include "softmax_layer.h"
#include "utils.h"
#include "upsample_layer.h"
Expand All @@ -48,6 +49,7 @@ LAYER_TYPE string_to_layer_type(char * type)
{

if (strcmp(type, "[shortcut]")==0) return SHORTCUT;
if (strcmp(type, "[scale_channels]") == 0) return SCALE_CHANNELS;
if (strcmp(type, "[crop]")==0) return CROP;
if (strcmp(type, "[cost]")==0) return COST;
if (strcmp(type, "[detection]")==0) return DETECTION;
Expand Down Expand Up @@ -80,6 +82,7 @@ LAYER_TYPE string_to_layer_type(char * type)
|| strcmp(type, "[softmax]")==0) return SOFTMAX;
if (strcmp(type, "[route]")==0) return ROUTE;
if (strcmp(type, "[upsample]") == 0) return UPSAMPLE;
if (strcmp(type, "[empty]") == 0) return EMPTY;
return BLANK;
}

Expand Down Expand Up @@ -600,6 +603,24 @@ layer parse_shortcut(list *options, size_params params, network net)
}


layer parse_scale_channels(list *options, size_params params, network net)
{
char *l = option_find(options, "from");
int index = atoi(l);
if (index < 0) index = params.index + index;

int batch = params.batch;
layer from = net.layers[index];

layer s = make_scale_channels_layer(batch, index, params.w, params.h, params.c, from.out_w, from.out_h, from.out_c);

char *activation_s = option_find_str_quiet(options, "activation", "linear");
ACTIVATION activation = get_activation(activation_s);
s.activation = activation;
return s;
}


layer parse_activation(list *options, size_params params)
{
char *activation_s = option_find_str(options, "activation", "linear");
Expand Down Expand Up @@ -895,13 +916,30 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
l = parse_shortcut(options, params, net);
net.layers[count - 1].use_bin_output = 0;
net.layers[l.index].use_bin_output = 0;
}else if (lt == SCALE_CHANNELS) {
l = parse_scale_channels(options, params, net);
net.layers[count - 1].use_bin_output = 0;
net.layers[l.index].use_bin_output = 0;
}else if(lt == DROPOUT){
l = parse_dropout(options, params);
l.output = net.layers[count-1].output;
l.delta = net.layers[count-1].delta;
#ifdef GPU
l.output_gpu = net.layers[count-1].output_gpu;
l.delta_gpu = net.layers[count-1].delta_gpu;
#endif
}
else if (lt == EMPTY) {
layer empty_layer;
empty_layer.out_w = params.w;
empty_layer.out_h = params.h;
empty_layer.out_c = params.c;
l = empty_layer;
l.output = net.layers[count - 1].output;
l.delta = net.layers[count - 1].delta;
#ifdef GPU
l.output_gpu = net.layers[count - 1].output_gpu;
l.delta_gpu = net.layers[count - 1].delta_gpu;
#endif
}else{
fprintf(stderr, "Type not recognized: %s\n", s->type);
Expand Down
118 changes: 118 additions & 0 deletions src/scale_channels_layer.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
#include "scale_channels_layer.h"
#include "dark_cuda.h"
#include "blas.h"
#include <stdio.h>
#include <assert.h>

layer make_scale_channels_layer(int batch, int index, int w, int h, int c, int w2, int h2, int c2)
{
fprintf(stderr,"scale Layer: %d\n", index);
layer l = { (LAYER_TYPE)0 };
l.type = SCALE_CHANNELS;
l.batch = batch;
l.w = w;
l.h = h;
l.c = c;
assert(w == 1 & h == 1);

l.out_w = w2;
l.out_h = h2;
l.out_c = c2;
assert(l.out_c == l.c);

l.outputs = l.out_w*l.out_h*l.out_c;
l.inputs = l.outputs;
l.index = index;

l.delta = (float*)calloc(l.outputs * batch, sizeof(float));
l.output = (float*)calloc(l.outputs * batch, sizeof(float));

l.forward = forward_scale_channels_layer;
l.backward = backward_scale_channels_layer;
#ifdef GPU
l.forward_gpu = forward_scale_channels_layer_gpu;
l.backward_gpu = backward_scale_channels_layer_gpu;

l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch);
l.output_gpu = cuda_make_array(l.output, l.outputs*batch);
#endif
return l;
}

void resize_scale_channels_layer(layer *l, int w, int h)
{
l->out_w = w;
l->out_h = h;
l->outputs = l->out_w*l->out_h*l->out_c;
l->inputs = l->outputs;
l->delta = (float*)realloc(l->delta, l->outputs * l->batch * sizeof(float));
l->output = (float*)realloc(l->output, l->outputs * l->batch * sizeof(float));

#ifdef GPU
cuda_free(l->output_gpu);
cuda_free(l->delta_gpu);
l->output_gpu = cuda_make_array(l->output, l->outputs*l->batch);
l->delta_gpu = cuda_make_array(l->delta, l->outputs*l->batch);
#endif

}

void forward_scale_channels_layer(const layer l, network_state state)
{
int size = l.batch * l.out_c * l.out_w * l.out_h;
int channel_size = l.out_w * l.out_h;
float *from_output = state.net.layers[l.index].output;

int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
l.output[i] = state.input[i / channel_size] * from_output[i];
}

activate_array(l.output, l.outputs*l.batch, l.activation);
}

void backward_scale_channels_layer(const layer l, network_state state)
{
gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
//axpy_cpu(l.outputs*l.batch, 1, l.delta, 1, state.delta, 1);
//scale_cpu(l.batch, l.out_w, l.out_h, l.out_c, l.delta, l.w, l.h, l.c, state.net.layers[l.index].delta);

int size = l.batch * l.out_c * l.out_w * l.out_h;
int channel_size = l.out_w * l.out_h;
float *from_output = state.net.layers[l.index].output;
float *from_delta = state.net.layers[l.index].delta;

int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
state.delta[i / channel_size] += l.delta[i] * from_output[i]; // l.delta * from (should be divided by channel_size?)

from_delta[i] = state.input[i / channel_size] * l.delta[i]; // input * l.delta
}
}

#ifdef GPU
void forward_scale_channels_layer_gpu(const layer l, network_state state)
{
int size = l.batch * l.out_c * l.out_w * l.out_h;
int channel_size = l.out_w * l.out_h;

scale_channels_gpu(state.net.layers[l.index].output_gpu, size, channel_size, state.input, l.output_gpu);

activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
}

void backward_scale_channels_layer_gpu(const layer l, network_state state)
{
gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);

int size = l.batch * l.out_c * l.out_w * l.out_h;
int channel_size = l.out_w * l.out_h;
float *from_output = state.net.layers[l.index].output_gpu;
float *from_delta = state.net.layers[l.index].delta_gpu;


backward_scale_channels_gpu(l.delta_gpu, size, channel_size, state.input, from_delta, from_output, state.delta);
}
#endif
Loading

0 comments on commit cc41339

Please sign in to comment.