From 8388f01dfefa73bfdedc765fe16a8e944c5d4b5b Mon Sep 17 00:00:00 2001 From: jafermarq Date: Fri, 31 Jan 2025 18:37:29 +0000 Subject: [PATCH] to device embedded --- examples/embedded-devices/embeddedexample/task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/embedded-devices/embeddedexample/task.py b/examples/embedded-devices/embeddedexample/task.py index f08441c0426a..6f89a551ee19 100644 --- a/examples/embedded-devices/embeddedexample/task.py +++ b/examples/embedded-devices/embeddedexample/task.py @@ -84,6 +84,8 @@ def train(net, trainloader, valloader, epochs, learning_rate, device): def test(net, testloader, device): """Validate the model on the test set.""" + net.to(device) # move model to GPU if available + net.eval() criterion = torch.nn.CrossEntropyLoss() correct, loss = 0, 0.0 with torch.no_grad():