diff --git a/ig65m/cli/extract.py b/ig65m/cli/extract.py index bbaf513..900bd69 100755 --- a/ig65m/cli/extract.py +++ b/ig65m/cli/extract.py @@ -67,14 +67,15 @@ def main(args): features = [] - for inputs in tqdm(loader, total=len(dataset) // args.batch_size): - inputs = inputs.to(device) + with torch.no_grad(): + for inputs in tqdm(loader, total=len(dataset) // args.batch_size): + inputs = inputs.to(device) - outputs = model(inputs) - outputs = outputs.data.cpu().numpy() + outputs = model(inputs) + outputs = outputs.data.cpu().numpy() - for output in outputs: - features.append(output) + for output in outputs: + features.append(output) np.save(args.features, np.array(features), allow_pickle=False) print("🍪 Done", file=sys.stderr)