Skip to content

Commit

Permalink
Support comma-separated device string
Browse files Browse the repository at this point in the history
  • Loading branch information
jamt9000 committed Aug 17, 2024
1 parent 94e0c7b commit a9ade66
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def cli_main():
"--device",
default=None,
type=str,
help="indices of GPUs to enable (default: None)",
help="comma-separated indices of GPUs to enable (default: None)",
)
parser.add_argument(
"--num_workers",
Expand Down Expand Up @@ -209,15 +209,26 @@ def get_instance(module, name, config, *args, **kwargs):
monitor="val_loss",
mode="min",
)

if args.device is None:
devices = "auto"
else:
devices = [int(d.strip()) for d in args.device.split(",")]

trainer = pl.Trainer(
devices=[int(args.device)] if args.device is not None else "auto",
devices=devices,
max_epochs=args.n_epochs,
accumulate_grad_batches=config["accumulate_grad_batches"],
callbacks=[checkpoint_callback],
default_root_dir="saved/" + config["name"],
deterministic=True,
)
trainer.fit(model=model, train_dataloaders=data_loader, val_dataloaders=valid_data_loader, ckpt_path=args.resume)
trainer.fit(
model=model,
train_dataloaders=data_loader,
val_dataloaders=valid_data_loader,
ckpt_path=args.resume,
)


if __name__ == "__main__":
Expand Down

0 comments on commit a9ade66

Please sign in to comment.