Skip to content

Commit

Permalink
Merge pull request #1 from Cryaaa/anna_contrastive
Browse files Browse the repository at this point in the history
Check the  ResNEe18 model
  • Loading branch information
afoix authored Sep 1, 2024
2 parents a9680c5 + 7674ab6 commit bf6442b
Show file tree
Hide file tree
Showing 5 changed files with 395 additions and 0 deletions.
48 changes: 48 additions & 0 deletions notebooks/simclr_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "6b572c20-a9ea-4f68-a14c-360f4ae96be6",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import shutil, time, os, requests, random, copy\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torchvision import datasets, transforms, models\n",
"\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"from sklearn.manifold import TSNE"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:embed_time]",
"language": "python",
"name": "conda-env-embed_time-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
33 changes: 33 additions & 0 deletions scripts/navigate_worms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2
import matplotlib.pyplot as plt

# Transforms
data_transform_train = v2.Compose([
v2.RandomRotation(30),
v2.RandomHorizontalFlip(),
v2.ToTensor(),
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])


# Bring the dataset
dataset = ImageFolder(root='/nfs/research/uhlmann/afoix/datasets/image_datasets/bbbc010/BBBC010_v1_foreground_eachworm/', transform=data_transform_train)

# Split datatset
train, val, test = torch.utils.data.random_split(dataset, [0.6, 0.2, 0.2])

# Create data datatloader
batch_size = 8
num_workers = 4
trainLoader = torch.utils.data.DataLoader(train, batch_size=batch_size,
num_workers=num_workers, drop_last=True, shuffle=True)
valLoader = torch.utils.data.DataLoader(val, batch_size=batch_size,
num_workers=num_workers, drop_last=True)
testLoader = torch.utils.data.DataLoader(test, batch_size=batch_size,
num_workers=num_workers, drop_last=True)


print(trainLoader)
1 change: 1 addition & 0 deletions src/datasets
Submodule datasets added at bce9aa
156 changes: 156 additions & 0 deletions src/embed_time/model_VAE_resnet18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import torch
from torch import nn, optim
import torch.nn.functional as F

class ResizeConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
super().__init__()
self.scale_factor = scale_factor
self.mode = mode
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size//2)

def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
x = self.conv(x)
return x

class BasicBlockEnc(nn.Module):

def __init__(self, in_planes, stride=1):
super().__init__()

planes = in_planes*stride

self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)

if stride == 1:
self.shortcut = nn.Identity()
else:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes)
)

def forward(self, x):
out = torch.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = out + self.shortcut(x)
out = torch.relu(out)
return out

class BasicBlockDec(nn.Module):
def __init__(self, in_planes, stride=1):
super().__init__()
planes = int(in_planes/stride)

self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(in_planes)
# self.bn1 could have been placed here,
# but that messes up the order of the layers when printing the class

if stride == 1:
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
else:
self.conv1 = ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride)
self.bn1 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential(
ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride),
nn.BatchNorm2d(planes)
)

def forward(self, x):
out = torch.relu(self.bn2(self.conv2(x)))
out = self.bn1(self.conv1(out))
out = out + self.shortcut(x)
out = torch.relu(out)
return out


class ResNet18Enc(nn.Module):

def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
super().__init__()
self.in_planes = 64
self.z_dim = z_dim
self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(BasicBlockEnc, 64, num_Blocks[0], stride=1)
self.layer2 = self._make_layer(BasicBlockEnc, 128, num_Blocks[1], stride=2)
self.layer3 = self._make_layer(BasicBlockEnc, 256, num_Blocks[2], stride=2)
self.layer4 = self._make_layer(BasicBlockEnc, 512, num_Blocks[3], stride=2)
self.linear = nn.Conv2d(512, 2 * z_dim, kernel_size=1)

def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride):
strides = [stride] + [1]*(num_Blocks-1)
layers = []
for stride in strides:
layers += [BasicBlockEnc(self.in_planes, stride)]
self.in_planes = planes
return nn.Sequential(*layers)

def forward(self, x):
x = torch.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.linear(x)
mu, logvar = torch.chunk(x, 2, dim=1)
return mu, logvar

class ResNet18Dec(nn.Module):

def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
super().__init__()
self.in_planes = 512
self.nc = nc

self.linear = nn.Conv2d(z_dim, 512, kernel_size=1)

self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2)
self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2)
self.layer2 = self._make_layer(BasicBlockDec, 64, num_Blocks[1], stride=2)
self.layer1 = self._make_layer(BasicBlockDec, 64, num_Blocks[0], stride=1)
self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)

def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride):
strides = [stride] + [1]*(num_Blocks-1)
layers = []
for stride in reversed(strides):
layers += [BasicBlockDec(self.in_planes, stride)]
self.in_planes = planes
return nn.Sequential(*layers)

