From 983af98e355f228c5e590ef544594790de609e4f Mon Sep 17 00:00:00 2001 From: Anna Foix Date: Sat, 31 Aug 2024 14:08:52 +0100 Subject: [PATCH 1/7] first test with simclr --- notebooks/simclr_example.ipynb | 48 ++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 notebooks/simclr_example.ipynb diff --git a/notebooks/simclr_example.ipynb b/notebooks/simclr_example.ipynb new file mode 100644 index 0000000..154d4b5 --- /dev/null +++ b/notebooks/simclr_example.ipynb @@ -0,0 +1,48 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "6b572c20-a9ea-4f68-a14c-360f4ae96be6", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import shutil, time, os, requests, random, copy\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torchvision import datasets, transforms, models\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "from sklearn.manifold import TSNE" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:embed_time]", + "language": "python", + "name": "conda-env-embed_time-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From da9b1a9dbf232f67b874bf76498603802f6fd645 Mon Sep 17 00:00:00 2001 From: Anna Foix Date: Sat, 31 Aug 2024 14:38:29 +0100 Subject: [PATCH 2/7] first serve of the dataloader --- scripts/navigate_worms.py | 28 ++++++++++++++++++++++++++++ src/datasets | 1 + 2 files changed, 29 insertions(+) create mode 100644 scripts/navigate_worms.py create mode 160000 src/datasets diff --git a/scripts/navigate_worms.py b/scripts/navigate_worms.py new file mode 100644 index 0000000..42d6d7f --- /dev/null +++ b/scripts/navigate_worms.py @@ -0,0 +1,28 @@ +import torch +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor +from torchvision.datasets import ImageFolder +import matplotlib.pyplot as plt + +# Transforms +data_transform_train = transforms.Compose([ + transforms.RandomRotation(30), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) + + +# Bring the dataset +dataset = torchvision.datasets.ImageFolder(root='/nfs/research/uhlmann/afoix/datasets/image_datasets/BBBC010_v1_foreground_eachworm/', transform=data_transform_train) + +# Split datatset +train, val, test = torch.utils.data.random_split(dataset, [0.6, 0.2, 0.2]) + +# Create data datatloader +trainLoader = torch.utils.data.DataLoader(train, batch_size=batch_size, + num_workers=num_workers, drop_last=True, shuffle=True) +valLoader = torch.utils.data.DataLoader(val, batch_size=batch_size, + num_workers=num_workers, drop_last=True) +testLoader = torch.utils.data.DataLoader(test, batch_size=batch_size, + num_workers=num_workers, drop_last=True) + diff --git a/src/datasets b/src/datasets new file mode 160000 index 0000000..bce9aa9 --- /dev/null +++ b/src/datasets @@ -0,0 +1 @@ +Subproject commit bce9aa9b5db5495ff431b1114a9e2dda30d27af0 From f83bd7f01de85e21ac88e10e1c5c868db19a7142 Mon Sep 17 00:00:00 2001 From: Anna Foix Date: Sat, 31 Aug 2024 14:58:11 +0100 Subject: [PATCH 3/7] working state dataloader --- scripts/navigate_worms.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/scripts/navigate_worms.py b/scripts/navigate_worms.py index 42d6d7f..5b77398 100644 --- a/scripts/navigate_worms.py +++ b/scripts/navigate_worms.py @@ -2,23 +2,26 @@ from torch.utils.data import Dataset from torchvision.transforms import ToTensor from torchvision.datasets import ImageFolder +from torchvision.transforms import v2 import matplotlib.pyplot as plt # Transforms -data_transform_train = transforms.Compose([ - transforms.RandomRotation(30), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) +data_transform_train = v2.Compose([ + v2.RandomRotation(30), + v2.RandomHorizontalFlip(), + v2.ToTensor(), + v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # Bring the dataset -dataset = torchvision.datasets.ImageFolder(root='/nfs/research/uhlmann/afoix/datasets/image_datasets/BBBC010_v1_foreground_eachworm/', transform=data_transform_train) +dataset = ImageFolder(root='/nfs/research/uhlmann/afoix/datasets/image_datasets/bbbc010/BBBC010_v1_foreground_eachworm/', transform=data_transform_train) # Split datatset train, val, test = torch.utils.data.random_split(dataset, [0.6, 0.2, 0.2]) # Create data datatloader +batch_size = 8 +num_workers = 4 trainLoader = torch.utils.data.DataLoader(train, batch_size=batch_size, num_workers=num_workers, drop_last=True, shuffle=True) valLoader = torch.utils.data.DataLoader(val, batch_size=batch_size, @@ -26,3 +29,5 @@ testLoader = torch.utils.data.DataLoader(test, batch_size=batch_size, num_workers=num_workers, drop_last=True) + +print(trainLoader) From 83076050721a68564901157385f93c66dc549882 Mon Sep 17 00:00:00 2001 From: Anna Foix Date: Sat, 31 Aug 2024 21:21:47 +0100 Subject: [PATCH 4/7] VAE_resnet18 model to test --- src/model_VAE_resnet18.py | 157 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 src/model_VAE_resnet18.py diff --git a/src/model_VAE_resnet18.py b/src/model_VAE_resnet18.py new file mode 100644 index 0000000..c81a1fb --- /dev/null +++ b/src/model_VAE_resnet18.py @@ -0,0 +1,157 @@ +import torch +from torch import nn, optim +import torch.nn.functional as F + +class ResizeConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'): + super.__init__() + self.scale_factor = scale_factor + self.mode = mode + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1) + + def forward(self, x): + F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) + x = self.conv(x) + return x + +class BasicBlockEnc(nn.Module): + def __init__(self, in_planes, stride=1): + planes = in_planes * stride + self.conv1 = nn.Conv2d(in_planes, planes, kernel=3, strides=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel=3, strides=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + if strides == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + out = torch.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = torch.relu(out) + return out + +class BasicBlockDec(nn.Module): + def __init__(self, in_planes, stride=1): + super().__init__() + planes = int(in_planes/stride) + + self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(in_planes) + # self.bn1 could have been placed here, + # but that messes up the order of the layers when printing the class + + if stride == 1: + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.shortcut = nn.Sequential() + else: + self.conv1 = ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.shortcut = nn.Sequential( + ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride), + nn.BatchNorm2d(planes) + ) + + def foward(self, x): + out = torch.relu(self.bn2(self.conv2(x))) + out = self.bn1(self.conv1(out)) + out += self.shortcut(x) + out = torch.relu(out) + return out + + +class Resnet18Enc(nn.Module): + + def __init__(self, num_Block=[2, 2, 2, 2], z_dim=10, nc=3): + super().__init__() + self.in_planes = 64 + self.z_dim = z_dim + self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._makelayer(BasicBlockEnc, 64, num_Block[0], stride=1) + self.layer2 = self._makelayer(BasicBlockEnc, 128, num_Block[1], stride=2) + self.layer3 = self._makelayer(BasicBlockEnc, 256, num_Block[2], stride=2) + self.layer4 = self._makelayer(BasicBlockEnc, 512, num_Block[3], stride=2) + self.linear = nn.Linear(512, 2 * z_dim) + + def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride): + strides = [stride] + [1]*(num_Blocks-1) + layers = [] + for stride in strides: + layers += [BasicBlockEnc(self.in_planes, stride)] + self.in_planes = planes + return nn.Sequential(*layers) + + def forward(self, x): + x = torch.relu(self.bn1(self.conv1(x))) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = F.adaptive_avg_pool2d(x, 1) + x = x.view(x.size(0), -1) + x = self.linear(x) + mu = x[:, :self.z_dim] + logvar = x[:, self.z_dim:] + return mu, logvar + +class Resnet18Dec(nn.Module): + + def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3): + super().__init__() + self.in_planes = 512 + + self.linear = nn.Linear(z_dim, 512) + + self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2) + self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2) + self.layer2 = self._make_layer(BasicBlockDec, 64, num_Blocks[1], stride=2) + self.layer1 = self._make_layer(BasicBlockDec, 64, num_Blocks[0], stride=1) + self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2) + + def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride): + strides = [stride] + [1]*(num_Blocks-1) + layers = [] + for stride in reversed(strides): + layers += [BasicBlockDec(self.in_planes, stride)] + self.in_planes = planes + return nn.Sequential(*layers) + + def forward(self, z): + x = self.linear(z) + x = x.view(z.size(0), 512, 1, 1) + x = F.interpolate(x, scale_factor=4) + x = self.layer4(x) + x = self.layer3(x) + x = self.layer2(x) + x = self.layer1(x) + x = torch.sigmoid(self.conv1(x)) + x = x.view(x.size(0), 3, 64, 64) + return x + + +class VAE(nn.Module) + + def __init__(self, z_dim): + super().__init__() + self.encoder = Resnet18Enc(z_dim=z_dim) + self.decoder = Resnet18Dec(z_dim=z_dim) + + def foward(self, x): + mean, logvar = self.encoder(x) + z = self.reparameterize(mean, logvar) + x = self.decoder(z) + return x, z + + @staticmethod + def reparameterize(mean, logvar): + std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two + epsilon = torch.rand_like(std) + return epsilon * std + mean From beb28297e379809a756fccaf70c2a4d5c41d8eec Mon Sep 17 00:00:00 2001 From: Anna Foix Date: Sun, 1 Sep 2024 01:48:35 +0100 Subject: [PATCH 5/7] fix VAEResnet --- src/embed_time/model_VAE_resnet18.py | 161 +++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 src/embed_time/model_VAE_resnet18.py diff --git a/src/embed_time/model_VAE_resnet18.py b/src/embed_time/model_VAE_resnet18.py new file mode 100644 index 0000000..a5a9045 --- /dev/null +++ b/src/embed_time/model_VAE_resnet18.py @@ -0,0 +1,161 @@ +import torch +from torch import nn, optim +import torch.nn.functional as F + +class ResizeConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='nearest'): + super().__init__() + self.scale_factor = scale_factor + self.mode = mode + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1) + + def forward(self, x): + F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) + x = self.conv(x) + return x + +class BasicBlockEnc(nn.Module): + + def __init__(self, in_planes, stride=1): + super().__init__() + + planes = in_planes*stride + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + if stride == 1: + self.shortcut = nn.Sequential() + else: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + out = torch.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + out = torch.relu(out) + return out + +class BasicBlockDec(nn.Module): + def __init__(self, in_planes, stride=1): + super().__init__() + planes = int(in_planes/stride) + + self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(in_planes) + # self.bn1 could have been placed here, + # but that messes up the order of the layers when printing the class + + if stride == 1: + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.shortcut = nn.Sequential() + else: + self.conv1 = ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.shortcut = nn.Sequential( + ResizeConv2d(in_planes, planes, kernel_size=3, scale_factor=stride), + nn.BatchNorm2d(planes) + ) + + def foward(self, x): + out = torch.relu(self.bn2(self.conv2(x))) + out = self.bn1(self.conv1(out)) + out += self.shortcut(x) + out = torch.relu(out) + return out + + +class ResNet18Enc(nn.Module): + + def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3): + super().__init__() + self.in_planes = 64 + self.z_dim = z_dim + self.conv1 = nn.Conv2d(nc, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(BasicBlockEnc, 64, num_Blocks[0], stride=1) + self.layer2 = self._make_layer(BasicBlockEnc, 128, num_Blocks[1], stride=2) + self.layer3 = self._make_layer(BasicBlockEnc, 256, num_Blocks[2], stride=2) + self.layer4 = self._make_layer(BasicBlockEnc, 512, num_Blocks[3], stride=2) + self.linear = nn.Linear(512, 2 * z_dim) + + def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride): + strides = [stride] + [1]*(num_Blocks-1) + layers = [] + for stride in strides: + layers += [BasicBlockEnc(self.in_planes, stride)] + self.in_planes = planes + return nn.Sequential(*layers) + + def forward(self, x): + x = torch.relu(self.bn1(self.conv1(x))) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = F.adaptive_avg_pool2d(x, 1) + x = x.view(x.size(0), -1) + x = self.linear(x) + mu = x[:, :self.z_dim] + logvar = x[:, self.z_dim:] + return mu, logvar + +class ResNet18Dec(nn.Module): + + def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3): + super().__init__() + self.in_planes = 512 + + self.linear = nn.Linear(z_dim, 512) + + self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2) + self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2) + self.layer2 = self._make_layer(BasicBlockDec, 64, num_Blocks[1], stride=2) + self.layer1 = self._make_layer(BasicBlockDec, 64, num_Blocks[0], stride=1) + self.conv1 = ResizeConv2d(64, nc, kernel_size=3, scale_factor=2) + + def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride): + strides = [stride] + [1]*(num_Blocks-1) + layers = [] + for stride in reversed(strides): + layers += [BasicBlockDec(self.in_planes, stride)] + self.in_planes = planes + return nn.Sequential(*layers) + + def forward(self, z): + x = self.linear(z) + x = x.view(z.size(0), 512, 1, 1) + x = F.interpolate(x, scale_factor=4) + x = self.layer4(x) + x = self.layer3(x) + x = self.layer2(x) + x = self.layer1(x) + x = torch.sigmoid(self.conv1(x)) + x = x.view(x.size(0), 3, 64, 64) + return x + + +class VAEResNet18(nn.Module): + + def __init__(self, z_dim): + super().__init__() + self.encoder = ResNet18Enc(z_dim=z_dim) + self.decoder = ResNet18Dec(z_dim=z_dim) + + def foward(self, x): + mean, logvar = self.encoder(x) + z = self.reparameterize(mean, logvar) + x = self.decoder(z) + return x, z + + @staticmethod + def reparameterize(mean, logvar): + std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two + epsilon = torch.rand_like(std) + return epsilon * std + mean From daa344cb6783cd20ccbc328ffbdfeb8228bffd8d Mon Sep 17 00:00:00 2001 From: Ben Salmon Date: Sun, 1 Sep 2024 02:00:45 +0000 Subject: [PATCH 6/7] no bugs, yet to train --- src/embed_time/model_VAE_resnet18.py | 31 ++++++++++++---------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/src/embed_time/model_VAE_resnet18.py b/src/embed_time/model_VAE_resnet18.py index a5a9045..e3244e0 100644 --- a/src/embed_time/model_VAE_resnet18.py +++ b/src/embed_time/model_VAE_resnet18.py @@ -7,10 +7,10 @@ def __init__(self, in_channels, out_channels, kernel_size, scale_factor, mode='n super().__init__() self.scale_factor = scale_factor self.mode = mode - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size//2) def forward(self, x): - F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) + x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) x = self.conv(x) return x @@ -27,7 +27,7 @@ def __init__(self, in_planes, stride=1): self.bn2 = nn.BatchNorm2d(planes) if stride == 1: - self.shortcut = nn.Sequential() + self.shortcut = nn.Identity() else: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), @@ -63,7 +63,7 @@ def __init__(self, in_planes, stride=1): nn.BatchNorm2d(planes) ) - def foward(self, x): + def forward(self, x): out = torch.relu(self.bn2(self.conv2(x))) out = self.bn1(self.conv1(out)) out += self.shortcut(x) @@ -83,7 +83,7 @@ def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3): self.layer2 = self._make_layer(BasicBlockEnc, 128, num_Blocks[1], stride=2) self.layer3 = self._make_layer(BasicBlockEnc, 256, num_Blocks[2], stride=2) self.layer4 = self._make_layer(BasicBlockEnc, 512, num_Blocks[3], stride=2) - self.linear = nn.Linear(512, 2 * z_dim) + self.linear = nn.Conv2d(512, 2 * z_dim, kernel_size=1) def _make_layer(self, BasicBlockEnc, planes, num_Blocks, stride): strides = [stride] + [1]*(num_Blocks-1) @@ -99,11 +99,8 @@ def forward(self, x): x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) - x = F.adaptive_avg_pool2d(x, 1) - x = x.view(x.size(0), -1) x = self.linear(x) - mu = x[:, :self.z_dim] - logvar = x[:, self.z_dim:] + mu, logvar = torch.chunk(x, 2, dim=1) return mu, logvar class ResNet18Dec(nn.Module): @@ -111,8 +108,9 @@ class ResNet18Dec(nn.Module): def __init__(self, num_Blocks=[2,2,2,2], z_dim=10, nc=3): super().__init__() self.in_planes = 512 + self.nc = nc - self.linear = nn.Linear(z_dim, 512) + self.linear = nn.Conv2d(z_dim, 512, kernel_size=1) self.layer4 = self._make_layer(BasicBlockDec, 256, num_Blocks[3], stride=2) self.layer3 = self._make_layer(BasicBlockDec, 128, num_Blocks[2], stride=2) @@ -130,25 +128,22 @@ def _make_layer(self, BasicBlockDec, planes, num_Blocks, stride): def forward(self, z): x = self.linear(z) - x = x.view(z.size(0), 512, 1, 1) - x = F.interpolate(x, scale_factor=4) x = self.layer4(x) x = self.layer3(x) x = self.layer2(x) x = self.layer1(x) x = torch.sigmoid(self.conv1(x)) - x = x.view(x.size(0), 3, 64, 64) return x class VAEResNet18(nn.Module): - def __init__(self, z_dim): + def __init__(self, nc, z_dim): super().__init__() - self.encoder = ResNet18Enc(z_dim=z_dim) - self.decoder = ResNet18Dec(z_dim=z_dim) + self.encoder = ResNet18Enc(nc=nc, z_dim=z_dim) + self.decoder = ResNet18Dec(nc=nc, z_dim=z_dim) - def foward(self, x): + def forward(self, x): mean, logvar = self.encoder(x) z = self.reparameterize(mean, logvar) x = self.decoder(z) @@ -157,5 +152,5 @@ def foward(self, x): @staticmethod def reparameterize(mean, logvar): std = torch.exp(logvar / 2) # in log-space, squareroot is divide by two - epsilon = torch.rand_like(std) + epsilon = torch.randn_like(std) return epsilon * std + mean From 7674ab61a144e54803504aa3c487635596073063 Mon Sep 17 00:00:00 2001 From: Ben Salmon Date: Sun, 1 Sep 2024 13:46:08 +0000 Subject: [PATCH 7/7] remove in place operations --- src/embed_time/model_VAE_resnet18.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/embed_time/model_VAE_resnet18.py b/src/embed_time/model_VAE_resnet18.py index e3244e0..12ee85f 100644 --- a/src/embed_time/model_VAE_resnet18.py +++ b/src/embed_time/model_VAE_resnet18.py @@ -37,7 +37,7 @@ def __init__(self, in_planes, stride=1): def forward(self, x): out = torch.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) - out += self.shortcut(x) + out = out + self.shortcut(x) out = torch.relu(out) return out @@ -66,7 +66,7 @@ def __init__(self, in_planes, stride=1): def forward(self, x): out = torch.relu(self.bn2(self.conv2(x))) out = self.bn1(self.conv1(out)) - out += self.shortcut(x) + out = out + self.shortcut(x) out = torch.relu(out) return out