diff --git a/examples/embedded-devices/embeddedexample/task.py b/examples/embedded-devices/embeddedexample/task.py index f08441c0426a..6f89a551ee19 100644 --- a/examples/embedded-devices/embeddedexample/task.py +++ b/examples/embedded-devices/embeddedexample/task.py @@ -84,6 +84,8 @@ def train(net, trainloader, valloader, epochs, learning_rate, device): def test(net, testloader, device): """Validate the model on the test set.""" + net.to(device) # move model to GPU if available + net.eval() criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad():