Skip to content

t-SNE with OpenCV and sklearn #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions TSNE/animals_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from os import path, listdir
import torch
from torchvision import transforms
import random

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


colors_per_class = {
'dog' : [254, 202, 87],
'horse' : [255, 107, 107],
'elephant' : [10, 189, 227],
'butterfly' : [255, 159, 243],
'chicken' : [16, 172, 132],
'cat' : [128, 80, 128],
'cow' : [87, 101, 116],
'sheep' : [52, 31, 151],
'spider' : [0, 0, 0],
'squirrel' : [100, 100, 255],
}


# processes Animals10 dataset: https://www.kaggle.com/alessiocorrado99/animals10
class AnimalsDataset(torch.utils.data.Dataset):
def __init__(self, data_path, num_images=1000):
translation = {'cane' : 'dog',
'cavallo' : 'horse',
'elefante' : 'elephant',
'farfalla' : 'butterfly',
'gallina' : 'chicken',
'gatto' : 'cat',
'mucca' : 'cow',
'pecora' : 'sheep',
'ragno' : 'spider',
'scoiattolo' : 'squirrel'}

self.classes = translation.values()

if not path.exists(data_path):
raise Exception(data_path + ' does not exist!')

self.data = []

folders = listdir(data_path)
for folder in folders:
label = translation[folder]

full_path = path.join(data_path, folder)
images = listdir(full_path)

current_data = [(path.join(full_path, image), label) for image in images]
self.data += current_data

num_images = min(num_images, len(self.data))
self.data = random.sample(self.data, num_images) # only use num_images images

# We use the transforms described in official PyTorch ResNet inference example:
# https://pytorch.org/hub/pytorch_vision_resnet/.
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


def __len__(self):
return len(self.data)


def __getitem__(self, index):
image_path, label = self.data[index]

image = Image.open(image_path)

try:
image = self.transform(image) # some images in the dataset cannot be processed - we'll skip them
except Exception:
return None

dict_data = {
'image' : image,
'label' : label,
'image_path' : image_path
}
return dict_data


# Skips empty samples in a batch
def collate_skip_empty(batch):
batch = [sample for sample in batch if sample] # check that sample is not None
return torch.utils.data.dataloader.default_collate(batch)
6 changes: 6 additions & 0 deletions TSNE/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
torch==1.4
torchvision==0.5.0
scikit-learn==0.22.2.post1
opencv-python>=3.4.1.15
matplotlib==3.2.1
tqdm==4.44.1
37 changes: 37 additions & 0 deletions TSNE/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torchvision import models
from torch.hub import load_state_dict_from_url


# Define the architecture by modifying resnet.
# Original code is here
# https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
class ResNet101(models.ResNet):
def __init__(self, num_classes=1000, pretrained=True, **kwargs):
# Start with standard resnet101 defined here
# https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
super().__init__(block=models.resnet.Bottleneck, layers=[3, 4, 23, 3], num_classes=num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(models.resnet.model_urls['resnet101'], progress=True)
self.load_state_dict(state_dict)

# Reimplementing forward pass.
# Replacing the following code
# https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py#L197-L213
def _forward_impl(self, x):
# Standard forward for resnet
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)

# Notice there is no forward pass through the original classifier.
x = self.avgpool(x)
x = torch.flatten(x, 1)

return x
212 changes: 212 additions & 0 deletions TSNE/tsne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import argparse
from tqdm import tqdm
import cv2
import torch
import random
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from animals_dataset import AnimalsDataset, collate_skip_empty, colors_per_class
from resnet import ResNet101


def fix_random_seeds():
seed = 10
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)


def get_features(dataset, batch, num_images):
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'

# initialize our implementation of ResNet
model = ResNet101(pretrained=True)
model.eval()
model.to(device)

# read the dataset and initialize the data loader
dataset = AnimalsDataset(dataset, num_images)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, collate_fn=collate_skip_empty, shuffle=True)

# we'll store the features as NumPy array of size num_images x feature_size
features = None

# we'll also store the image labels and paths to visualize them later
labels = []
image_paths = []

for batch in tqdm(dataloader, desc='Running the model inference'):
images = batch['image'].to(device)
labels += batch['label']
image_paths += batch['image_path']

with torch.no_grad():
output = model.forward(images)

current_features = output.cpu().numpy()
if features is not None:
features = np.concatenate((features, current_features))
else:
features = current_features

return features, labels, image_paths


# scale and move the coordinates so they fit [0; 1] range
def scale_to_01_range(x):
# compute the distribution range
value_range = (np.max(x) - np.min(x))

