Skip to content

Latest commit

 

History

History
68 lines (50 loc) · 3.42 KB

README.md

File metadata and controls

68 lines (50 loc) · 3.42 KB

pytorch-mutual-information

Batch computation of mutual information and histogram2d in Pytorch

This implementation uses kernel density estimation with a gaussian kernel to calculate histograms and joint histograms. We use a diagonal bandwidth matrix for the multivariate case, which allows us to decompose the multivariate kernel as the product of each univariate kernel. From wikipedia,

where the bandwith matrix

Example usage


Setup

device = 'cuda:0'

img1 = Image.open('grad1.jpg').convert('L')
img2 = Image.open('grad.jpg').convert('L')

img1 = transforms.ToTensor() (img1).unsqueeze(dim=0).to(device)
img2 = transforms.ToTensor() (img2).unsqueeze(dim=0).to(device)

# Pair of different images, pair of same images
input1 = torch.cat([img2, img2])
input2 = torch.cat([img1, img2])

B, C, H, W = input1.shape   # shape: (2, 1, 300, 300)

Histogram usage:

hist = histogram(input1.view(B, H*W), torch.linspace(0,255,256), sigma)

Histogram 2D usage:

hist = histogram2d(input1.view(B, H*W), input2.view(B, H*W), torch.linspace(0,255,256), sigma)

Mutual Information (of images)

MI = MutualInformation(num_bins=256, sigma=0.4, normalize=True).to(device)
score = MI(input1, input2)

Results


Histogram

Joint Histogram