Skip to content

mehrdad-yousefi/pytorch-mutual-information

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pytorch-mutual-information

Batch computation of mutual information and histogram2d in Pytorch

This implementation uses kernel density estimation with a gaussian kernel to calculate histograms and joint histograms. We use a diagonal bandwidth matrix for the multivariate case, which allows us to decompose the multivariate kernel as the product of each univariate kernel. From wikipedia,

where the bandwith matrix

Example usage


Setup

device = 'cuda:0'

img1 = Image.open('grad1.jpg').convert('L')
img2 = Image.open('grad.jpg').convert('L')

img1 = transforms.ToTensor() (img1).unsqueeze(dim=0).to(device)
img2 = transforms.ToTensor() (img2).unsqueeze(dim=0).to(device)

# Pair of different images, pair of same images
input1 = torch.cat([img2, img2])
input2 = torch.cat([img1, img2])

B, C, H, W = input1.shape   # shape: (2, 1, 300, 300)

Histogram usage:

hist = histogram(input1.view(B, H*W), torch.linspace(0,255,256), sigma)

Histogram 2D usage:

hist = histogram2d(input1.view(B, H*W), input2.view(B, H*W), torch.linspace(0,255,256), sigma)

Mutual Information (of images)

MI = MutualInformation(num_bins=256, sigma=0.4, normalize=True).to(device)
score = MI(input1, input2)

Results


Histogram

Joint Histogram

About

Mutual Information in Pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%