# move the distribution so that it starts from zero
# by extracting the minimal value from all its values
starts_from_zero = x - np.min(x)

# make the distribution fit [0; 1] by dividing by its range
return starts_from_zero / value_range


def scale_image(image, max_image_size):
image_height, image_width, _ = image.shape

scale = max(1, image_width / max_image_size, image_height / max_image_size)
image_width = int(image_width / scale)
image_height = int(image_height / scale)

image = cv2.resize(image, (image_width, image_height))
return image


def draw_rectangle_by_class(image, label):
image_height, image_width, _ = image.shape

# get the color corresponding to image class
color = colors_per_class[label]
image = cv2.rectangle(image, (0, 0), (image_width - 1, image_height - 1), color=color, thickness=5)

return image


def compute_plot_coordinates(image, x, y, image_centers_area_size, offset):
image_height, image_width, _ = image.shape

# compute the image center coordinates on the plot
center_x = int(image_centers_area_size * x) + offset

# in matplotlib, the y axis is directed upward
# to have the same here, we need to mirror the y coordinate
center_y = int(image_centers_area_size * (1 - y)) + offset

# knowing the image center, compute the coordinates of the top left and bottom right corner
tl_x = center_x - int(image_width / 2)
tl_y = center_y - int(image_height / 2)

br_x = tl_x + image_width
br_y = tl_y + image_height

return tl_x, tl_y, br_x, br_y


def visualize_tsne_images(tx, ty, images, labels, plot_size=1000, max_image_size=100):
# we'll put the image centers in the central area of the plot
# and use offsets to make sure the images fit the plot
offset = max_image_size // 2
image_centers_area_size = plot_size - 2 * offset

tsne_plot = 255 * np.ones((plot_size, plot_size, 3), np.uint8)

# now we'll put a small copy of every image to its corresponding T-SNE coordinate
for image_path, label, x, y in tqdm(
zip(images, labels, tx, ty),
desc='Building the T-SNE plot',
total=len(images)
):
image = cv2.imread(image_path)

# scale the image to put it to the plot
image = scale_image(image, max_image_size)

# draw a rectangle with a color corresponding to the image class
image = draw_rectangle_by_class(image, label)

# compute the coordinates of the image on the scaled plot visualization
tl_x, tl_y, br_x, br_y = compute_plot_coordinates(image, x, y, image_centers_area_size, offset)

# put the image to its TSNE coordinates using numpy subarray indices
tsne_plot[tl_y:br_y, tl_x:br_x, :] = image

cv2.imshow('T-SNE', tsne_plot)
cv2.waitKey()


def visualize_tsne_points(tx, ty, labels):
# initialize matplotlib plot
fig = plt.figure()
ax = fig.add_subplot(111)

# for every class, we'll add a scatter plot separately
for label in colors_per_class:
# find the samples of the current class in the data
indices = [i for i, l in enumerate(labels) if l == label]

# extract the coordinates of the points of this class only
current_tx = np.take(tx, indices)
current_ty = np.take(ty, indices)

# convert the class color to matplotlib format:
# BGR -> RGB, divide by 255, convert to np.array
color = np.array([colors_per_class[label][::-1]], dtype=np.float) / 255

# add a scatter plot with the correponding color and label
ax.scatter(current_tx, current_ty, c=color, label=label)

# build a legend using the labels we set previously
ax.legend(loc='best')

# finally, show the plot
plt.show()


def visualize_tsne(tsne, images, labels, plot_size=1000, max_image_size=100):
# extract x and y coordinates representing the positions of the images on T-SNE plot
tx = tsne[:, 0]
ty = tsne[:, 1]

# scale and move the coordinates so they fit [0; 1] range
tx = scale_to_01_range(tx)
ty = scale_to_01_range(ty)

# visualize the plot: samples as colored points
visualize_tsne_points(tx, ty, labels)

# visualize the plot: samples as images
visualize_tsne_images(tx, ty, images, labels, plot_size=plot_size, max_image_size=max_image_size)


def main():
parser = argparse.ArgumentParser()

parser.add_argument('--path', type=str, default='data/raw-img')
parser.add_argument('--batch', type=int, default=64)
parser.add_argument('--num_images', type=int, default=500)
args = parser.parse_args()

fix_random_seeds()

features, labels, image_paths = get_features(
dataset=args.path,
batch=args.batch,
num_images=args.num_images
)

tsne = TSNE(n_components=2).fit_transform(features)

visualize_tsne(tsne, image_paths, labels)

if __name__ == '__main__':
main()