Skip to content

Commit

Permalink
fix(exec): reimplement fast_dev_run with ckpt
Browse files Browse the repository at this point in the history
Closes #15.
  • Loading branch information
tbung committed Apr 24, 2023
1 parent 7d8d9d1 commit 16aa3a0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
32 changes: 28 additions & 4 deletions fd_shifts/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,20 +80,36 @@ def train(
else 1
)

limit_batches: float | int = 1.0
num_epochs = cf.trainer.num_epochs
val_every_n_epoch = cf.trainer.val_every_n_epoch

if isinstance(cf.trainer.fast_dev_run, bool):
limit_batches = 1 if cf.trainer.fast_dev_run else 1.0
num_epochs = 1 if cf.trainer.fast_dev_run else num_epochs
max_steps = 1 if cf.trainer.fast_dev_run else max_steps
val_every_n_epoch = 1 if cf.trainer.fast_dev_run else val_every_n_epoch
if isinstance(cf.trainer.fast_dev_run, int):
limit_batches = cf.trainer.fast_dev_run
max_steps = cf.trainer.fast_dev_run
val_every_n_epoch = 1
num_epochs = 1

trainer = pl.Trainer(
accelerator="auto",
devices="auto",
logger=[tb_logger, csv_logger],
max_epochs=cf.trainer.num_epochs,
max_epochs=num_epochs,
max_steps=max_steps,
callbacks=[progress] + get_callbacks(cf),
resume_from_checkpoint=resume_ckpt_path,
benchmark=cf.trainer.benchmark,
check_val_every_n_epoch=cf.trainer.val_every_n_epoch,
fast_dev_run=cf.trainer.fast_dev_run,
check_val_every_n_epoch=val_every_n_epoch,
num_sanity_val_steps=5,
deterministic=train_deterministic_flag,
limit_val_batches=0 if cf.trainer.do_val is False else 1.0,
limit_train_batches=limit_batches,
limit_val_batches=0 if cf.trainer.do_val is False else limit_batches,
limit_test_batches=limit_batches,
gradient_clip_val=1,
accumulate_grad_batches=accumulate_grad_batches,
)
Expand Down Expand Up @@ -138,11 +154,19 @@ def test(cf: configs.Config, progress: RichProgressBar = RichProgressBar()) -> N
if not os.path.exists(cf.test.dir):
os.makedirs(cf.test.dir)

limit_batches: float | int = 1.0

if isinstance(cf.trainer.fast_dev_run, bool):
limit_batches = 1 if cf.trainer.fast_dev_run else 1.0
if isinstance(cf.trainer.fast_dev_run, int):
limit_batches = cf.trainer.fast_dev_run

trainer = pl.Trainer(
accelerator="auto",
devices="auto",
logger=False,
callbacks=[progress] + get_callbacks(cf),
limit_test_batches=limit_batches,
replace_sampler_ddp=False,
)
trainer.test(model=module, datamodule=datamodule)
Expand Down
11 changes: 8 additions & 3 deletions fd_shifts/tests/test_reproducible.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,24 @@
from fd_shifts.tests.utils import mock_env_if_missing


def _update_overrides(overrides: dict[str, Any]) -> dict[str, Any]:
def _update_overrides_fast(overrides: dict[str, Any]) -> dict[str, Any]:
# HACK: This is highly machine dependend!
max_batch_size = 16

overrides["trainer.fast_dev_run"] = 5
accum = overrides.get("trainer.batch_size", 128) // max_batch_size
overrides["trainer.batch_size"] = max_batch_size
overrides["trainer.accumulate_grad_batches"] = accum

# HACK: Have to disable these because they do not handle limited batches
overrides["eval.query_studies.noise_study"] = []
overrides["eval.query_studies.in_class_study"] = []
overrides["eval.query_studies.new_class_study"] = []
return overrides


@pytest.mark.slow
def test_integration(mock_env_if_missing):
def test_small_heuristic_run(mock_env_if_missing):
# TODO: Test multiple with fixture
name = "fd-shifts/cifar100_paper_sweep/confidnet_bbvgg13_do0_run1_rew2.2"
# TODO: Also run some form of inference. Maybe generate outputs on main branch instead of using full experiments?
Expand All @@ -41,7 +46,7 @@ def test_integration(mock_env_if_missing):
f"{str(experiment.to_path()).replace('/', '_').replace('.','_')}"
)

overrides = _update_overrides(experiment.overrides())
overrides = _update_overrides_fast(experiment.overrides())

cmd = BASH_BASE_COMMAND.format(
overrides=" ".join(f"{k}={v}" for k, v in overrides.items()),
Expand Down

0 comments on commit 16aa3a0

Please sign in to comment.