Skip to content

Commit

Permalink
Upgrade to newest version of pytorch (#1070)
Browse files Browse the repository at this point in the history
* Update to newest version of pytorch:
* A couple of interface changes when creating the resnet model, in the Trainer and the model summary
* Updated the yaml files to remove pytorch versions
* Bumped version
  • Loading branch information
DocGarbanzo authored Dec 15, 2022
1 parent f8dc720 commit 0b74a78
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 25 deletions.
21 changes: 10 additions & 11 deletions donkeycar/parts/pytorch/torch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
import torch
import pytorch_lightning as pl
from pytorch_lightning.utilities.model_summary import summarize
from donkeycar.parts.pytorch.torch_data import TorchTubDataModule
from donkeycar.parts.pytorch.torch_utils import get_model_by_type

Expand All @@ -16,21 +17,17 @@ def train(cfg, tub_paths, model_output_path, model_type, checkpoint_path=None):
if is_torch_model:
model = f'{model_name}.ckpt'
else:
print("Unrecognized model file extension for model_output_path: '{}'. Please use the '.ckpt' extension.".format(
model_output_path))

print(f"Unrecognized model file extension for model_output_path: '"
f"{model_output_path}'. Please use the '.ckpt' extension.")

if not model_type:
model_type = cfg.DEFAULT_MODEL_TYPE

tubs = tub_paths.split(',')
tub_paths = [os.path.expanduser(tub) for tub in tubs]
output_path = os.path.expanduser(model_output_path)

output_dir = Path(model_output_path).parent

output_dir = str(Path(model_output_path).parent)
model = get_model_by_type(model_type, cfg, checkpoint_path=checkpoint_path)

if torch.cuda.is_available():
print('Using CUDA')
gpus = -1
Expand All @@ -40,15 +37,17 @@ def train(cfg, tub_paths, model_output_path, model_type, checkpoint_path=None):

logger = None
if cfg.VERBOSE_TRAIN:
print("Tensorboard logging started. Run `tensorboard --logdir ./tb_logs` in a new terminal")
print("Tensorboard logging started. Run `tensorboard --logdir "
"./tb_logs` in a new terminal")
from pytorch_lightning.loggers import TensorBoardLogger

# Create Tensorboard logger
logger = TensorBoardLogger('tb_logs', name=model_name)

weights_summary = 'full' if cfg.PRINT_MODEL_SUMMARY else 'top'
trainer = pl.Trainer(gpus=gpus, logger=logger, progress_bar_refresh_rate=30,
max_epochs=cfg.MAX_EPOCHS, default_root_dir=output_dir, weights_summary=weights_summary)
if cfg.PRINT_MODEL_SUMMARY:
summarize(model)
trainer = pl.Trainer(gpus=gpus, logger=logger, max_epochs=cfg.MAX_EPOCHS,
default_root_dir=output_dir)

data_module = TorchTubDataModule(cfg, tub_paths)
trainer.fit(model, data_module)
Expand Down
3 changes: 1 addition & 2 deletions donkeycar/tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ def test_training_pipeline(config: Config, model_type: str, car_dir: str) \
gpus = 0

# Overfit the data
trainer = pl.Trainer(gpus=gpus, overfit_batches=2,
progress_bar_refresh_rate=30, max_epochs=30)
trainer = pl.Trainer(gpus=gpus, overfit_batches=2, max_epochs=30)
trainer.fit(model, data_module)
final_loss = model.loss_history[-1]
assert final_loss < 0.35, \
Expand Down
6 changes: 3 additions & 3 deletions install/envs/mac.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ channels:

dependencies:
- python=3.7
- numpy=1.19
- h5py
- pillow
- opencv
Expand All @@ -26,11 +27,10 @@ dependencies:
- PrettyTable
- pyfiglet
- mypy
- pytorch=1.7.1
- torchvision
- pytorch
- torchvision=0.12
- torchaudio
- pytorch-lightning
- numpy
- psutil
- kivy=2.0.0
- plotly
Expand Down
6 changes: 3 additions & 3 deletions install/envs/ubuntu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ channels:

dependencies:
- python=3.7
- numpy=1.19
- h5py
- pillow
- opencv
Expand All @@ -27,11 +28,10 @@ dependencies:
- PrettyTable
- pyfiglet
- mypy
- pytorch=1.7.1
- torchvision
- pytorch
- torchvision=0.12
- torchaudio
- pytorch-lightning
- numpy
- psutil
- kivy=2.0.0
- plotly
Expand Down
6 changes: 3 additions & 3 deletions install/envs/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ channels:

dependencies:
- python=3.7
- numpy=1.19
- h5py
- pillow
- opencv
Expand All @@ -27,11 +28,10 @@ dependencies:
- PrettyTable
- pyfiglet
- mypy
- pytorch=1.7.1
- torchvision
- pytorch
- torchvision=0.12
- torchaudio
- pytorch-lightning
- numpy
- kivy=2.0.0
- plotly
- pyyaml
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def package_files(directory, strip_leading):
long_description = fh.read()

setup(name='donkeycar',
version="4.4.1-main",
version="4.4.2-main",
long_description=long_description,
description='Self driving library for python.',
url='https://github.com/autorope/donkeycar',
Expand Down Expand Up @@ -89,8 +89,8 @@ def package_files(directory, strip_leading):
'ci': ['codecov'],
'tf': ['tensorflow==2.2.0'],
'torch': [
'pytorch>=1.7.1',
'torchvision',
'pytorch',
'torchvision==0.12',
'torchaudio',
'fastai'
],
Expand Down

0 comments on commit 0b74a78

Please sign in to comment.