Skip to content

Commit

Permalink
Segmentation positive and negative weighting balance
Browse files Browse the repository at this point in the history
  • Loading branch information
eric612 committed Mar 22, 2019
1 parent ca3759a commit 26d3be4
Show file tree
Hide file tree
Showing 12 changed files with 745 additions and 181 deletions.
4 changes: 2 additions & 2 deletions include/caffe/data_transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class DataTransformer {
* This is destination blob. It can be part of top blob's data if
* set_cpu_data() is used. See data_layer.cpp for an example.
*/
void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);
void Transform(const Datum& datum, Blob<Dtype>* transformed_blob, int policy_num=0);

/**
* @brief Applies the transformation defined in the data layer's
Expand Down Expand Up @@ -111,7 +111,7 @@ class DataTransformer {
* @brief Crops the datum and AnnotationGroup according to bbox.
*/
void CropImage(const AnnotatedDatum& anno_datum, const NormalizedBBox& bbox,
AnnotatedDatum* cropped_anno_datum);
AnnotatedDatum* cropped_anno_datum , bool has_anno = true);

/**
* @brief Expand the datum.
Expand Down
54 changes: 54 additions & 0 deletions include/caffe/layers/lane_data_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#ifndef CAFFE_DATA_LAYER_HPP_
#define CAFFE_DATA_LAYER_HPP_

#include <string>
#include <vector>

#include "caffe/blob.hpp"
#include "caffe/data_reader.hpp"
#include "caffe/data_transformer.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
#include "caffe/layers/base_data_layer.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/db.hpp"

namespace caffe {

template <typename Dtype>
class LaneDataLayer : public BasePrefetchingDataLayer<Dtype> {
public:
explicit LaneDataLayer(const LayerParameter& param);
virtual ~LaneDataLayer();
virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
// AnnotatedDataLayer uses DataReader instead for sharing for parallelism
virtual inline bool ShareInParallel() const { return false; }
virtual inline const char* type() const { return "LaneData"; }
virtual inline int ExactNumBottomBlobs() const { return 0; }
virtual inline int MinTopBlobs() const { return 1; }

protected:
virtual void load_batch(Batch<Dtype>* batch);

DataReader<AnnotatedDatum> reader_;
bool has_anno_type_;
AnnotatedDatum_AnnotationType anno_type_;
vector<BatchSampler> batch_samplers_;
string label_map_file_;
int yolo_data_type_;
float yolo_data_jitter_;
bool train_diffcult_;
int iters_;
int policy_num_ ;
bool single_class_; //for yolo segementation
YoloSegLabel label_map_;
int seg_label_maxima_;
int seg_scales_;
int seg_resize_width_;
int seg_resize_height_;
};

} // namespace caffe

#endif // CAFFE_DATA_LAYER_HPP_
3 changes: 2 additions & 1 deletion include/caffe/layers/yolo_seg_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class YoloSegLayer : public LossLayer<Dtype> {
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual inline const char* type() const { return "YoloSeg"; }
inline int ExactNumBottomBlobs() const { return 3; } // bottom[2] give the weighting of each classes
virtual inline int ExactNumBottomBlobs() const { return 3; }
protected:

virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
Expand All @@ -47,6 +47,7 @@ class YoloSegLayer : public LossLayer<Dtype> {
protected:
Blob<Dtype> diff_; // cached for backward pass
Blob<Dtype> swap_; // cached for backward pass
bool enable_weighting_;
bool use_logic_gradient_;
bool use_hardsigmoid_;
float object_scale_;
Expand Down
2 changes: 1 addition & 1 deletion models/cityscapes/mobilenet_yolov3_train.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -3395,7 +3395,7 @@ layer {
decay_mult: 0
}
convolution_param {
num_output: 3 # channel = class number
num_output: 19 # channel = class number
kernel_size: 1
pad: 0
stride: 1
Expand Down
74 changes: 43 additions & 31 deletions src/caffe/data_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
Dtype* transformed_data,
NormalizedBBox* crop_bbox,
bool* do_mirror) {
//LOG(INFO) << "test";


const string& data = datum.data();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
Expand Down Expand Up @@ -151,6 +152,7 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,
Dtype* transformed_data) {

NormalizedBBox crop_bbox;
bool do_mirror;
Transform(datum, transformed_data, &crop_bbox, &do_mirror);
Expand All @@ -174,6 +176,7 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
} else {
cv_img = DecodeDatumToCVMatNative(datum);
}

// Transform the cv::image into blob.
return Transform(cv_img, transformed_blob, crop_bbox, do_mirror, policy_num);
#else
Expand Down Expand Up @@ -214,11 +217,11 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
}

template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,
Blob<Dtype>* transformed_blob) {
void DataTransformer<Dtype>::Transform(const Datum& datum,Blob<Dtype>* transformed_blob, int policy_num) {
NormalizedBBox crop_bbox;
bool do_mirror;
Transform(datum, transformed_blob, &crop_bbox, &do_mirror);
Transform(datum, transformed_blob, &crop_bbox, &do_mirror,policy_num);
// entry point 1
}

template<typename Dtype>
Expand All @@ -229,7 +232,7 @@ void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
LOG(INFO) << "test";

CHECK_GT(datum_num, 0) << "There is no datum to add";
CHECK_LE(datum_num, num) <<
"The size of datum_vector must be no greater than transformed_blob->num()";
Expand Down Expand Up @@ -427,18 +430,21 @@ void DataTransformer<Dtype>::CropImage(const Datum& datum,
template<typename Dtype>
void DataTransformer<Dtype>::CropImage(const AnnotatedDatum& anno_datum,
const NormalizedBBox& bbox,
AnnotatedDatum* cropped_anno_datum) {
AnnotatedDatum* cropped_anno_datum , bool has_anno) {
// Crop the datum.
CropImage(anno_datum.datum(), bbox, cropped_anno_datum->mutable_datum());
cropped_anno_datum->set_type(anno_datum.type());
if(has_anno) {
cropped_anno_datum->set_type(anno_datum.type());

// Transform the annotation according to crop_bbox.
const bool do_resize = false;
const bool do_mirror = false;
NormalizedBBox crop_bbox;
ClipBBox(bbox, &crop_bbox);
TransformAnnotation(anno_datum, do_resize, crop_bbox, do_mirror,
cropped_anno_datum->mutable_annotation_group());
}

// Transform the annotation according to crop_bbox.
const bool do_resize = false;
const bool do_mirror = false;
NormalizedBBox crop_bbox;
ClipBBox(bbox, &crop_bbox);
TransformAnnotation(anno_datum, do_resize, crop_bbox, do_mirror,
cropped_anno_datum->mutable_annotation_group());
}

template<typename Dtype>
Expand Down Expand Up @@ -607,31 +613,34 @@ template<typename Dtype>
void DataTransformer<Dtype>::Transform2(const std::vector<cv::Mat> cv_imgs,
Blob<Dtype>* transformed_blob,
bool preserve_pixel_vals) {


// Check dimensions.
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int num = transformed_blob->num();
//LOG(INFO) << img_channels;
//CHECK_EQ(channels, img_channels);


//CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";
const int crop_size = param_.crop_size();
const float scale = 1/255.0;
const bool do_mirror = param_.mirror() && Rand(2);

//LOG(INFO) << scale << ","<< mean_values_[0] << ","<< mean_values_[1];
Dtype* transformed_data = transformed_blob->mutable_cpu_data();
for (int i=0;i<cv_imgs.size();i++) {
//LOG(INFO)<<i;
cv::Mat cv_img = cv_imgs[i];
const int img_channels = cv_img.channels();
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;

// Check dimensions.
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int num = transformed_blob->num();
//LOG(INFO) << img_channels;
//CHECK_EQ(channels, img_channels);
CHECK_LE(height, img_height);
CHECK_LE(width, img_width);
CHECK_GE(num, 1);

//CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";
const int crop_size = param_.crop_size();
const float scale = 1/255.0;
const bool do_mirror = param_.mirror() && Rand(2);
CHECK_GT(img_channels, 0);
//LOG(INFO) << scale << ","<< mean_values_[0] << ","<< mean_values_[1];
Dtype* transformed_data = transformed_blob->mutable_cpu_data();
int top_index;
//LOG(INFO) << do_mirror;
int maxima = 0;
Expand All @@ -650,7 +659,8 @@ void DataTransformer<Dtype>::Transform2(const std::vector<cv::Mat> cv_imgs,
//LOG(INFO) << top_index;
// int top_index = (c * height + h) * width + w;
Dtype pixel = static_cast<Dtype>(ptr[img_index++]);

//if(pixel>0)
// LOG(INFO) << pixel;
transformed_data[top_index] = pixel * scale;
//LOG(INFO) << transformed_data[top_index];
if(top_index>maxima)
Expand Down Expand Up @@ -709,11 +719,12 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
crop_h = crop_size;
crop_w = crop_size;
}

cv::Mat cv_resized_image, cv_noised_image, cv_cropped_image;
if (param_.resize_param_size()) {
cv_resized_image = ApplyResize(cv_img, param_.resize_param(policy_num));
/*char buf[1000];
sprintf(buf, "input/input_%05d.jpg",iter_count++);
if (*do_mirror) {
cv::flip(cv_resized_image, cv_resized_image, 1);
Expand All @@ -730,6 +741,7 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
}
int img_height = cv_noised_image.rows;
int img_width = cv_noised_image.cols;

CHECK_GE(img_height, crop_h);
CHECK_GE(img_width, crop_w);
//LOG(INFO)<<img_width<<","<<img_height;
Expand Down
Loading

0 comments on commit 26d3be4

Please sign in to comment.