def forward(self, z):
x = self.linear(z)
x = self.layer4(x)
x = self.layer3(x)
x = self.layer2(x)
x = self.layer1(x)
x = torch.sigmoid(self.conv1(x))
return x


class VAEResNet18(nn.Module):

def __init__(self, nc, z_dim):
super().__init__()
self.encoder = ResNet18Enc(nc=nc, z_dim=z_dim)
self.decoder = ResNet18Dec(nc=nc, z_dim=z_dim)

def forward(self, x):
mean, logvar = self.encoder(x)
z = self.reparameterize(mean, logvar)
x = self.decoder(z)
return x, z

@staticmethod
def reparameterize(mean, logvar):
std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two
epsilon = torch.randn_like(std)
return epsilon * std + mean
157 changes: 157 additions & 0 deletions src/model_VAE_resnet18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import torch
from torch import nn, optim
import torch.nn.functional as F

class ResizeConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'):
super.__init__()
self.scale_factor = scale_factor
self.mode = mode
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

def forward(self, x):
F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
x = self.conv(x)
return x

class BasicBlockEnc(nn.Module):
def __init__(self, in_planes, stride=1):
planes = in_planes * stride
self.conv1 = nn.Conv2d(in_planes, planes, kernel=3, strides=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel=3, strides=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)

if strides == 1:
self.shortcut = nn.Sequential()
else:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes)
)

def forward(self, x):
out = torch.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = torch.relu(out)
return out

class BasicBlockDec(nn.Module):
def __init__(self, in_planes, stride=1):
super().__init__()
planes = int(in_planes/stride)

self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(in_planes)
# self.bn1 could have been placed here,
# but that messes up the order of the layers when printing the class

if stride == 1:
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
else:
self.conv1 = ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride)
self.bn1 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential(
ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride),
nn.BatchNorm2d(planes)
)

def foward(self, x):
out = torch.relu(self.bn2(self.conv2(x)))
out = self.bn1(self.conv1(out))
out += self.shortcut(x)
out = torch.relu(out)
return out


class Resnet18Enc(nn.Module):

def __init__(self, num_Block=[2, 2, 2, 2], z_dim=10, nc=3):
super().__init__()
self.in_planes = 64
self.z_dim = z_dim
self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._makelayer(BasicBlockEnc, 64, num_Block[0], stride=1)
self.layer2 = self._makelayer(BasicBlockEnc, 128, num_Block[1], stride=2)
self.layer3 = self._makelayer(BasicBlockEnc, 256, num_Block[2], stride=2)
self.layer4 = self._makelayer(BasicBlockEnc, 512, num_Block[3], stride=2)
self.linear = nn.Linear(512, 2 * z_dim)

def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride):
strides = [stride] + [1]*(num_Blocks-1)
layers = []
for stride in strides:
layers += [BasicBlockEnc(self.in_planes, stride)]
self.in_planes = planes
return nn.Sequential(*layers)

def forward(self, x):
x = torch.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = F.adaptive_avg_pool2d(x, 1)
x = x.view(x.size(0), -1)
x = self.linear(x)
mu = x[:, :self.z_dim]
logvar = x[:, self.z_dim:]
return mu, logvar

class Resnet18Dec(nn.Module):

def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3):
super().__init__()
self.in_planes = 512

self.linear = nn.Linear(z_dim, 512)

self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2)
self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2)
self.layer2 = self._make_layer(BasicBlockDec, 64, num_Blocks[1], stride=2)
self.layer1 = self._make_layer(BasicBlockDec, 64, num_Blocks[0], stride=1)
self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2)

def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride):
strides = [stride] + [1]*(num_Blocks-1)
layers = []
for stride in reversed(strides):
layers += [BasicBlockDec(self.in_planes, stride)]
self.in_planes = planes
return nn.Sequential(*layers)

def forward(self, z):
x = self.linear(z)
x = x.view(z.size(0), 512, 1, 1)
x = F.interpolate(x, scale_factor=4)
x = self.layer4(x)
x = self.layer3(x)
x = self.layer2(x)
x = self.layer1(x)
x = torch.sigmoid(self.conv1(x))
x = x.view(x.size(0), 3, 64, 64)
return x


class VAE(nn.Module)

def __init__(self, z_dim):
super().__init__()
self.encoder = Resnet18Enc(z_dim=z_dim)
self.decoder = Resnet18Dec(z_dim=z_dim)

def foward(self, x):
mean, logvar = self.encoder(x)
z = self.reparameterize(mean, logvar)
x = self.decoder(z)
return x, z

@staticmethod
def reparameterize(mean, logvar):
std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two
epsilon = torch.rand_like(std)
return epsilon * std + mean

0 comments on commit bf6442b

Please sign in to comment.