Skip to content

Commit

Permalink
Remove old interface and deprecate the arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
JQGoh committed Dec 11, 2024
1 parent c0a7eb8 commit 3924876
Show file tree
Hide file tree
Showing 73 changed files with 75 additions and 1,076 deletions.
88 changes: 38 additions & 50 deletions nbs/common.base_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,24 @@
"execution_count": null,
"id": "1c7c2ba5-19ee-421e-9252-7224b03f5201",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda3/envs/neuralforecast/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"2024-12-11 17:06:11,409\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
"2024-12-11 17:06:11,467\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n"
]
}
],
"source": [
"#| export\n",
"import inspect\n",
"import random\n",
"import warnings\n",
"from contextlib import contextmanager\n",
"from copy import deepcopy\n",
"from dataclasses import dataclass\n",
"\n",
"import fsspec\n",
Expand Down Expand Up @@ -121,10 +131,6 @@
" random_seed,\n",
" loss,\n",
" valid_loss,\n",
" optimizer,\n",
" optimizer_kwargs,\n",
" lr_scheduler,\n",
" lr_scheduler_kwargs,\n",
" futr_exog_list,\n",
" hist_exog_list,\n",
" stat_exog_list,\n",
Expand All @@ -150,18 +156,6 @@
" self.train_trajectories = []\n",
" self.valid_trajectories = []\n",
"\n",
" # Optimization\n",
" if optimizer is not None and not issubclass(optimizer, torch.optim.Optimizer):\n",
" raise TypeError(\"optimizer is not a valid subclass of torch.optim.Optimizer\")\n",
" self.optimizer = optimizer\n",
" self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}\n",
"\n",
" # lr scheduler\n",
" if lr_scheduler is not None and not issubclass(lr_scheduler, torch.optim.lr_scheduler.LRScheduler):\n",
" raise TypeError(\"lr_scheduler is not a valid subclass of torch.optim.lr_scheduler.LRScheduler\")\n",
" self.lr_scheduler = lr_scheduler\n",
" self.lr_scheduler_kwargs = lr_scheduler_kwargs if lr_scheduler_kwargs is not None else {}\n",
"\n",
" # customized by set_configure_optimizers()\n",
" self.config_optimizers = None\n",
"\n",
Expand Down Expand Up @@ -412,41 +406,19 @@
"\n",
" def configure_optimizers(self):\n",
" if self.config_optimizers is not None:\n",
" # return the customized optimizer settings if specified\n",
" return self.config_optimizers\n",
" \n",
" if self.optimizer:\n",
" optimizer_signature = inspect.signature(self.optimizer)\n",
" optimizer_kwargs = deepcopy(self.optimizer_kwargs)\n",
" if 'lr' in optimizer_signature.parameters:\n",
" if 'lr' in optimizer_kwargs:\n",
" warnings.warn(\"ignoring learning rate passed in optimizer_kwargs, using the model's learning rate\")\n",
" optimizer_kwargs['lr'] = self.learning_rate\n",
" optimizer = self.optimizer(params=self.parameters(), **optimizer_kwargs)\n",
" else:\n",
" if self.optimizer_kwargs:\n",
" warnings.warn(\n",
" \"ignoring optimizer_kwargs as the optimizer is not specified\"\n",
" )\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" \n",
" lr_scheduler = {'frequency': 1, 'interval': 'step'}\n",
" if self.lr_scheduler:\n",
" lr_scheduler_signature = inspect.signature(self.lr_scheduler)\n",
" lr_scheduler_kwargs = deepcopy(self.lr_scheduler_kwargs)\n",
" if 'optimizer' in lr_scheduler_signature.parameters:\n",
" if 'optimizer' in lr_scheduler_kwargs:\n",
" warnings.warn(\"ignoring optimizer passed in lr_scheduler_kwargs, using the model's optimizer\")\n",
" del lr_scheduler_kwargs['optimizer']\n",
" lr_scheduler['scheduler'] = self.lr_scheduler(optimizer=optimizer, **lr_scheduler_kwargs)\n",
" else:\n",
" if self.lr_scheduler_kwargs:\n",
" warnings.warn(\n",
" \"ignoring lr_scheduler_kwargs as the lr_scheduler is not specified\"\n",
" ) \n",
" lr_scheduler['scheduler'] = torch.optim.lr_scheduler.StepLR(\n",
" # default choice\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
" scheduler = {\n",
" \"scheduler\": torch.optim.lr_scheduler.StepLR(\n",
" optimizer=optimizer, step_size=self.lr_decay_steps, gamma=0.5\n",
" )\n",
" return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler}\n",
" ),\n",
" \"frequency\": 1,\n",
" \"interval\": \"step\",\n",
" }\n",
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler}\n",
"\n",
" def set_configure_optimizers(\n",
" self, \n",
Expand Down Expand Up @@ -528,6 +500,22 @@
" model.load_state_dict(content[\"state_dict\"], strict=True)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "077ea025",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "2b36e87a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
10 changes: 1 addition & 9 deletions nbs/common.base_multivariate.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,12 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs, \n",
" valid_loss=valid_loss, \n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
Expand Down
8 changes: 0 additions & 8 deletions nbs/common.base_recurrent.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,12 @@
" drop_last_loader=False,\n",
" random_seed=1, \n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
Expand Down
8 changes: 0 additions & 8 deletions nbs/common.base_windows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,20 +115,12 @@
" drop_last_loader=False,\n",
" random_seed=1,\n",
" alias=None,\n",
" optimizer=None,\n",
" optimizer_kwargs=None,\n",
" lr_scheduler=None,\n",
" lr_scheduler_kwargs=None,\n",
" dataloader_kwargs=None,\n",
" **trainer_kwargs):\n",
" super().__init__(\n",
" random_seed=random_seed,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" optimizer=optimizer,\n",
" optimizer_kwargs=optimizer_kwargs,\n",
" lr_scheduler=lr_scheduler,\n",
" lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
" futr_exog_list=futr_exog_list,\n",
" hist_exog_list=hist_exog_list,\n",
" stat_exog_list=stat_exog_list,\n",
Expand Down
Loading

0 comments on commit 3924876

Please sign in to comment.