diff --git a/main.py b/main.py index 631125f..79e9799 100644 --- a/main.py +++ b/main.py @@ -20,7 +20,7 @@ optimizer = optim.SGD(model.parameters(), lr=0.01) # 设定优化器 if __name__ == "__main__": - epoch = 20 + epoch = 10 loop = Loop(model=model, train_loader=train_loader, test_loader=test_loader, loss_fn=F.nll_loss, optimizer=optimizer, device=device) for epoch in range(1, epoch + 1): loop.train(epoch)