From 1f4b2ed40a2c7275f08a2e4615d4df79c994dff0 Mon Sep 17 00:00:00 2001 From: rasbt Date: Sun, 27 Feb 2022 20:05:28 -0600 Subject: [PATCH] install torchmetrics --- templates/pytorch_lightning/submit_command.sh | 5 +++-- templates/pytorch_lightning/tune_classification_basic.py | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/templates/pytorch_lightning/submit_command.sh b/templates/pytorch_lightning/submit_command.sh index 605c6d0..7935a6d 100644 --- a/templates/pytorch_lightning/submit_command.sh +++ b/templates/pytorch_lightning/submit_command.sh @@ -3,5 +3,6 @@ grid run \ --framework lightning \ --gpus 2 \ tune_classification_basic.py \ ---learning_rate "uniform(1e-5, 1e-1, 5)" \ ---batch_size "[64, 128, 256]" +--learning_rate "[0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05)" \ +--batch_size "[64, 128, 256]"\ +--epochs 20 diff --git a/templates/pytorch_lightning/tune_classification_basic.py b/templates/pytorch_lightning/tune_classification_basic.py index a80fcc2..4c3dc7f 100644 --- a/templates/pytorch_lightning/tune_classification_basic.py +++ b/templates/pytorch_lightning/tune_classification_basic.py @@ -10,7 +10,6 @@ import torch from torch.utils.data import DataLoader -import torchmetrics from torchvision import transforms from torchvision import datasets from torch.utils.data.dataset import random_split @@ -22,6 +21,10 @@ def install(package): install("torchmetrics") + +import torchmetrics + + # Argparse helper parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)