diff --git a/loop.py b/loop.py index b3888eb..2cdbb43 100644 --- a/loop.py +++ b/loop.py @@ -1,4 +1,5 @@ import torch +import matplotlib.pyplot as plt class Loop: @@ -9,6 +10,7 @@ def __init__(self, model, train_loader, test_loader, loss_fn, optimizer, device) self.loss_fn = loss_fn self.optimizer = optimizer self.device = device + self.test_acc = [] def train(self, epoch): self.model.train() @@ -44,3 +46,11 @@ def test(self, epoch): print('Test Epoch:{} Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' .format(epoch, test_loss, correct, len(self.test_loader.dataset), 100. * correct / len(self.test_loader.dataset))) + self.test_acc.append(100. * correct / len(self.test_loader.dataset)) + + def show(self): + x = range(1, self.test_acc.__len__() + 1) + plt.figure() + plt.title('Test Acc') + plt.plot(x, self.test_acc) + plt.savefig("./result.jpg") diff --git a/main.py b/main.py index df77c1e..631125f 100644 --- a/main.py +++ b/main.py @@ -26,3 +26,4 @@ loop.train(epoch) loop.test(epoch) visualize_stn(model=model, test_loader=test_loader, idx=epoch) # 可视化展示STN前后的图,结果保存在visual/文件夹下 + loop.show() # 绘制Test Acc变化曲线,保存到result.jpg diff --git a/requirements.txt b/requirements.txt index b28c0e7..74cf7a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ torch==1.10.0 torchvision==0.11.1 numpy=1.19.5 +matplotlib=3.3.4