Skip to content

Latest commit

 

History

History
 
 

pytorch

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

Pytorch version of regularized segmentation loss

build python extension module

The implementation of DenseCRF loss depends on fast bilateral filtering, which is provided in C++. Use SWIG to wrap C++ for python and then build the python module of bilateral filtering.

cd wrapper/bilateralfilter
swig -python -c++ bilateralfilter.i
python3 setup.py build

denseCRF loss in pytorch

The source code for the denseCRF loss layer is DenseCRFLoss.py. Declare such a loss layer as follows:

losslayer=DenseCRFLoss(weight=weight, sigma_rgb=sigma_rgb, sigma_xy=sigma_xy, scale_factor=scale_factor)

Here we specify loss weight, Gaussian kernel bandwidth (for RGB and XY), and an optional scale_factor (used to downscale output segmentation so that forward and backward for DenseCRF loss is faster).

The input to the denseCRF loss layer includes image (in the range of [0-255]), segmentation (output of softmax) and a binary tensor specifying region of interest for the regularized loss (e.g. not interested for padded region).

losslayer(image,segmentation,region_of_interest)

how to run the code

To train with densecrf loss, use the following example script. The weight of densecrf loss is 2e-9. The bandwidths of Gaussian kernels are 15 and 100 for RGB and XY respectively. Optionally, the output segmentation is downscaled by 0.5 (rloss-scale).

python3 train_withdensecrfloss.py --backbone mobilenet --lr 0.007 --workers 6 --epochs 60 
--batch-size 12  --checkname deeplab-mobilenet --eval-interval 2 --dataset pascal --save-interval 2 
--densecrfloss 2e-9 --rloss-scale 0.5 --sigma-rgb 15 --sigma-xy 100

(set the path of dataset in mypath.py. For example, the path for pascal should have three subdirectories called "JPEGImages", "SegmentationClassAug", and "pascal_2012_scribble" containing RGB images, groundtruth, and scribbles respectively.)

results

network backbone weak supervision (~3% pixels labeled) full supervision
(partial) Cross Entropy Loss w/ DenseCRF Loss
mobilenet 65.8% (1.05 s/it) 69.4% (1.66 s/it) 72.1% (1.05 s/it)

Table 1: mIOU on PASCAL VOC2012 val set. We report training time for different losses (seconds/iteration, batch_size 12, GTX 1080Ti, AMD FX-6300 3.5GHz).

The trained pytorch models are released here.

acknowledgement

The code here is built on pytorch-deeplab-xception. We alto utilized the efficient c++ implementation of permutohedral lattice from CRF-as-RNN. Fangyu Liu from the University of Waterloo helped tremendously in releasing this pytorch version.