diff --git a/DICECriterion.lua b/DICECriterion.lua
new file mode 100644
index 000000000..1aa0cc48a
--- /dev/null
+++ b/DICECriterion.lua
@@ -0,0 +1,102 @@
+--[[
+ Computes the Sorensen-dice coefficient of similarity given two samples.
+ The quotient of similarity is defined as:
+
+ Q = 2 * (X n Y)
+ -------------------
+ sum_i(X) + sum_i(Y)
+ where X and Y are the two samples;
+ (X n Y) denote the intersection where the elements of X and Y are equal.
+
+ Author: Olalekan Ogunmolu, July 2016
+ patlekano@gmail.com
+]]
+
+local DICECriterion, parent = torch.class('nn.DICECriterion', 'nn.Criterion')
+
+local eps = 1
+
+function DICECriterion:_init(weights)
+ parent._init(self)
+
+ if weights then
+ assert(weights:dim() == 1, "weights input should be 1-D Tensor")
+ self.weights = weights
+ end
+
+end
+
+function DICECriterion:updateOutput(input, target)
+
+ assert(input:nElement() == target:nElement(), "input and target size mismatch")
+
+ local weights = self.weights
+
+ local numerator, denom, common, output
+
+ if weights ~= nil and target:dim() ~= 1 then
+ weights = self.weights:view(1, target:size(2)):expandAs(target)
+ end
+
+ -- compute 2 * (X intersection Y)
+ common = torch.eq(input, target) --find logical equivalence between both
+ numerator = torch.sum(common)
+ numerator = numerator * 2
+
+ -- compute denominator: sum_i(X) + sum_i(Y)
+ denom = input:nElement() + target:nElement() + eps
+
+ output = numerator/denom
+
+ self.output = -output
+
+ return self.output
+end
+
+function DICECriterion:updateGradInput(input, target)
+ --[[
+ 2 * sum_i(X) * sum_i(Y)
+ Gradient = ---------------------------------
+ sum_i(X)*(sum_i(X) + sum_i(Y))^2
+ ]]
+
+ assert(input:nElement() == target:nElement(), "inputs and target size mismatch")
+ self.buffer = self.buffer or input.new()
+
+ local buffer = self.buffer
+ local weights = self.weights
+ local gradInput = self.gradInput
+
+ if weights ~= nil and target:dim() ~= 1 then
+ weights = self.weights:view(1, target:size(2)):expandAs(target)
+ end
+
+ buffer:resizeAs(input)
+ buffer:zero()
+
+ -- compute sum_i(X) + sum_i(Y) + eps
+ buffer:add(input:nElement()):add(target:nElement()):add(eps)
+ -- compute (sum_i(X) + sum_i(Y) + eps )^2 + eps
+ buffer:cmul(buffer):add(eps)
+ -- compute sum_i(X)*(sum_i(X) + sum_i(Y) + eps )^2 + eps
+ buffer:mul(input:nElement())
+
+ gradInput:resizeAs(input)
+ gradInput:zero()
+
+ -- compute 2 * sum_i(X) * sum_i(Y)
+ gradInput:add(input:nElement()):mul(target:nElement()):mul(2)
+
+ -- compute quotient
+ gradInput:cdiv(buffer)
+
+ if weights ~= nil then
+ gradInput:cmul(weights)
+ end
+
+ if self.sizeAverage then
+ gradInput:div(target:nElement())
+ end
+
+ return gradInput
+end
diff --git a/doc/criterion.md b/doc/criterion.md
index 270edb928..0cc8671f7 100644
--- a/doc/criterion.md
+++ b/doc/criterion.md
@@ -8,7 +8,8 @@ target, they compute a gradient according to a given loss function.
* [`BCECriterion`](#nn.BCECriterion): binary cross-entropy for [`Sigmoid`](transfer.md#nn.Sigmoid) (two-class version of [`ClassNLLCriterion`](#nn.ClassNLLCriterion));
* [`ClassNLLCriterion`](#nn.ClassNLLCriterion): negative log-likelihood for [`LogSoftMax`](transfer.md#nn.LogSoftMax) (multi-class);
* [`CrossEntropyCriterion`](#nn.CrossEntropyCriterion): combines [`LogSoftMax`](transfer.md#nn.LogSoftMax) and [`ClassNLLCriterion`](#nn.ClassNLLCriterion);
- * [`ClassSimplexCriterion`](#nn.ClassSimplexCriterion): A simplex embedding criterion for classification.
+ * [`ClassSimplexCriterion`](#nn.ClassSimplexCriterion): A simplex embedding criterion for classification;
+ * [`DICECriterion`](criterion.md#nn.DICECriterion): A criterion for comparing the similarity of two samples;
* [`MarginCriterion`](#nn.MarginCriterion): two class margin-based loss;
* [`SoftMarginCriterion`](#nn.SoftMarginCriterion): two class softmargin-based loss;
* [`MultiMarginCriterion`](#nn.MultiMarginCriterion): multi-class margin-based loss;
@@ -211,6 +212,55 @@ end
This criterion also provides two helper functions `getPredictions(input)` and `getTopPrediction(input)` that return the raw predictions and the top prediction index respectively, given an input sample.
+
+## DICECriterion ##
+
+```lua
+criterion = nn.DICECriterion()
+```
+
+The [Sørensen–Dice index](https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient) measures the degree of similarity between two sample sets. Geiven targets `X` and `Y` in two sample datasets, the quotient of similarity is calculated as
+
+```lua
+ Q = 2 * (X n Y)
+ -------------------
+ sum_i(X) + sum_i(Y)
+```
+
+where X and Y are the two samples.
+X n Y denote the intersection where the elements of X and Y are equal.
+The resulting quotient is a measure of the similarity between the two samples.
+It ranges between 0 and 1. If it is 1, the two images are perfectly similar. Otherwise,
+they are perfectly dissimilar.
+
+The input tensor and output tensor are expected to be of the same size when calling [`forward(input, target)`](#nn.CriterionForward) and [`backward(input, target)`](#nn.CriterionBackward).
+
+Example
+-------
+```lua
+require 'torch'
+require 'nn'
+
+local dice = nn.DICECriterion
+
+inputs = torch.FloatTensor(1, 5)
+preds = torch.FloatTensor(1, 5)
+
+inputs = torch.range(1, 5)
+preds = inputs:clone()
+
+loss = dice:forward(preds, inputs)
+df_do = dice:backward(preds, inputs)
+
+print('loss', loss)
+```
+
+Prints
+
+```bash
+ loss -0.9999999999999
+```
+
## DistKLDivCriterion ##
diff --git a/init.lua b/init.lua
index 70027a18c..caefd0d35 100644
--- a/init.lua
+++ b/init.lua
@@ -150,6 +150,7 @@ require('nn.MapTable')
require('nn.Criterion')
require('nn.MSECriterion')
+require('nn.DICECriterion')
require('nn.SmoothL1Criterion')
require('nn.MarginCriterion')
require('nn.SoftMarginCriterion')
diff --git a/test.lua b/test.lua
index 0b57626a8..71d425f82 100644
--- a/test.lua
+++ b/test.lua
@@ -1339,6 +1339,13 @@ function nntest.MSECriterion()
criterionJacobianTest(cri, input, target)
end
+function nntest.DICECriterion()
+ local input = torch.rand(10)
+ local target = input:clone():add(torch.rand(10))
+ local cri = nn.DICECriterion()
+ criterionJacobianTest(cri, input, target)
+end
+
function nntest.ClassSimplexCriterion()
local nClasses = torch.random(3,15)
local input = torch.rand(nClasses)