-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1dabb15
Showing
6 changed files
with
228 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from torchvision import transforms | ||
from torchvision.datasets import MNIST | ||
from torch.utils.data import DataLoader | ||
|
||
|
||
def get_loader(batch_size=64, num_workers=0): | ||
# Training dataset | ||
train_loader = DataLoader(dataset=MNIST(root='.', | ||
train=True, | ||
download=True, | ||
transform=transforms.Compose([transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), | ||
(0.3081,))])), | ||
batch_size=batch_size, | ||
shuffle=True, | ||
num_workers=num_workers) | ||
|
||
# Test dataset | ||
test_loader = DataLoader(dataset=MNIST(root='.', | ||
train=False, | ||
transform=transforms.Compose([transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), | ||
(0.3081,))])), | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=num_workers) | ||
return train_loader, test_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
|
||
|
||
class Loop: | ||
def __init__(self, model, train_loader, test_loader, loss_fn, optimizer, device): | ||
self.model = model | ||
self.train_loader = train_loader | ||
self.test_loader = test_loader | ||
self.loss_fn = loss_fn | ||
self.optimizer = optimizer | ||
self.device = device | ||
|
||
def train(self, epoch): | ||
self.model.train() | ||
for batch_idx, (data, target) in enumerate(self.train_loader): | ||
data, target = data.to(self.device), target.to(self.device) | ||
|
||
self.optimizer.zero_grad() | ||
output = self.model(data) | ||
loss = self.loss_fn(output, target) | ||
loss.backward() | ||
self.optimizer.step() | ||
if batch_idx % 500 == 0: | ||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | ||
epoch, batch_idx * len(data), len(self.train_loader.dataset), | ||
100. * batch_idx / len(self.train_loader), loss.item())) | ||
|
||
def test(self, epoch): | ||
with torch.no_grad(): | ||
self.model.eval() | ||
test_loss = 0 | ||
correct = 0 | ||
for data, target in self.test_loader: | ||
data, target = data.to(self.device), target.to(self.device) | ||
output = self.model(data) | ||
|
||
# sum up batch loss | ||
test_loss += self.loss_fn(output, target, size_average=False).item() | ||
# get the index of the max log-probability | ||
pred = output.max(1, keepdim=True)[1] | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
|
||
test_loss /= len(self.test_loader.dataset) | ||
print('Test Epoch:{} Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' | ||
.format(epoch, test_loss, correct, len(self.test_loader.dataset), | ||
100. * correct / len(self.test_loader.dataset))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import warnings | ||
import torch | ||
from torch import optim | ||
from torch.nn import functional as F | ||
|
||
from net import SpatialTransformerNet # 定义模型结构 | ||
from visual import visualize_stn # 定义可视化代码 | ||
from dataset import get_loader # 定义数据集加载 | ||
from loop import Loop # 定义train和test代码段 | ||
from utils import random_seed # 设定随机种子 | ||
|
||
random_seed(0) | ||
warnings.filterwarnings("ignore") | ||
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
model = SpatialTransformerNet().to(device) # 实例化模型 | ||
train_loader, test_loader = get_loader(batch_size=128, num_workers=0) # 获取数据集 | ||
optimizer = optim.SGD(model.parameters(), lr=0.01) # 设定优化器 | ||
|
||
if __name__ == "__main__": | ||
epoch = 20 | ||
loop = Loop(model=model, train_loader=train_loader, test_loader=test_loader, loss_fn=F.nll_loss, optimizer=optimizer, device=device) | ||
for epoch in range(1, epoch + 1): | ||
loop.train(epoch) | ||
loop.test(epoch) | ||
visualize_stn(model=model, test_loader=test_loader, idx=epoch) # 可视化展示STN前后的图,结果保存在visual/文件夹下 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
class SpatialTransformerNet(nn.Module): | ||
def __init__(self): | ||
super(SpatialTransformerNet, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5) | ||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5) | ||
self.conv2_drop = nn.Dropout2d() | ||
self.fc1 = nn.Linear(320, 50) | ||
self.fc2 = nn.Linear(50, 10) | ||
|
||
# Spatial transformer localization-network | ||
self.localization = nn.Sequential( | ||
nn.Conv2d(1, 8, kernel_size=7), | ||
nn.MaxPool2d(2, stride=2), | ||
nn.ReLU(True), | ||
nn.Conv2d(8, 10, kernel_size=5), | ||
nn.MaxPool2d(2, stride=2), | ||
nn.ReLU(True) | ||
) | ||
|
||
# Regressor for the 3 * 2 affine matrix | ||
self.fc_loc = nn.Sequential( | ||
nn.Linear(10 * 3 * 3, 32), | ||
nn.ReLU(True), | ||
nn.Linear(32, 3 * 2) | ||
) | ||
|
||
# Initialize the weights/bias with identity transformation | ||
self.fc_loc[2].weight.data.zero_() | ||
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # 该矩阵就是恒等变换矩阵,即变换后x=x;y=y | ||
|
||
# Spatial transformer network forward function | ||
def stn(self, x): | ||
""" | ||
该部分经过卷积和全连接层,从原图拟合出用于仿射变换的转换矩阵,其shape=(2,3) | ||
:param x: 原图,shape=(1,28,28) | ||
:return: 仿射变换(掰正)后的图,shape=(1,28,28) | ||
""" | ||
xs = self.localization(x) | ||
xs = xs.view(-1, 10 * 3 * 3) | ||
theta = self.fc_loc(xs) | ||
theta = theta.view(-1, 2, 3) # 转换矩阵 | ||
grid = F.affine_grid(theta, x.size()) | ||
x = F.grid_sample(x, grid) | ||
|
||
return x | ||
|
||
def forward(self, x): # 1,28,28 | ||
# transform the input | ||
x = self.stn(x) # 如果注释该行,就相当于丢弃STN模块,此时相当于直接将输入图片送入下面的分类网络。 | ||
|
||
# Perform the usual forward pass | ||
x = F.relu(F.max_pool2d(self.conv1(x), 2)) | ||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) | ||
x = x.view(-1, 320) | ||
x = F.relu(self.fc1(x)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import random | ||
import numpy as np | ||
import torch | ||
|
||
def random_seed(seed): | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
torch.backends.cudnn.deterministic = True # 使用确定性的操作 | ||
torch.backends.cudnn.benchmark = False # 关闭卷积优化 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
import torch | ||
import torchvision | ||
|
||
import numpy as np | ||
from matplotlib import pyplot as plt | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def convert_image_np(inp): | ||
"""Convert a Tensor to numpy image.""" | ||
inp = inp.numpy().transpose((1, 2, 0)) | ||
mean = np.array([0.485, 0.456, 0.406]) | ||
std = np.array([0.229, 0.224, 0.225]) | ||
inp = std * inp + mean | ||
inp = np.clip(inp, 0, 1) | ||
return inp | ||
|
||
|
||
# We want to visualize the output of the spatial transformers layer | ||
# after the training, we visualize a batch of input images and | ||
# the corresponding transformed batch using STN. | ||
|
||
def visualize_stn(model, test_loader, idx): | ||
with torch.no_grad(): | ||
# Get a batch of training data | ||
data = next(iter(test_loader))[0].to(device) # data包含batchsize张图片 | ||
|
||
input_tensor = data.cpu() | ||
transformed_input_tensor = model.stn(data).cpu() | ||
|
||
in_grid = convert_image_np( | ||
torchvision.utils.make_grid(input_tensor)) | ||
|
||
out_grid = convert_image_np( | ||
torchvision.utils.make_grid(transformed_input_tensor)) | ||
|
||
# Plot the results side-by-side | ||
f, axarr = plt.subplots(1, 2) | ||
axarr[0].imshow(in_grid) | ||
axarr[0].set_title('Dataset Images') | ||
|
||
axarr[1].imshow(out_grid) | ||
axarr[1].set_title('Transformed Images') | ||
plt.ioff() | ||
try: | ||
plt.savefig(f"Visual/iter_{idx}") | ||
except: | ||
os.mkdir("visual") | ||
plt.savefig(f"Visual/iter_{idx}") |