diff --git a/README.md b/README.md index a3a5bee6..ec3156f6 100644 --- a/README.md +++ b/README.md @@ -61,8 +61,8 @@ Requires Python 3.8 or newer and PyTorch 1.13. Originally developed on Python 3. platform=cpu # Generate requirements files for specified PyTorch platform make torch-${platform} -# Install the project and core + train dependencies. Options: [train,test,bench,tune] -pip install -r requirements/core.${platform}.txt -e .[train] +# Install the project and core + train + test dependencies. Subsets: [train,test,bench,tune] +pip install -r requirements/core.${platform}.txt -e .[train,test] ``` #### Updating dependency version pins ```bash diff --git a/configs/main.yaml b/configs/main.yaml index 4f802417..fe890d34 100644 --- a/configs/main.yaml +++ b/configs/main.yaml @@ -37,7 +37,8 @@ trainer: #max_steps: 169680 # 20 epochs x 8484 steps (for batch size = 384, real data) max_epochs: 20 gradient_clip_val: 20 - gpus: 2 + accelerator: gpu + devices: 2 ckpt_path: null pretrained: null diff --git a/configs/tune.yaml b/configs/tune.yaml index e619f864..6d0a22fd 100644 --- a/configs/tune.yaml +++ b/configs/tune.yaml @@ -3,7 +3,7 @@ defaults: - _self_ trainer: - gpus: 1 # tuning with DDP is not yet supported. + devices: 1 # tuning with DDP is not yet supported. tune: num_samples: 10 diff --git a/requirements/bench.in b/requirements/bench.in index 5d74e2ec..9a4199cc 100644 --- a/requirements/bench.in +++ b/requirements/bench.in @@ -1,3 +1,4 @@ -c ${CONSTRAINTS} -hydra-core~=1.2.0 -fvcore + +hydra-core >=1.2.0 +fvcore >=0.1.5.post20220512 diff --git a/requirements/bench.txt b/requirements/bench.txt index 530b7a99..83d370d3 100644 --- a/requirements/bench.txt +++ b/requirements/bench.txt @@ -1,6 +1,6 @@ antlr4-python3-runtime==4.9.3 fvcore==0.1.5.post20221221 -hydra-core==1.2.0 +hydra-core==1.3.2 importlib-resources==5.12.0 iopath==0.1.10 numpy==1.24.3 diff --git a/requirements/constraints.txt b/requirements/constraints.txt index 9ce86bff..272d212e 100644 --- a/requirements/constraints.txt +++ b/requirements/constraints.txt @@ -1,7 +1,5 @@ --extra-index-url https://download.pytorch.org/whl/cpu -absl-py==1.4.0 - # via tensorboard aiohttp==3.8.4 # via fsspec aiosignal==1.3.1 @@ -27,8 +25,6 @@ backcall==0.2.0 # via ipython botorch==0.8.5 # via ax-platform -cachetools==5.3.1 - # via google-auth certifi==2023.5.7 # via requests charset-normalizer==3.1.0 @@ -71,21 +67,13 @@ fsspec==2023.5.0 # pytorch-lightning fvcore==0.1.5.post20221221 # via -r requirements/bench.in -google-auth==2.19.0 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==0.4.6 - # via tensorboard gpytorch==1.10 # via botorch grpcio==1.43.0 - # via - # ray - # tensorboard -huggingface-hub==0.14.1 + # via ray +huggingface-hub==0.15.1 # via timm -hydra-core==1.2.0 +hydra-core==1.3.2 # via # -r requirements/bench.in # -r requirements/tune.in @@ -102,9 +90,7 @@ imgaug==0.4.0 # -r requirements/train.in # -r requirements/tune.in importlib-metadata==6.6.0 - # via - # jupyter-client - # markdown + # via jupyter-client importlib-resources==5.12.0 # via # hydra-core @@ -142,6 +128,8 @@ kiwisolver==1.4.4 # via matplotlib lazy-loader==0.2 # via scikit-image +lightning-utilities==0.8.0 + # via pytorch-lightning linear-operator==0.4.0 # via # botorch @@ -150,12 +138,8 @@ lmdb==1.4.1 # via # -r requirements/test.in # -r requirements/tune.in -markdown==3.4.3 - # via tensorboard markupsafe==2.1.2 - # via - # jinja2 - # werkzeug + # via jinja2 matplotlib==3.7.1 # via imgaug matplotlib-inline==0.1.6 @@ -194,13 +178,10 @@ numpy==1.24.3 # scikit-learn # scipy # shapely - # tensorboard # tensorboardx # tifffile # torchmetrics # torchvision -oauthlib==3.2.2 - # via requests-oauthlib omegaconf==2.3.0 # via hydra-core opencv-python==4.7.0.72 @@ -212,6 +193,7 @@ packaging==23.1 # huggingface-hub # hydra-core # ipykernel + # lightning-utilities # matplotlib # plotly # pytorch-lightning @@ -250,11 +232,9 @@ portalocker==2.7.0 # via iopath prompt-toolkit==3.0.38 # via ipython -protobuf==3.20.1 +protobuf==3.20.3 # via - # pytorch-lightning # ray - # tensorboard # tensorboardx psutil==5.9.5 # via ipykernel @@ -262,14 +242,6 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 # via stack-data -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pydeprecate==0.3.2 - # via pytorch-lightning pygments==2.15.1 # via ipython pyparsing==3.0.9 @@ -285,7 +257,7 @@ python-dateutil==2.8.2 # jupyter-client # matplotlib # pandas -pytorch-lightning==1.6.5 +pytorch-lightning==1.9.5 # via -r requirements/core.in pytz==2023.3 # via pandas @@ -314,13 +286,9 @@ requests==2.31.0 # fsspec # huggingface-hub # ray - # requests-oauthlib - # tensorboard # torchvision -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rsa==4.9 - # via google-auth +safetensors==0.3.1 + # via timm scikit-image==0.20.0 # via imgaug scikit-learn==1.2.2 @@ -340,7 +308,6 @@ shapely==2.0.1 six==1.16.0 # via # asttokens - # google-auth # grpcio # imgaug # multipledispatch @@ -353,12 +320,6 @@ tabulate==0.9.0 # ray tenacity==8.2.2 # via plotly -tensorboard==2.11.2 - # via pytorch-lightning -tensorboard-data-server==0.6.1 - # via tensorboard -tensorboard-plugin-wit==1.8.1 - # via tensorboard tensorboardx==2.6 # via ray termcolor==2.3.0 @@ -367,7 +328,7 @@ threadpoolctl==3.1.0 # via scikit-learn tifffile==2023.4.12 # via scikit-image -timm==0.6.13 +timm==0.9.2 # via -r requirements/core.in torch==1.13.1+cpu # via @@ -414,24 +375,19 @@ typing-extensions==4.6.2 # huggingface-hub # iopath # ipython + # lightning-utilities # pytorch-lightning # torch # torchmetrics # torchvision tzdata==2023.3 # via pandas -urllib3==1.26.16 - # via - # google-auth - # requests +urllib3==2.0.2 + # via requests virtualenv==20.23.0 # via ray wcwidth==0.2.6 # via prompt-toolkit -werkzeug==2.3.4 - # via tensorboard -wheel==0.40.0 - # via tensorboard widgetsnbextension==4.0.7 # via ipywidgets yacs==0.1.8 @@ -442,7 +398,3 @@ zipp==3.15.0 # via # importlib-metadata # importlib-resources - -# The following packages are considered to be unsafe in a requirements file: -setuptools==67.8.0 - # via tensorboard diff --git a/requirements/core.in b/requirements/core.in index 491e1fc8..8be3a85d 100644 --- a/requirements/core.in +++ b/requirements/core.in @@ -1,7 +1,8 @@ -c ${CONSTRAINTS} + torch >=1.10.0, <2.0.0 torchvision >=0.11.0, <0.15.0 -timm~=0.6.5 -pytorch-lightning~=1.6.5 # TODO: refactor code to separate model from training code. -nltk # TODO: refactor/reorganize code. This is a train/test dependency. -PyYAML # TODO: can we move this to train/test? +timm >=0.6.5 +pytorch-lightning >=1.7.0, <2.0.0 # TODO: refactor code to separate model from training code. +nltk >=3.7.0 # TODO: refactor/reorganize code. This is a train/test dependency. +PyYAML >=6.0.0 # TODO: can we move this to train/test? diff --git a/requirements/core.txt b/requirements/core.txt index 0591d4c5..cd715fab 100644 --- a/requirements/core.txt +++ b/requirements/core.txt @@ -1,57 +1,34 @@ --extra-index-url https://download.pytorch.org/whl/cpu -absl-py==1.4.0 aiohttp==3.8.4 aiosignal==1.3.1 async-timeout==4.0.2 attrs==23.1.0 -cachetools==5.3.1 certifi==2023.5.7 charset-normalizer==3.1.0 click==8.0.4 filelock==3.12.0 frozenlist==1.3.3 fsspec[http]==2023.5.0 -google-auth==2.19.0 -google-auth-oauthlib==0.4.6 -grpcio==1.43.0 -huggingface-hub==0.14.1 +huggingface-hub==0.15.1 idna==3.4 -importlib-metadata==6.6.0 joblib==1.2.0 -markdown==3.4.3 -markupsafe==2.1.2 +lightning-utilities==0.8.0 multidict==6.0.4 nltk==3.8.1 numpy==1.24.3 -oauthlib==3.2.2 packaging==23.1 pillow==9.5.0 -protobuf==3.20.1 -pyasn1==0.5.0 -pyasn1-modules==0.3.0 -pydeprecate==0.3.2 -pytorch-lightning==1.6.5 +pytorch-lightning==1.9.5 pyyaml==6.0 regex==2023.5.5 requests==2.31.0 -requests-oauthlib==1.3.1 -rsa==4.9 -six==1.16.0 -tensorboard==2.11.2 -tensorboard-data-server==0.6.1 -tensorboard-plugin-wit==1.8.1 -timm==0.6.13 +safetensors==0.3.1 +timm==0.9.2 torch==1.13.1+cpu torchmetrics==0.11.4 torchvision==0.14.1+cpu tqdm==4.65.0 typing-extensions==4.6.2 -urllib3==1.26.16 -werkzeug==2.3.4 -wheel==0.40.0 +urllib3==2.0.2 yarl==1.9.2 -zipp==3.15.0 - -# The following packages are considered to be unsafe in a requirements file: -setuptools==67.8.0 diff --git a/requirements/test.in b/requirements/test.in index a237489a..b4e843e1 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -1,4 +1,5 @@ -c ${CONSTRAINTS} -lmdb -pillow -tqdm + +lmdb >=1.3.0 +Pillow >=9.2.0 +tqdm >=4.64.0 diff --git a/requirements/train.in b/requirements/train.in index ca43d32e..41ad6d26 100644 --- a/requirements/train.in +++ b/requirements/train.in @@ -1,5 +1,6 @@ -c ${CONSTRAINTS} -lmdb -pillow -imgaug -hydra-core~=1.2.0 + +lmdb >=1.3.0 +Pillow >=9.2.0 +imgaug >=0.4.0 +hydra-core >=1.2.0 diff --git a/requirements/train.txt b/requirements/train.txt index 5cbc8505..e68109e8 100644 --- a/requirements/train.txt +++ b/requirements/train.txt @@ -2,7 +2,7 @@ antlr4-python3-runtime==4.9.3 contourpy==1.0.7 cycler==0.11.0 fonttools==4.39.4 -hydra-core==1.2.0 +hydra-core==1.3.2 imageio==2.30.0 imgaug==0.4.0 importlib-resources==5.12.0 diff --git a/requirements/tune.in b/requirements/tune.in index e5a71afa..80a0ec07 100644 --- a/requirements/tune.in +++ b/requirements/tune.in @@ -1,7 +1,8 @@ -c ${CONSTRAINTS} -lmdb -pillow -imgaug -hydra-core~=1.2.0 -ray[tune]~=1.13.0 -ax-platform + +lmdb >=1.3.0 +Pillow >=9.2.0 +imgaug >=0.4.0 +hydra-core >=1.2.0 +ray[tune] >=1.13.0, <2.0.0 +ax-platform >=0.2.5.1 diff --git a/requirements/tune.txt b/requirements/tune.txt index 93be91a6..6f183ab0 100644 --- a/requirements/tune.txt +++ b/requirements/tune.txt @@ -20,7 +20,7 @@ fonttools==4.39.4 frozenlist==1.3.3 gpytorch==1.10 grpcio==1.43.0 -hydra-core==1.2.0 +hydra-core==1.3.2 idna==3.4 imageio==2.30.0 imgaug==0.4.0 @@ -61,7 +61,7 @@ pkgutil-resolve-name==1.3.10 platformdirs==3.5.1 plotly==5.14.1 prompt-toolkit==3.0.38 -protobuf==3.20.1 +protobuf==3.20.3 psutil==5.9.5 ptyprocess==0.7.0 pure-eval==0.2.2 @@ -94,7 +94,7 @@ traitlets==5.9.0 typeguard==2.13.3 typing-extensions==4.6.2 tzdata==2023.3 -urllib3==1.26.16 +urllib3==2.0.2 virtualenv==20.23.0 wcwidth==0.2.6 widgetsnbextension==4.0.7 diff --git a/strhub/data/augment.py b/strhub/data/augment.py index e5c1fb5a..21dc9d29 100644 --- a/strhub/data/augment.py +++ b/strhub/data/augment.py @@ -105,7 +105,7 @@ def rand_augment_transform(magnitude=5, num_layers=3): 'translate_x_pct': 0.10, 'translate_y_pct': 0.30 } - ra_ops = auto_augment.rand_augment_ops(magnitude, hparams, transforms=_RAND_TRANSFORMS) + ra_ops = auto_augment.rand_augment_ops(magnitude, hparams=hparams, transforms=_RAND_TRANSFORMS) # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice) choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))] return auto_augment.RandAugment(ra_ops, num_layers, choice_weights) diff --git a/train.py b/train.py index d97622b6..2b8e904d 100755 --- a/train.py +++ b/train.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import math from pathlib import Path from omegaconf import DictConfig, open_dict @@ -31,6 +31,23 @@ from strhub.models.utils import get_pretrained_weights +# Copied from OneCycleLR +def _annealing_cos(start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + +def get_swa_lr_factor(warmup_pct, swa_epoch_start, div_factor=25, final_div_factor=1e4) -> float: + """Get the SWA LR factor for the given `swa_epoch_start`. Assumes OneCycleLR Scheduler.""" + total_steps = 1000 # Can be anything. We use 1000 for convenience. + start_step = int(total_steps * warmup_pct) - 1 + end_step = total_steps - 1 + step_num = int(total_steps * swa_epoch_start) - 1 + pct = (step_num - start_step) / (end_step - start_step) + return _annealing_cos(1, 1 / (div_factor * final_div_factor), pct) + + @hydra.main(config_path='configs', config_name='main', version_base='1.2') def main(config: DictConfig): trainer_strategy = None @@ -38,19 +55,20 @@ def main(config: DictConfig): # Resolve absolute path to data.root_dir config.data.root_dir = hydra.utils.to_absolute_path(config.data.root_dir) # Special handling for GPU-affected config - gpus = config.trainer.get('gpus', 0) - if gpus: + gpu = config.trainer.get('accelerator') == 'gpu' + devices = config.trainer.get('devices', 0) + if gpu: # Use mixed-precision training config.trainer.precision = 16 - if gpus > 1: + if gpu and devices > 1: # Use DDP config.trainer.strategy = 'ddp' # DDP optimizations trainer_strategy = DDPStrategy(find_unused_parameters=False, gradient_as_bucket_view=True) # Scale steps-based config - config.trainer.val_check_interval //= gpus + config.trainer.val_check_interval //= devices if config.trainer.get('max_steps', -1) > 0: - config.trainer.max_steps //= gpus + config.trainer.max_steps //= devices # Special handling for PARseq if config.model.get('perm_mirrored', False): @@ -66,7 +84,9 @@ def main(config: DictConfig): checkpoint = ModelCheckpoint(monitor='val_accuracy', mode='max', save_top_k=3, save_last=True, filename='{epoch}-{step}-{val_accuracy:.4f}-{val_NED:.4f}') - swa = StochasticWeightAveraging(swa_epoch_start=0.75) + swa_epoch_start = 0.75 + swa_lr = config.model.lr * get_swa_lr_factor(config.model.warmup_pct, swa_epoch_start) + swa = StochasticWeightAveraging(swa_lr, swa_epoch_start) cwd = HydraConfig.get().runtime.output_dir if config.ckpt_path is None else \ str(Path(config.ckpt_path).parents[1].absolute()) trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=TensorBoardLogger(cwd, '', '.'),