-
Notifications
You must be signed in to change notification settings - Fork 9
/
app_hpo.py
42 lines (30 loc) · 1.33 KB
/
app_hpo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os.path as ops
import optuna
from lightning.app import LightningFlow, CloudCompute, LightningApp
from lightning_hpo import BaseObjective, Optimizer
from quick_start.components import ImageServeGradio, PyTorchLightningScript
class HPOPyTorchLightningScript(PyTorchLightningScript, BaseObjective):
@staticmethod
def distributions():
return {"model.lr": optuna.distributions.LogUniformDistribution(0.0001, 0.1)}
class TrainDeploy(LightningFlow):
def __init__(self):
super().__init__()
self.train_work = Optimizer(
script_path=ops.join(ops.dirname(__file__), "./train_script.py"),
script_args=["--trainer.max_epochs=5"],
objective_cls=HPOPyTorchLightningScript,
n_trials=4,
)
self.serve_work = ImageServeGradio(CloudCompute("cpu"))
def run(self):
# 1. Run the python script that trains the model
self.train_work.run()
# 2. when a checkpoint is available, deploy
if self.train_work.best_model_path:
self.serve_work.run(self.train_work.best_model_path)
def configure_layout(self):
tab_1 = {"name": "Model training", "content": self.train_work.hi_plot}
tab_2 = {"name": "Interactive demo", "content": self.serve_work}
return [tab_1, tab_2]
app = LightningApp(TrainDeploy())