Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initialized buffer/gradInputs memory and added tests for dice #913

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions DICECriterion.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
--[[
Computes the Sorensen-dice coefficient of similarity given two samples.
The quotient of similarity is defined as:

Q = 2 * (X n Y)
-------------------
sum_i(X) + sum_i(Y)
where X and Y are the two samples;
(X n Y) denote the intersection where the elements of X and Y are equal.

Author: Olalekan Ogunmolu, July 2016
[email protected]
]]

local DICECriterion, parent = torch.class('nn.DICECriterion', 'nn.Criterion')

local eps = 1

function DICECriterion:_init(weights)
parent._init(self)

if weights then
assert(weights:dim() == 1, "weights input should be 1-D Tensor")
self.weights = weights
end

end

function DICECriterion:updateOutput(input, target)

assert(input:nElement() == target:nElement(), "input and target size mismatch")

local weights = self.weights

local numerator, denom, common, output

if weights ~= nil and target:dim() ~= 1 then
weights = self.weights:view(1, target:size(2)):expandAs(target)
end

-- compute 2 * (X intersection Y)
common = torch.eq(input, target) --find logical equivalence between both
numerator = torch.sum(common)
numerator = numerator * 2

-- compute denominator: sum_i(X) + sum_i(Y)
denom = input:nElement() + target:nElement() + eps

output = numerator/denom

self.output = -output

return self.output
end

function DICECriterion:updateGradInput(input, target)
--[[
2 * sum_i(X) * sum_i(Y)
Gradient = ---------------------------------
sum_i(X)*(sum_i(X) + sum_i(Y))^2
]]

assert(input:nElement() == target:nElement(), "inputs and target size mismatch")
self.buffer = self.buffer or input.new()

local buffer = self.buffer
local weights = self.weights
local gradInput = self.gradInput

if weights ~= nil and target:dim() ~= 1 then
weights = self.weights:view(1, target:size(2)):expandAs(target)
end

buffer:resizeAs(input)
buffer:zero()

-- compute sum_i(X) + sum_i(Y) + eps
buffer:add(input:nElement()):add(target:nElement()):add(eps)
-- compute (sum_i(X) + sum_i(Y) + eps )^2 + eps
buffer:cmul(buffer):add(eps)
-- compute sum_i(X)*(sum_i(X) + sum_i(Y) + eps )^2 + eps
buffer:mul(input:nElement())

gradInput:resizeAs(input)
gradInput:zero()

-- compute 2 * sum_i(X) * sum_i(Y)
gradInput:add(input:nElement()):mul(target:nElement()):mul(2)

-- compute quotient
gradInput:cdiv(buffer)

if weights ~= nil then
gradInput:cmul(weights)
end

if self.sizeAverage then
gradInput:div(target:nElement())
end

return gradInput
end
52 changes: 51 additions & 1 deletion doc/criterion.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ target, they compute a gradient according to a given loss function.
* [`BCECriterion`](#nn.BCECriterion): binary cross-entropy for [`Sigmoid`](transfer.md#nn.Sigmoid) (two-class version of [`ClassNLLCriterion`](#nn.ClassNLLCriterion));
* [`ClassNLLCriterion`](#nn.ClassNLLCriterion): negative log-likelihood for [`LogSoftMax`](transfer.md#nn.LogSoftMax) (multi-class);
* [`CrossEntropyCriterion`](#nn.CrossEntropyCriterion): combines [`LogSoftMax`](transfer.md#nn.LogSoftMax) and [`ClassNLLCriterion`](#nn.ClassNLLCriterion);
* [`ClassSimplexCriterion`](#nn.ClassSimplexCriterion): A simplex embedding criterion for classification.
* [`ClassSimplexCriterion`](#nn.ClassSimplexCriterion): A simplex embedding criterion for classification;
* [`DICECriterion`](criterion.md#nn.DICECriterion): A criterion for comparing the similarity of two samples;
* [`MarginCriterion`](#nn.MarginCriterion): two class margin-based loss;
* [`SoftMarginCriterion`](#nn.SoftMarginCriterion): two class softmargin-based loss;
* [`MultiMarginCriterion`](#nn.MultiMarginCriterion): multi-class margin-based loss;
Expand Down Expand Up @@ -211,6 +212,55 @@ end

This criterion also provides two helper functions `getPredictions(input)` and `getTopPrediction(input)` that return the raw predictions and the top prediction index respectively, given an input sample.

<a name="nn.DICECriterion"></a>
## DICECriterion ##

```lua
criterion = nn.DICECriterion()
```

The [Sørensen–Dice index](https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient) measures the degree of similarity between two sample sets. Geiven targets `X` and `Y` in two sample datasets, the quotient of similarity is calculated as

```lua
Q = 2 * (X n Y)
-------------------
sum_i(X) + sum_i(Y)
```

where X and Y are the two samples.
X n Y denote the intersection where the elements of X and Y are equal.
The resulting quotient is a measure of the similarity between the two samples.
It ranges between 0 and 1. If it is 1, the two images are perfectly similar. Otherwise,
they are perfectly dissimilar.

The input tensor and output tensor are expected to be of the same size when calling [`forward(input, target)`](#nn.CriterionForward) and [`backward(input, target)`](#nn.CriterionBackward).

Example
-------
```lua
require 'torch'
require 'nn'

local dice = nn.DICECriterion

inputs = torch.FloatTensor(1, 5)
preds = torch.FloatTensor(1, 5)

inputs = torch.range(1, 5)
preds = inputs:clone()

loss = dice:forward(preds, inputs)
df_do = dice:backward(preds, inputs)

print('loss', loss)
```

Prints

```bash
loss -0.9999999999999
```

<a name="nn.DistKLDivCriterion"></a>
## DistKLDivCriterion ##

Expand Down
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ require('nn.MapTable')

require('nn.Criterion')
require('nn.MSECriterion')
require('nn.DICECriterion')
require('nn.SmoothL1Criterion')
require('nn.MarginCriterion')
require('nn.SoftMarginCriterion')
Expand Down
7 changes: 7 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1339,6 +1339,13 @@ function nntest.MSECriterion()
criterionJacobianTest(cri, input, target)
end

function nntest.DICECriterion()
local input = torch.rand(10)
local target = input:clone():add(torch.rand(10))
local cri = nn.DICECriterion()
criterionJacobianTest(cri, input, target)
end

function nntest.ClassSimplexCriterion()
local nClasses = torch.random(3,15)
local input = torch.rand(nClasses)
Expand Down