diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..62a9bd9 --- /dev/null +++ b/dataset.py @@ -0,0 +1,27 @@ +from torchvision import transforms +from torchvision.datasets import MNIST +from torch.utils.data import DataLoader + + +def get_loader(batch_size=64, num_workers=0): + # Training dataset + train_loader = DataLoader(dataset=MNIST(root='.', + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.1307,), + (0.3081,))])), + batch_size=batch_size, + shuffle=True, + num_workers=num_workers) + + # Test dataset + test_loader = DataLoader(dataset=MNIST(root='.', + train=False, + transform=transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.1307,), + (0.3081,))])), + batch_size=batch_size, + shuffle=False, + num_workers=num_workers) + return train_loader, test_loader diff --git a/loop.py b/loop.py new file mode 100644 index 0000000..b3888eb --- /dev/null +++ b/loop.py @@ -0,0 +1,46 @@ +import torch + + +class Loop: + def __init__(self, model, train_loader, test_loader, loss_fn, optimizer, device): + self.model = model + self.train_loader = train_loader + self.test_loader = test_loader + self.loss_fn = loss_fn + self.optimizer = optimizer + self.device = device + + def train(self, epoch): + self.model.train() + for batch_idx, (data, target) in enumerate(self.train_loader): + data, target = data.to(self.device), target.to(self.device) + + self.optimizer.zero_grad() + output = self.model(data) + loss = self.loss_fn(output, target) + loss.backward() + self.optimizer.step() + if batch_idx % 500 == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(self.train_loader.dataset), + 100. * batch_idx / len(self.train_loader), loss.item())) + + def test(self, epoch): + with torch.no_grad(): + self.model.eval() + test_loss = 0 + correct = 0 + for data, target in self.test_loader: + data, target = data.to(self.device), target.to(self.device) + output = self.model(data) + + # sum up batch loss + test_loss += self.loss_fn(output, target, size_average=False).item() + # get the index of the max log-probability + pred = output.max(1, keepdim=True)[1] + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(self.test_loader.dataset) + print('Test Epoch:{} Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' + .format(epoch, test_loss, correct, len(self.test_loader.dataset), + 100. * correct / len(self.test_loader.dataset))) diff --git a/main.py b/main.py new file mode 100644 index 0000000..df77c1e --- /dev/null +++ b/main.py @@ -0,0 +1,28 @@ +import warnings +import torch +from torch import optim +from torch.nn import functional as F + +from net import SpatialTransformerNet # 定义模型结构 +from visual import visualize_stn # 定义可视化代码 +from dataset import get_loader # 定义数据集加载 +from loop import Loop # 定义train和test代码段 +from utils import random_seed # 设定随机种子 + +random_seed(0) +warnings.filterwarnings("ignore") + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +model = SpatialTransformerNet().to(device) # 实例化模型 +train_loader, test_loader = get_loader(batch_size=128, num_workers=0) # 获取数据集 +optimizer = optim.SGD(model.parameters(), lr=0.01) # 设定优化器 + +if __name__ == "__main__": + epoch = 20 + loop = Loop(model=model, train_loader=train_loader, test_loader=test_loader, loss_fn=F.nll_loss, optimizer=optimizer, device=device) + for epoch in range(1, epoch + 1): + loop.train(epoch) + loop.test(epoch) + visualize_stn(model=model, test_loader=test_loader, idx=epoch) # 可视化展示STN前后的图,结果保存在visual/文件夹下 diff --git a/net.py b/net.py new file mode 100644 index 0000000..41b5506 --- /dev/null +++ b/net.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class SpatialTransformerNet(nn.Module): + def __init__(self): + super(SpatialTransformerNet, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + # Spatial transformer localization-network + self.localization = nn.Sequential( + nn.Conv2d(1, 8, kernel_size=7), + nn.MaxPool2d(2, stride=2), + nn.ReLU(True), + nn.Conv2d(8, 10, kernel_size=5), + nn.MaxPool2d(2, stride=2), + nn.ReLU(True) + ) + + # Regressor for the 3 * 2 affine matrix + self.fc_loc = nn.Sequential( + nn.Linear(10 * 3 * 3, 32), + nn.ReLU(True), + nn.Linear(32, 3 * 2) + ) + + # Initialize the weights/bias with identity transformation + self.fc_loc[2].weight.data.zero_() + self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) # 该矩阵就是恒等变换矩阵,即变换后x=x;y=y + + # Spatial transformer network forward function + def stn(self, x): + """ + 该部分经过卷积和全连接层,从原图拟合出用于仿射变换的转换矩阵,其shape=(2,3) + :param x: 原图,shape=(1,28,28) + :return: 仿射变换(掰正)后的图,shape=(1,28,28) + """ + xs = self.localization(x) + xs = xs.view(-1, 10 * 3 * 3) + theta = self.fc_loc(xs) + theta = theta.view(-1, 2, 3) # 转换矩阵 + grid = F.affine_grid(theta, x.size()) + x = F.grid_sample(x, grid) + + return x + + def forward(self, x): # 1,28,28 + # transform the input + x = self.stn(x) # 如果注释该行,就相当于丢弃STN模块,此时相当于直接将输入图片送入下面的分类网络。 + + # Perform the usual forward pass + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..491ecdc --- /dev/null +++ b/utils.py @@ -0,0 +1,12 @@ +import random +import numpy as np +import torch + +def random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True # 使用确定性的操作 + torch.backends.cudnn.benchmark = False # 关闭卷积优化 \ No newline at end of file diff --git a/visual.py b/visual.py new file mode 100644 index 0000000..a45d2ee --- /dev/null +++ b/visual.py @@ -0,0 +1,51 @@ +import os +import torch +import torchvision + +import numpy as np +from matplotlib import pyplot as plt + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def convert_image_np(inp): + """Convert a Tensor to numpy image.""" + inp = inp.numpy().transpose((1, 2, 0)) + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + inp = std * inp + mean + inp = np.clip(inp, 0, 1) + return inp + + +# We want to visualize the output of the spatial transformers layer +# after the training, we visualize a batch of input images and +# the corresponding transformed batch using STN. + +def visualize_stn(model, test_loader, idx): + with torch.no_grad(): + # Get a batch of training data + data = next(iter(test_loader))[0].to(device) # data包含batchsize张图片 + + input_tensor = data.cpu() + transformed_input_tensor = model.stn(data).cpu() + + in_grid = convert_image_np( + torchvision.utils.make_grid(input_tensor)) + + out_grid = convert_image_np( + torchvision.utils.make_grid(transformed_input_tensor)) + + # Plot the results side-by-side + f, axarr = plt.subplots(1, 2) + axarr[0].imshow(in_grid) + axarr[0].set_title('Dataset Images') + + axarr[1].imshow(out_grid) + axarr[1].set_title('Transformed Images') + plt.ioff() + try: + plt.savefig(f"Visual/iter_{idx}") + except: + os.mkdir("visual") + plt.savefig(f"Visual/iter_{idx}")