Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
thgpddl committed Sep 22, 2022
0 parents commit 1dabb15
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 0 deletions.
27 changes: 27 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader


def get_loader(batch_size=64, num_workers=0):
# Training dataset
train_loader = DataLoader(dataset=MNIST(root='.',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),
(0.3081,))])),
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)

# Test dataset
test_loader = DataLoader(dataset=MNIST(root='.',
train=False,
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),
(0.3081,))])),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers)
return train_loader, test_loader
46 changes: 46 additions & 0 deletions loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch


class Loop:
def __init__(self, model, train_loader, test_loader, loss_fn, optimizer, device):
self.model = model
self.train_loader = train_loader
self.test_loader = test_loader
self.loss_fn = loss_fn
self.optimizer = optimizer
self.device = device

def train(self, epoch):
self.model.train()
for batch_idx, (data, target) in enumerate(self.train_loader):
data, target = data.to(self.device), target.to(self.device)

self.optimizer.zero_grad()
output = self.model(data)
loss = self.loss_fn(output, target)
loss.backward()
self.optimizer.step()
if batch_idx % 500 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(self.train_loader.dataset),
100. * batch_idx / len(self.train_loader), loss.item()))

def test(self, epoch):
with torch.no_grad():
self.model.eval()
test_loss = 0
correct = 0
for data, target in self.test_loader:
data, target = data.to(self.device), target.to(self.device)
output = self.model(data)

# sum up batch loss
test_loss += self.loss_fn(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(self.test_loader.dataset)
print('Test Epoch:{} Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(epoch, test_loss, correct, len(self.test_loader.dataset),
100. * correct / len(self.test_loader.dataset)))
28 changes: 28 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import warnings
import torch
from torch import optim
from torch.nn import functional as F

from net import SpatialTransformerNet # 定义模型结构
from visual import visualize_stn # 定义可视化代码
from dataset import get_loader # 定义数据集加载
from loop import Loop # 定义train和test代码段
from utils import random_seed # 设定随机种子

random_seed(0)
warnings.filterwarnings("ignore")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SpatialTransformerNet().to(device) # 实例化模型
train_loader, test_loader = get_loader(batch_size=128, num_workers=0) # 获取数据集
optimizer = optim.SGD(model.parameters(), lr=0.01) # 设定优化器

if __name__ == "__main__":
epoch = 20
loop = Loop(model=model, train_loader=train_loader, test_loader=test_loader, loss_fn=F.nll_loss, optimizer=optimizer, device=device)
for epoch in range(1, epoch + 1):
loop.train(epoch)
loop.test(epoch)
visualize_stn(model=model, test_loader=test_loader, idx=epoch) # 可视化展示STN前后的图,结果保存在visual/文件夹下
64 changes: 64 additions & 0 deletions net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
from torch import nn
from torch.nn import functional as F

class SpatialTransformerNet(nn.Module):
def __init__(self):
super(SpatialTransformerNet, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)

# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)

# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)

# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # 该矩阵就是恒等变换矩阵,即变换后x=x;y=y

# Spatial transformer network forward function
def stn(self, x):
"""
该部分经过卷积和全连接层,从原图拟合出用于仿射变换的转换矩阵,其shape=(2,3)
:param x: 原图,shape=(1,28,28)
:return: 仿射变换(掰正)后的图,shape=(1,28,28)
"""
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3) # 转换矩阵
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)

return x

def forward(self, x): # 1,28,28
# transform the input
x = self.stn(x) # 如果注释该行,就相当于丢弃STN模块,此时相当于直接将输入图片送入下面的分类网络。

# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)


12 changes: 12 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import random
import numpy as np
import torch

def random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True # 使用确定性的操作
torch.backends.cudnn.benchmark = False # 关闭卷积优化
51 changes: 51 additions & 0 deletions visual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import torch
import torchvision

import numpy as np
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def convert_image_np(inp):
"""Convert a Tensor to numpy image."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
return inp


# We want to visualize the output of the spatial transformers layer
# after the training, we visualize a batch of input images and
# the corresponding transformed batch using STN.

def visualize_stn(model, test_loader, idx):
with torch.no_grad():
# Get a batch of training data
data = next(iter(test_loader))[0].to(device) # data包含batchsize张图片

input_tensor = data.cpu()
transformed_input_tensor = model.stn(data).cpu()

in_grid = convert_image_np(
torchvision.utils.make_grid(input_tensor))

out_grid = convert_image_np(
torchvision.utils.make_grid(transformed_input_tensor))

# Plot the results side-by-side
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(in_grid)
axarr[0].set_title('Dataset Images')

axarr[1].imshow(out_grid)
axarr[1].set_title('Transformed Images')
plt.ioff()
try:
plt.savefig(f"Visual/iter_{idx}")
except:
os.mkdir("visual")
plt.savefig(f"Visual/iter_{idx}")

0 comments on commit 1dabb15

Please sign in to comment.