From 323e612b3a5b31f4eceeb9cc755b61cd6320d9c7 Mon Sep 17 00:00:00 2001 From: Yao Xu <52527761+zjh199683@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:14:06 +0000 Subject: [PATCH] fix(examples) Fix `quickstart-pytorch` GPU RuntimeError (#4386) --- examples/quickstart-pytorch/pytorchexample/task.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/quickstart-pytorch/pytorchexample/task.py b/examples/quickstart-pytorch/pytorchexample/task.py index 8e0808871616..d115c9f1a469 100644 --- a/examples/quickstart-pytorch/pytorchexample/task.py +++ b/examples/quickstart-pytorch/pytorchexample/task.py @@ -100,6 +100,7 @@ def train(net, trainloader, valloader, epochs, learning_rate, device): def test(net, testloader, device): """Validate the model on the test set.""" + net.to(device) # move model to GPU if available criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad():