diff --git a/DICECriterion.lua b/DICECriterion.lua new file mode 100644 index 000000000..1aa0cc48a --- /dev/null +++ b/DICECriterion.lua @@ -0,0 +1,102 @@ +--[[ + Computes the Sorensen-dice coefficient of similarity given two samples. + The quotient of similarity is defined as: + + Q = 2 * (X n Y) + ------------------- + sum_i(X) + sum_i(Y) + where X and Y are the two samples; + (X n Y) denote the intersection where the elements of X and Y are equal. + + Author: Olalekan Ogunmolu, July 2016 + patlekano@gmail.com +]] + +local DICECriterion, parent = torch.class('nn.DICECriterion', 'nn.Criterion') + +local eps = 1 + +function DICECriterion:_init(weights) + parent._init(self) + + if weights then + assert(weights:dim() == 1, "weights input should be 1-D Tensor") + self.weights = weights + end + +end + +function DICECriterion:updateOutput(input, target) + + assert(input:nElement() == target:nElement(), "input and target size mismatch") + + local weights = self.weights + + local numerator, denom, common, output + + if weights ~= nil and target:dim() ~= 1 then + weights = self.weights:view(1, target:size(2)):expandAs(target) + end + + -- compute 2 * (X intersection Y) + common = torch.eq(input, target) --find logical equivalence between both + numerator = torch.sum(common) + numerator = numerator * 2 + + -- compute denominator: sum_i(X) + sum_i(Y) + denom = input:nElement() + target:nElement() + eps + + output = numerator/denom + + self.output = -output + + return self.output +end + +function DICECriterion:updateGradInput(input, target) + --[[ + 2 * sum_i(X) * sum_i(Y) + Gradient = --------------------------------- + sum_i(X)*(sum_i(X) + sum_i(Y))^2 + ]] + + assert(input:nElement() == target:nElement(), "inputs and target size mismatch") + self.buffer = self.buffer or input.new() + + local buffer = self.buffer + local weights = self.weights + local gradInput = self.gradInput + + if weights ~= nil and target:dim() ~= 1 then + weights = self.weights:view(1, target:size(2)):expandAs(target) + end + + buffer:resizeAs(input) + buffer:zero() + + -- compute sum_i(X) + sum_i(Y) + eps + buffer:add(input:nElement()):add(target:nElement()):add(eps) + -- compute (sum_i(X) + sum_i(Y) + eps )^2 + eps + buffer:cmul(buffer):add(eps) + -- compute sum_i(X)*(sum_i(X) + sum_i(Y) + eps )^2 + eps + buffer:mul(input:nElement()) + + gradInput:resizeAs(input) + gradInput:zero() + + -- compute 2 * sum_i(X) * sum_i(Y) + gradInput:add(input:nElement()):mul(target:nElement()):mul(2) + + -- compute quotient + gradInput:cdiv(buffer) + + if weights ~= nil then + gradInput:cmul(weights) + end + + if self.sizeAverage then + gradInput:div(target:nElement()) + end + + return gradInput +end diff --git a/doc/criterion.md b/doc/criterion.md index 270edb928..0cc8671f7 100644 --- a/doc/criterion.md +++ b/doc/criterion.md @@ -8,7 +8,8 @@ target, they compute a gradient according to a given loss function. * [`BCECriterion`](#nn.BCECriterion): binary cross-entropy for [`Sigmoid`](transfer.md#nn.Sigmoid) (two-class version of [`ClassNLLCriterion`](#nn.ClassNLLCriterion)); * [`ClassNLLCriterion`](#nn.ClassNLLCriterion): negative log-likelihood for [`LogSoftMax`](transfer.md#nn.LogSoftMax) (multi-class); * [`CrossEntropyCriterion`](#nn.CrossEntropyCriterion): combines [`LogSoftMax`](transfer.md#nn.LogSoftMax) and [`ClassNLLCriterion`](#nn.ClassNLLCriterion); - * [`ClassSimplexCriterion`](#nn.ClassSimplexCriterion): A simplex embedding criterion for classification. + * [`ClassSimplexCriterion`](#nn.ClassSimplexCriterion): A simplex embedding criterion for classification; + * [`DICECriterion`](criterion.md#nn.DICECriterion): A criterion for comparing the similarity of two samples; * [`MarginCriterion`](#nn.MarginCriterion): two class margin-based loss; * [`SoftMarginCriterion`](#nn.SoftMarginCriterion): two class softmargin-based loss; * [`MultiMarginCriterion`](#nn.MultiMarginCriterion): multi-class margin-based loss; @@ -211,6 +212,55 @@ end This criterion also provides two helper functions `getPredictions(input)` and `getTopPrediction(input)` that return the raw predictions and the top prediction index respectively, given an input sample. + +## DICECriterion ## + +```lua +criterion = nn.DICECriterion() +``` + +The [Sørensen–Dice index](https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient) measures the degree of similarity between two sample sets. Geiven targets `X` and `Y` in two sample datasets, the quotient of similarity is calculated as + +```lua + Q = 2 * (X n Y) + ------------------- + sum_i(X) + sum_i(Y) +``` + +where X and Y are the two samples. +X n Y denote the intersection where the elements of X and Y are equal. +The resulting quotient is a measure of the similarity between the two samples. +It ranges between 0 and 1. If it is 1, the two images are perfectly similar. Otherwise, +they are perfectly dissimilar. + +The input tensor and output tensor are expected to be of the same size when calling [`forward(input, target)`](#nn.CriterionForward) and [`backward(input, target)`](#nn.CriterionBackward). + +Example +------- +```lua +require 'torch' +require 'nn' + +local dice = nn.DICECriterion + +inputs = torch.FloatTensor(1, 5) +preds = torch.FloatTensor(1, 5) + +inputs = torch.range(1, 5) +preds = inputs:clone() + +loss = dice:forward(preds, inputs) +df_do = dice:backward(preds, inputs) + +print('loss', loss) +``` + +Prints + +```bash + loss -0.9999999999999 +``` + ## DistKLDivCriterion ## diff --git a/init.lua b/init.lua index 70027a18c..caefd0d35 100644 --- a/init.lua +++ b/init.lua @@ -150,6 +150,7 @@ require('nn.MapTable') require('nn.Criterion') require('nn.MSECriterion') +require('nn.DICECriterion') require('nn.SmoothL1Criterion') require('nn.MarginCriterion') require('nn.SoftMarginCriterion') diff --git a/test.lua b/test.lua index 0b57626a8..71d425f82 100644 --- a/test.lua +++ b/test.lua @@ -1339,6 +1339,13 @@ function nntest.MSECriterion() criterionJacobianTest(cri, input, target) end +function nntest.DICECriterion() + local input = torch.rand(10) + local target = input:clone():add(torch.rand(10)) + local cri = nn.DICECriterion() + criterionJacobianTest(cri, input, target) +end + function nntest.ClassSimplexCriterion() local nClasses = torch.random(3,15) local input = torch.rand(nClasses)