Skip to content

Commit

Permalink
新增Test Acc曲线
Browse files Browse the repository at this point in the history
  • Loading branch information
thgpddl committed Sep 22, 2022
1 parent 6b895b8 commit 3a5f5a2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
10 changes: 10 additions & 0 deletions loop.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import matplotlib.pyplot as plt


class Loop:
Expand All @@ -9,6 +10,7 @@ def __init__(self, model, train_loader, test_loader, loss_fn, optimizer, device)
self.loss_fn = loss_fn
self.optimizer = optimizer
self.device = device
self.test_acc = []

def train(self, epoch):
self.model.train()
Expand Down Expand Up @@ -44,3 +46,11 @@ def test(self, epoch):
print('Test Epoch:{} Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(epoch, test_loss, correct, len(self.test_loader.dataset),
100. * correct / len(self.test_loader.dataset)))
self.test_acc.append(100. * correct / len(self.test_loader.dataset))

def show(self):
x = range(1, self.test_acc.__len__() + 1)
plt.figure()
plt.title('Test Acc')
plt.plot(x, self.test_acc)
plt.savefig("./result.jpg")
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
loop.train(epoch)
loop.test(epoch)
visualize_stn(model=model, test_loader=test_loader, idx=epoch) # 可视化展示STN前后的图,结果保存在visual/文件夹下
loop.show() # 绘制Test Acc变化曲线,保存到result.jpg
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch==1.10.0
torchvision==0.11.1
numpy=1.19.5
matplotlib=3.3.4

0 comments on commit 3a5f5a2

Please sign in to comment.