Skip to content

Latest commit

 

History

History
58 lines (48 loc) · 3.01 KB

README.md

File metadata and controls

58 lines (48 loc) · 3.01 KB

Pytorch version of regularized segmentation loss

build python extension module

The implementation of DenseCRF loss depends on fast bilateral filtering, which is provided in C++. Use SWIG to wrap C++ for python and then build the python module of bilateral filtering.

cd wrapper/bilateralfilter
swig -python -c++ bilateralfilter.i
python3 setup.py build

denseCRF loss in pytorch

The source code for the denseCRF loss layer is DenseCRFLoss.py. Declare such a loss layer as follows:

losslayer=DenseCRFLoss(weight=weight, sigma_rgb=sigma_rgb, sigma_xy=sigma_xy, scale_factor=scale_factor)

Here we specify loss weight, Gaussian kernel bandwidth (for RGB and XY), and an optional scale_factor (used to downscale output segmentation so that forward and backward for DenseCRF loss is faster).

The input to the denseCRF loss layer includes image (in the range of [0-255]), segmentation (output of softmax) and a binary tensor specifying region of interest for the regularized loss (e.g. not interested for padded region).

losslayer(image,segmentation,region_of_interest)

how to run the code

To train with densecrf loss, use the following example script. The weight of densecrf loss is 2e-9. The bandwidths of Gaussian kernels are 15 and 100 for RGB and XY respectively. Optionally, the output segmentation is downscaled by 0.5 (rloss-scale).

python3 train_withdensecrfloss.py --backbone mobilenet --lr 0.007 --workers 6 --epochs 60 
--batch-size 12  --checkname deeplab-mobilenet --eval-interval 2 --dataset pascal --save-interval 2 
--densecrfloss 2e-9 --rloss-scale 0.5 --sigma-rgb 15 --sigma-xy 100

(set the path of dataset in mypath.py. For example, the path for pascal should have three subdirectories called "JPEGImages", "SegmentationClassAug", and "pascal_2012_scribble" containing RGB images, groundtruth, and scribbles respectively.)

results

network backbone weak supervision (~3% pixels labeled) full supervision
(partial) Cross Entropy Loss w/ DenseCRF Loss
mobilenet 65.8% (1.05 s/it) 69.4% (1.66 s/it) 72.1% (1.05 s/it)

Table 1: mIOU on PASCAL VOC2012 val set. We report training time for different losses (seconds/iteration, batch_size 12, GTX 1080Ti, AMD FX-6300 3.5GHz).

The trained pytorch models are released here.

acknowledgement

The code here is built on pytorch-deeplab-xception. We alto utilized the efficient c++ implementation of permutohedral lattice from CRF-as-RNN. Fangyu Liu from the University of Waterloo helped tremendously in releasing this pytorch version.