diff --git a/tutorials/early_stopping/early_stopping.ipynb b/tutorials/early_stopping/early_stopping.ipynb index 8447b49c1cc..0ed7a462a7a 100644 --- a/tutorials/early_stopping/early_stopping.ipynb +++ b/tutorials/early_stopping/early_stopping.ipynb @@ -1,666 +1,669 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "id": "12fe3797", - "metadata": {}, - "source": [ - "## Trial-level early stopping in Ax\n", - "\n", - "This tutorial illustrates how to add a trial-level early stopping strategy to an Ax hyper-parameter optimization (HPO) loop. The goal of trial-level early stopping is to monitor the results of expensive evaluations and terminate those that are unlikely to produce promising results, freeing up resources to explore more configurations.\n", - "\n", - "Most of this tutorial is adapted from the [PyTorch Ax Multiobjective NAS Tutorial](https://pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html). The training job is different from the original in that we do not optimize `batch_size` or `epochs`. This was done for illustrative purposes, as each validation curve now has the same number of points. The companion training file `mnist_train_nas.py` has also been altered to log to Tensorboard during training.\n", - "\n", - "NOTE: Although the original NAS tutorial is for a multi-objective problem, this tutorial focuses on a single objective (validation accuracy) problem. Early stopping currently does not support \\\"true\\\" multi-objective stopping, although one can use [logical compositions of early stopping strategies](https://github.com/facebook/Ax/blob/main/ax/early_stopping/strategies/logical.py) to target multiple objectives separately. Early stopping for the multi-objective case is currently a work in progress." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb953f30", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import tempfile\n", - "\n", - "from pathlib import Path\n", - "\n", - "import torchx\n", - "\n", - "from ax.core import Experiment, Objective, ParameterType, RangeParameter, SearchSpace\n", - "from ax.core.optimization_config import OptimizationConfig\n", - "\n", - "from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy\n", - "from ax.metrics.tensorboard import TensorboardCurveMetric\n", - "\n", - "from ax.modelbridge.dispatch_utils import choose_generation_strategy\n", - "\n", - "from ax.runners.torchx import TorchXRunner\n", - "\n", - "from ax.service.scheduler import Scheduler, SchedulerOptions\n", - "from ax.service.utils.report_utils import exp_to_df\n", - "\n", - "from torchx import specs\n", - "from torchx.components import utils\n", - "\n", - "from matplotlib import pyplot as plt\n", - "\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8a7bd328", - "metadata": {}, - "outputs": [], - "source": [ - "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "fe2cf6fe", - "metadata": {}, - "source": [ - "## Defining the TorchX App\n", - "\n", - "Our goal is to optimize the PyTorch Lightning training job defined in\n", - "[mnist_train_nas.py](https://github.com/pytorch/tutorials/tree/master/intermediate_source/mnist_train_nas.py)_.\n", - "To do this using TorchX, we write a helper function that takes in\n", - "the values of the architcture and hyperparameters of the training\n", - "job and creates a [TorchX AppDef](https://pytorch.org/torchx/latest/basics.html)_\n", - "with the appropriate settings.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e21d309", - "metadata": {}, - "outputs": [], - "source": [ - "if SMOKE_TEST:\n", - " epochs = 3\n", - "else:\n", - " epochs = 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b423923c", - "metadata": {}, - "outputs": [], - "source": [ - "def trainer(\n", - " log_path: str,\n", - " hidden_size_1: int,\n", - " hidden_size_2: int,\n", - " learning_rate: float,\n", - " dropout: float,\n", - " trial_idx: int = -1,\n", - ") -> specs.AppDef:\n", - "\n", - " # define the log path so we can pass it to the TorchX AppDef\n", - " if trial_idx >= 0:\n", - " log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()\n", - "\n", - " batch_size = 32\n", - "\n", - " return utils.python(\n", - " # command line args to the training script\n", - " \"--log_path\",\n", - " log_path,\n", - " \"--hidden_size_1\",\n", - " str(hidden_size_1),\n", - " \"--hidden_size_2\",\n", - " str(hidden_size_2),\n", - " \"--learning_rate\",\n", - " str(learning_rate),\n", - " \"--epochs\",\n", - " str(epochs),\n", - " \"--dropout\",\n", - " str(dropout),\n", - " \"--batch_size\",\n", - " str(batch_size),\n", - " # other config options\n", - " name=\"trainer\",\n", - " script=\"tutorials/early_stopping/mnist_train_nas.py\",\n", - " image=torchx.version.TORCHX_IMAGE,\n", - " )" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "65f7011d", - "metadata": {}, - "source": [ - "## Setting up the Runner\n", - "\n", - "Ax’s [Runner](https://ax.dev/api/core.html#ax.core.runner.Runner)\n", - "abstraction allows writing interfaces to various backends.\n", - "Ax already comes with Runner for TorchX, so we just need to\n", - "configure it. For the purpose of this tutorial, we run jobs locally\n", - "in a fully asynchronous fashion. In order to launch them on a cluster, you can instead specify a\n", - "different TorchX scheduler and adjust the configuration appropriately.\n", - "For example, if you have a Kubernetes cluster, you just need to change the\n", - "scheduler from ``local_cwd`` to ``kubernetes``.\n", - "\n", - "The training job launched by this runner will log partial results to Tensorboard, which will then be monitored by the early stopping strategy. We will show how this is done using an Ax \n", - "[TensorboardCurveMetric](https://ax.dev/api/metrics.html#module-ax.metrics.tensorboard) below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "998e5835", - "metadata": {}, - "outputs": [], - "source": [ - "# Make a temporary dir to log our results into\n", - "log_dir = tempfile.mkdtemp()\n", - "\n", - "ax_runner = TorchXRunner(\n", - " tracker_base=\"/tmp/\",\n", - " component=trainer,\n", - " # NOTE: To launch this job on a cluster instead of locally you can\n", - " # specify a different scheduler and adjust args appropriately.\n", - " scheduler=\"local_cwd\",\n", - " component_const_params={\"log_path\": log_dir},\n", - " cfg={},\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "2fec7495", - "metadata": {}, - "source": [ - "## Setting up the SearchSpace\n", - "\n", - "First, we define our search space. Ax supports both range parameters\n", - "of type integer and float as well as choice parameters which can have\n", - "non-numerical types such as strings.\n", - "We will tune the hidden sizes, learning rate, and dropout parameters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cf6f869f", - "metadata": {}, - "outputs": [], - "source": [ - "parameters = [\n", - " # NOTE: In a real-world setting, hidden_size_1 and hidden_size_2\n", - " # should probably be powers of 2, but in our simple example this\n", - " # would mean that num_params can't take on that many values, which\n", - " # in turn makes the Pareto frontier look pretty weird.\n", - " RangeParameter(\n", - " name=\"hidden_size_1\",\n", - " lower=16,\n", - " upper=128,\n", - " parameter_type=ParameterType.INT,\n", - " log_scale=True,\n", - " ),\n", - " RangeParameter(\n", - " name=\"hidden_size_2\",\n", - " lower=16,\n", - " upper=128,\n", - " parameter_type=ParameterType.INT,\n", - " log_scale=True,\n", - " ),\n", - " RangeParameter(\n", - " name=\"learning_rate\",\n", - " lower=1e-4,\n", - " upper=1e-2,\n", - " parameter_type=ParameterType.FLOAT,\n", - " log_scale=True,\n", - " ),\n", - " RangeParameter(\n", - " name=\"dropout\",\n", - " lower=0.0,\n", - " upper=0.5,\n", - " parameter_type=ParameterType.FLOAT,\n", - " ),\n", - "]\n", - "\n", - "search_space = SearchSpace(\n", - " parameters=parameters,\n", - " # NOTE: In practice, it may make sense to add a constraint\n", - " # hidden_size_2 <= hidden_size_1\n", - " parameter_constraints=[],\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "a8005e80", - "metadata": {}, - "source": [ - "## Setting up Metrics\n", - "\n", - "Ax has the concept of a Metric that defines properties of outcomes and how observations are obtained for these outcomes. This allows e.g. encodig how data is fetched from some distributed execution backend and post-processed before being passed as input to Ax.\n", - "\n", - "We will optimize the validation accuracy, which is a `TensorboardCurveMetric` that points to the logging directory assigned above. Note that we have set `is_available_while_running`, allowing for the metric to be queried as the trial progresses. This is critical for the early stopping strategy to monitor partial results." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0775a96e", - "metadata": {}, - "outputs": [], - "source": [ - "class MyTensorboardMetric(TensorboardCurveMetric):\n", - "\n", - " # NOTE: We need to tell the new Tensorboard metric how to get the id /\n", - " # file handle for the tensorboard logs from a trial. In this case\n", - " # our convention is to just save a separate file per trial in\n", - " # the pre-specified log dir.\n", - " @classmethod\n", - " def get_ids_from_trials(cls, trials):\n", - " return {\n", - " trial.index: Path(log_dir).joinpath(str(trial.index)).as_posix()\n", - " for trial in trials\n", - " }\n", - "\n", - " # This indicates whether the metric is queryable while the trial is\n", - " # still running. This is required for early stopping to monitor the\n", - " # progress of the running trial.ArithmeticError\n", - " @classmethod\n", - " def is_available_while_running(cls):\n", - " return True" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a5c5a7d0", - "metadata": {}, - "outputs": [], - "source": [ - "val_acc = MyTensorboardMetric(\n", - " name=\"val_acc\",\n", - " curve_name=\"val_acc\",\n", - " lower_is_better=False,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "d4f3ba5d", - "metadata": {}, - "source": [ - "## Setting up the OptimizationConfig\n", - "\n", - "The `OptimizationConfig` specifies the objective for Ax to optimize." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ada66cf3", - "metadata": {}, - "outputs": [], - "source": [ - "opt_config = OptimizationConfig(\n", - " objective=Objective(\n", - " metric=val_acc,\n", - " minimize=False,\n", - " )\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "57aa9cf7", - "metadata": {}, - "source": [ - "## Defining an Early Stopping Strategy\n", - "\n", - "A `PercentileEarlyStoppingStrategy` is a simple method that stops a trial if its performance falls below a certain percentile of other trials at the same step (e.g., when `percentile_threshold` is 50, at a given point in time, if a trial ranks in the bottom 50% of trials, it is stopped). \n", - "- We make use of `normalize_progressions` which normalizes the progression column (e.g. timestamp, epochs, training data used) to be in [0, 1]. This is useful because one doesn't need to know the maximum progression values of the curve (which might be, e.g., the total number of data points in the training dataset).\n", - "- The `min_progression` parameter specifies that trials should only be considered for stopping if the latest progression value is greater than this threshold.\n", - "- The `min_curves` parameter specifies the minimum number of completed curves (i.e., fully completed training jobs) before early stopping will be considered. This should be larger than zero if `normalize_progression` is used. In general, we want a few completed curves to have a baseline for comparison.\n", - "\n", - "Note that `PercentileEarlyStoppingStrategy` does not make use of learning curve modeling or prediction. More sophisticated model-based methods will be available in future versions of Ax." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "949e8ab5", - "metadata": {}, - "outputs": [], - "source": [ - "percentile_early_stopping_strategy = PercentileEarlyStoppingStrategy(\n", - " # stop if in bottom 70% of runs at the same progression\n", - " percentile_threshold=70,\n", - " # the trial must have passed `min_progression` steps before early stopping is initiated\n", - " # note that we are using `normalize_progressions`, so this is on a scale of [0, 1]\n", - " min_progression=0.3,\n", - " # there must be `min_curves` completed trials and `min_curves` trials reporting data in\n", - " # order for early stopping to be applicable\n", - " min_curves=5,\n", - " # specify, e.g., [0, 1] if the first two trials should never be stopped\n", - " trial_indices_to_ignore=None,\n", - " # check for new data every 10 seconds\n", - " seconds_between_polls=10,\n", - " normalize_progressions=True,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "2665ca93", - "metadata": {}, - "source": [ - "## Creating the Ax Experiment\n", - "\n", - "In Ax, the Experiment object is the object that stores all the information about the problem setup." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12849b31", - "metadata": {}, - "outputs": [], - "source": [ - "experiment = Experiment(\n", - " name=\"torchx_mnist\",\n", - " search_space=search_space,\n", - " optimization_config=opt_config,\n", - " runner=ax_runner,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "49a4ed0e", - "metadata": {}, - "source": [ - "## Choosing the GenerationStrategy\n", - "\n", - "A [GenerationStrategy](https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStrategy)\n", - "is the abstract representation of how we would like to perform the\n", - "optimization. While this can be customized (if you’d like to do so, see\n", - "[this tutorial](https://ax.dev/tutorials/generation_strategy.html)),\n", - "in most cases Ax can automatically determine an appropriate strategy\n", - "based on the search space, optimization config, and the total number\n", - "of trials we want to run.\n", - "\n", - "Typically, Ax chooses to evaluate a number of random configurations\n", - "before starting a model-based Bayesian Optimization strategy.\n", - "\n", - "We remark that in Ax, generation strategies and early stopping strategies are separate, a design decision motivated by ease-of-use. However, we should acknowledge that jointly considering generation and stopping using a single strategy would likely be the \"proper\" formulation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e38d0237", - "metadata": {}, - "outputs": [], - "source": [ - "if SMOKE_TEST:\n", - " total_trials = 6\n", - "else:\n", - " total_trials = 15 # total evaluation budget\n", - "\n", - "gs = choose_generation_strategy(\n", - " search_space=experiment.search_space,\n", - " optimization_config=experiment.optimization_config,\n", - " num_trials=total_trials,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "78d86fea", - "metadata": {}, - "source": [ - "## Configuring the Scheduler\n", - "\n", - "The `Scheduler` acts as the loop control for the optimization.\n", - "It communicates with the backend to launch trials, check their status, retrieve (partial) results, and importantly for this tutorial, calls the early stopping strategy. If the early stopping strategy suggests a trial to be the stopped, the `Scheduler` communicates with the backend to terminate the trial.\n", - "\n", - "The ``Scheduler`` requires the ``Experiment`` and the ``GenerationStrategy``.\n", - "A set of options can be passed in via ``SchedulerOptions``. Here, we\n", - "configure the number of total evaluations as well as ``max_pending_trials``,\n", - "the maximum number of trials that should run concurrently. In our\n", - "local setting, this is the number of training jobs running as individual\n", - "processes, while in a remote execution setting, this would be the number\n", - "of machines you want to use in parallel.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "499fb9b5", - "metadata": {}, - "outputs": [], - "source": [ - "scheduler = Scheduler(\n", - " experiment=experiment,\n", - " generation_strategy=gs,\n", - " options=SchedulerOptions(\n", - " total_trials=total_trials,\n", - " max_pending_trials=5,\n", - " early_stopping_strategy=percentile_early_stopping_strategy,\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "78257ebb", - "metadata": {}, - "outputs": [], - "source": [ - "%%time\n", - "scheduler.run_all_trials()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "8c5afbe8", - "metadata": {}, - "source": [ - "## Results\n", - "\n", - "First, we examine the data stored on the experiment. This shows that each trial is associated with an entire learning curve, represented by the column \"steps\"." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "764365f0", - "metadata": {}, - "outputs": [], - "source": [ - "experiment.lookup_data().map_df.head(n=10)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "0033ed2e", - "metadata": {}, - "source": [ - "Below is a summary of the experiment, showing that a portion of trials have been early stopped." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "00f2b35f", - "metadata": {}, - "outputs": [], - "source": [ - "exp_to_df(experiment)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "f8909cf2", - "metadata": {}, - "source": [ - "We can give a very rough estimate of the amount of computational savings due to early stopping, by looking at the total number of steps used when early stopping is used versus the number of steps used if we ran all trials to completion. Note to do a true comparison, one should run full HPO loops with and without early stopping (as early stopping will influence the model and future points selected by the generation strategy). " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5abb3ce8", - "metadata": {}, - "outputs": [], - "source": [ - "map_df = experiment.lookup_data().map_df\n", - "trial_to_max_steps = map_df.groupby(\"trial_index\")[\"steps\"].max()\n", - "completed_trial_steps = trial_to_max_steps.iloc[0]\n", - "savings = 1.0 - trial_to_max_steps.sum() / (\n", - " completed_trial_steps * len(trial_to_max_steps)\n", - ")\n", - "# TODO format nicer\n", - "print(f\"A rough estimate of the computational savings is {100 * savings}%.\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "37df6964", - "metadata": {}, - "source": [ - "## Visualizations\n", - "\n", - "Finally, we show a visualization of learning curves versus actual elapsed wall time. This helps to illustrate that stopped trials make room for additional trials to be run." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c88cb8d0", - "metadata": {}, - "outputs": [], - "source": [ - "# helper function for getting trial start times\n", - "def time_started(row):\n", - " trial_index = row[\"trial_index\"]\n", - " return experiment.trials[trial_index].time_run_started\n", - "\n", - "\n", - "# helper function for getting trial completion times\n", - "def time_completed(row):\n", - " trial_index = row[\"trial_index\"]\n", - " return experiment.trials[trial_index].time_completed\n", - "\n", - "\n", - "# helper function for getting relevant data from experiment\n", - "# with early stopping into useful dfs\n", - "def early_stopping_exp_to_df(experiment):\n", - " trials_df = exp_to_df(experiment)\n", - " curve_df = experiment.lookup_data().map_df\n", - " training_row_df = (\n", - " curve_df.groupby(\"trial_index\").max().reset_index()[[\"trial_index\", \"steps\"]]\n", - " )\n", - " trials_df = trials_df.merge(training_row_df, on=\"trial_index\")\n", - " trials_df[\"time_started\"] = trials_df.apply(func=time_started, axis=1)\n", - " trials_df[\"time_completed\"] = trials_df.apply(func=time_completed, axis=1)\n", - " start_time = trials_df[\"time_started\"].min()\n", - " trials_df[\"time_started_rel\"] = (\n", - " trials_df[\"time_started\"] - start_time\n", - " ).dt.total_seconds()\n", - " trials_df[\"time_completed_rel\"] = (\n", - " trials_df[\"time_completed\"] - start_time\n", - " ).dt.total_seconds()\n", - " return trials_df, curve_df\n", - "\n", - "\n", - "def plot_curves_by_wall_time(trials_df, curve_df):\n", - " trials = set(curve_df[\"trial_index\"])\n", - " fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n", - " ax.set(xlabel=\"seconds since start\", ylabel=\"validation accuracy\")\n", - " for trial_index in trials:\n", - " this_trial_df = curve_df[curve_df[\"trial_index\"] == trial_index]\n", - " start_time_rel = trials_df[\"time_started_rel\"].iloc[trial_index]\n", - " completed_time_rel = trials_df[\"time_completed_rel\"].iloc[trial_index]\n", - " total_steps = trials_df.loc[trial_index, \"steps\"]\n", - " smoothed_curve = this_trial_df[\"mean\"].rolling(window=3).mean()\n", - " x = (\n", - " start_time_rel\n", - " + (completed_time_rel - start_time_rel)\n", - " / total_steps\n", - " * this_trial_df[\"steps\"]\n", - " )\n", - " ax.plot(\n", - " x,\n", - " smoothed_curve,\n", - " label=f\"trial #{trial_index}\" if trial_index % 2 == 1 else None,\n", - " )\n", - " ax.legend()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d7f52fed", - "metadata": {}, - "outputs": [], - "source": [ - "# wrap in try/except in case of flaky I/O issues\n", - "try:\n", - " trials_df, curve_df = early_stopping_exp_to_df(experiment)\n", - " plot_curves_by_wall_time(trials_df, curve_df)\n", - "except Exception as e:\n", - " print(f\"Encountered exception while plotting results: {e}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "193e2fc7", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "12fe3797", + "metadata": {}, + "source": [ + "## Trial-level early stopping in Ax\n", + "\n", + "This tutorial illustrates how to add a trial-level early stopping strategy to an Ax hyper-parameter optimization (HPO) loop. The goal of trial-level early stopping is to monitor the results of expensive evaluations and terminate those that are unlikely to produce promising results, freeing up resources to explore more configurations.\n", + "\n", + "Most of this tutorial is adapted from the [PyTorch Ax Multiobjective NAS Tutorial](https://pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html). The training job is different from the original in that we do not optimize `batch_size` or `epochs`. This was done for illustrative purposes, as each validation curve now has the same number of points. The companion training file `mnist_train_nas.py` has also been altered to log to Tensorboard during training.\n", + "\n", + "NOTE: Although the original NAS tutorial is for a multi-objective problem, this tutorial focuses on a single objective (validation accuracy) problem. Early stopping currently does not support \\\"true\\\" multi-objective stopping, although one can use [logical compositions of early stopping strategies](https://github.com/facebook/Ax/blob/main/ax/early_stopping/strategies/logical.py) to target multiple objectives separately. Early stopping for the multi-objective case is currently a work in progress." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb953f30", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "\n", + "from pathlib import Path\n", + "\n", + "import torchx\n", + "\n", + "from ax.core import Experiment, Objective, ParameterType, RangeParameter, SearchSpace\n", + "from ax.core.optimization_config import OptimizationConfig\n", + "\n", + "from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy\n", + "from ax.metrics.tensorboard import TensorboardMetric\n", + "\n", + "from ax.modelbridge.dispatch_utils import choose_generation_strategy\n", + "\n", + "from ax.runners.torchx import TorchXRunner\n", + "\n", + "from ax.service.scheduler import Scheduler, SchedulerOptions\n", + "from ax.service.utils.report_utils import exp_to_df\n", + "\n", + "from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer\n", + "\n", + "from torchx import specs\n", + "from torchx.components import utils\n", + "\n", + "from matplotlib import pyplot as plt\n", + "\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a7bd328", + "metadata": {}, + "outputs": [], + "source": [ + "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "fe2cf6fe", + "metadata": {}, + "source": [ + "## Defining the TorchX App\n", + "\n", + "Our goal is to optimize the PyTorch Lightning training job defined in\n", + "[mnist_train_nas.py](https://github.com/pytorch/tutorials/tree/master/intermediate_source/mnist_train_nas.py)_.\n", + "To do this using TorchX, we write a helper function that takes in\n", + "the values of the architcture and hyperparameters of the training\n", + "job and creates a [TorchX AppDef](https://pytorch.org/torchx/latest/basics.html)_\n", + "with the appropriate settings.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e21d309", + "metadata": {}, + "outputs": [], + "source": [ + "if SMOKE_TEST:\n", + " epochs = 3\n", + "else:\n", + " epochs = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b423923c", + "metadata": {}, + "outputs": [], + "source": [ + "def trainer(\n", + " log_path: str,\n", + " hidden_size_1: int,\n", + " hidden_size_2: int,\n", + " learning_rate: float,\n", + " dropout: float,\n", + " trial_idx: int = -1,\n", + ") -> specs.AppDef:\n", + "\n", + " # define the log path so we can pass it to the TorchX AppDef\n", + " if trial_idx >= 0:\n", + " log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()\n", + "\n", + " batch_size = 32\n", + "\n", + " return utils.python(\n", + " # command line args to the training script\n", + " \"--log_path\",\n", + " log_path,\n", + " \"--hidden_size_1\",\n", + " str(hidden_size_1),\n", + " \"--hidden_size_2\",\n", + " str(hidden_size_2),\n", + " \"--learning_rate\",\n", + " str(learning_rate),\n", + " \"--epochs\",\n", + " str(epochs),\n", + " \"--dropout\",\n", + " str(dropout),\n", + " \"--batch_size\",\n", + " str(batch_size),\n", + " # other config options\n", + " name=\"trainer\",\n", + " script=\"tutorials/early_stopping/mnist_train_nas.py\",\n", + " image=torchx.version.TORCHX_IMAGE,\n", + " )" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "65f7011d", + "metadata": {}, + "source": [ + "## Setting up the Runner\n", + "\n", + "Ax’s [Runner](https://ax.dev/api/core.html#ax.core.runner.Runner)\n", + "abstraction allows writing interfaces to various backends.\n", + "Ax already comes with Runner for TorchX, so we just need to\n", + "configure it. For the purpose of this tutorial, we run jobs locally\n", + "in a fully asynchronous fashion. In order to launch them on a cluster, you can instead specify a\n", + "different TorchX scheduler and adjust the configuration appropriately.\n", + "For example, if you have a Kubernetes cluster, you just need to change the\n", + "scheduler from ``local_cwd`` to ``kubernetes``.\n", + "\n", + "The training job launched by this runner will log partial results to Tensorboard, which will then be monitored by the early stopping strategy. We will show how this is done using an Ax \n", + "[TensorboardMetric](https://ax.dev/api/metrics.html#module-ax.metrics.tensorboard) below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "998e5835", + "metadata": {}, + "outputs": [], + "source": [ + "# Make a temporary dir to log our results into\n", + "log_dir = tempfile.mkdtemp()\n", + "\n", + "ax_runner = TorchXRunner(\n", + " tracker_base=\"/tmp/\",\n", + " component=trainer,\n", + " # NOTE: To launch this job on a cluster instead of locally you can\n", + " # specify a different scheduler and adjust args appropriately.\n", + " scheduler=\"local_cwd\",\n", + " component_const_params={\"log_path\": log_dir},\n", + " cfg={},\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "2fec7495", + "metadata": {}, + "source": [ + "## Setting up the SearchSpace\n", + "\n", + "First, we define our search space. Ax supports both range parameters\n", + "of type integer and float as well as choice parameters which can have\n", + "non-numerical types such as strings.\n", + "We will tune the hidden sizes, learning rate, and dropout parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf6f869f", + "metadata": {}, + "outputs": [], + "source": [ + "parameters = [\n", + " # NOTE: In a real-world setting, hidden_size_1 and hidden_size_2\n", + " # should probably be powers of 2, but in our simple example this\n", + " # would mean that num_params can't take on that many values, which\n", + " # in turn makes the Pareto frontier look pretty weird.\n", + " RangeParameter(\n", + " name=\"hidden_size_1\",\n", + " lower=16,\n", + " upper=128,\n", + " parameter_type=ParameterType.INT,\n", + " log_scale=True,\n", + " ),\n", + " RangeParameter(\n", + " name=\"hidden_size_2\",\n", + " lower=16,\n", + " upper=128,\n", + " parameter_type=ParameterType.INT,\n", + " log_scale=True,\n", + " ),\n", + " RangeParameter(\n", + " name=\"learning_rate\",\n", + " lower=1e-4,\n", + " upper=1e-2,\n", + " parameter_type=ParameterType.FLOAT,\n", + " log_scale=True,\n", + " ),\n", + " RangeParameter(\n", + " name=\"dropout\",\n", + " lower=0.0,\n", + " upper=0.5,\n", + " parameter_type=ParameterType.FLOAT,\n", + " ),\n", + "]\n", + "\n", + "search_space = SearchSpace(\n", + " parameters=parameters,\n", + " # NOTE: In practice, it may make sense to add a constraint\n", + " # hidden_size_2 <= hidden_size_1\n", + " parameter_constraints=[],\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "a8005e80", + "metadata": {}, + "source": [ + "## Setting up Metrics\n", + "\n", + "Ax has the concept of a Metric that defines properties of outcomes and how observations are obtained for these outcomes. This allows e.g. encodig how data is fetched from some distributed execution backend and post-processed before being passed as input to Ax.\n", + "\n", + "We will optimize the validation accuracy, which is a `TensorboardMetric` that points to the logging directory assigned above. Note that we have set `is_available_while_running`, allowing for the metric to be queried as the trial progresses. This is critical for the early stopping strategy to monitor partial results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0775a96e", + "metadata": {}, + "outputs": [], + "source": [ + "class MyTensorboardMetric(TensorboardMetric):\n", + "\n", + " # NOTE: We need to tell the new Tensorboard metric how to get the id /\n", + " # file handle for the tensorboard logs from a trial. In this case\n", + " # our convention is to just save a separate file per trial in\n", + " # the pre-specified log dir.\n", + " def _get_event_multiplexer_for_trial(self, trial):\n", + " mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)\n", + " mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)\n", + " mul.Reload()\n", + "\n", + " return mul\n", + "\n", + " # This indicates whether the metric is queryable while the trial is\n", + " # still running. This is required for early stopping to monitor the\n", + " # progress of the running trial.ArithmeticError\n", + " @classmethod\n", + " def is_available_while_running(cls):\n", + " return True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5c5a7d0", + "metadata": {}, + "outputs": [], + "source": [ + "val_acc = MyTensorboardMetric(\n", + " name=\"val_acc\",\n", + " tag=\"val_acc\",\n", + " lower_is_better=False,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "d4f3ba5d", + "metadata": {}, + "source": [ + "## Setting up the OptimizationConfig\n", + "\n", + "The `OptimizationConfig` specifies the objective for Ax to optimize." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ada66cf3", + "metadata": {}, + "outputs": [], + "source": [ + "opt_config = OptimizationConfig(\n", + " objective=Objective(\n", + " metric=val_acc,\n", + " minimize=False,\n", + " )\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "57aa9cf7", + "metadata": {}, + "source": [ + "## Defining an Early Stopping Strategy\n", + "\n", + "A `PercentileEarlyStoppingStrategy` is a simple method that stops a trial if its performance falls below a certain percentile of other trials at the same step (e.g., when `percentile_threshold` is 50, at a given point in time, if a trial ranks in the bottom 50% of trials, it is stopped). \n", + "- We make use of `normalize_progressions` which normalizes the progression column (e.g. timestamp, epochs, training data used) to be in [0, 1]. This is useful because one doesn't need to know the maximum progression values of the curve (which might be, e.g., the total number of data points in the training dataset).\n", + "- The `min_progression` parameter specifies that trials should only be considered for stopping if the latest progression value is greater than this threshold.\n", + "- The `min_curves` parameter specifies the minimum number of completed curves (i.e., fully completed training jobs) before early stopping will be considered. This should be larger than zero if `normalize_progression` is used. In general, we want a few completed curves to have a baseline for comparison.\n", + "\n", + "Note that `PercentileEarlyStoppingStrategy` does not make use of learning curve modeling or prediction. More sophisticated model-based methods will be available in future versions of Ax." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "949e8ab5", + "metadata": {}, + "outputs": [], + "source": [ + "percentile_early_stopping_strategy = PercentileEarlyStoppingStrategy(\n", + " # stop if in bottom 70% of runs at the same progression\n", + " percentile_threshold=70,\n", + " # the trial must have passed `min_progression` steps before early stopping is initiated\n", + " # note that we are using `normalize_progressions`, so this is on a scale of [0, 1]\n", + " min_progression=0.3,\n", + " # there must be `min_curves` completed trials and `min_curves` trials reporting data in\n", + " # order for early stopping to be applicable\n", + " min_curves=5,\n", + " # specify, e.g., [0, 1] if the first two trials should never be stopped\n", + " trial_indices_to_ignore=None,\n", + " # check for new data every 10 seconds\n", + " seconds_between_polls=10,\n", + " normalize_progressions=True,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "2665ca93", + "metadata": {}, + "source": [ + "## Creating the Ax Experiment\n", + "\n", + "In Ax, the Experiment object is the object that stores all the information about the problem setup." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12849b31", + "metadata": {}, + "outputs": [], + "source": [ + "experiment = Experiment(\n", + " name=\"torchx_mnist\",\n", + " search_space=search_space,\n", + " optimization_config=opt_config,\n", + " runner=ax_runner,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "49a4ed0e", + "metadata": {}, + "source": [ + "## Choosing the GenerationStrategy\n", + "\n", + "A [GenerationStrategy](https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStrategy)\n", + "is the abstract representation of how we would like to perform the\n", + "optimization. While this can be customized (if you’d like to do so, see\n", + "[this tutorial](https://ax.dev/tutorials/generation_strategy.html)),\n", + "in most cases Ax can automatically determine an appropriate strategy\n", + "based on the search space, optimization config, and the total number\n", + "of trials we want to run.\n", + "\n", + "Typically, Ax chooses to evaluate a number of random configurations\n", + "before starting a model-based Bayesian Optimization strategy.\n", + "\n", + "We remark that in Ax, generation strategies and early stopping strategies are separate, a design decision motivated by ease-of-use. However, we should acknowledge that jointly considering generation and stopping using a single strategy would likely be the \"proper\" formulation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e38d0237", + "metadata": {}, + "outputs": [], + "source": [ + "if SMOKE_TEST:\n", + " total_trials = 6\n", + "else:\n", + " total_trials = 15 # total evaluation budget\n", + "\n", + "gs = choose_generation_strategy(\n", + " search_space=experiment.search_space,\n", + " optimization_config=experiment.optimization_config,\n", + " num_trials=total_trials,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "78d86fea", + "metadata": {}, + "source": [ + "## Configuring the Scheduler\n", + "\n", + "The `Scheduler` acts as the loop control for the optimization.\n", + "It communicates with the backend to launch trials, check their status, retrieve (partial) results, and importantly for this tutorial, calls the early stopping strategy. If the early stopping strategy suggests a trial to be the stopped, the `Scheduler` communicates with the backend to terminate the trial.\n", + "\n", + "The ``Scheduler`` requires the ``Experiment`` and the ``GenerationStrategy``.\n", + "A set of options can be passed in via ``SchedulerOptions``. Here, we\n", + "configure the number of total evaluations as well as ``max_pending_trials``,\n", + "the maximum number of trials that should run concurrently. In our\n", + "local setting, this is the number of training jobs running as individual\n", + "processes, while in a remote execution setting, this would be the number\n", + "of machines you want to use in parallel.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "499fb9b5", + "metadata": {}, + "outputs": [], + "source": [ + "scheduler = Scheduler(\n", + " experiment=experiment,\n", + " generation_strategy=gs,\n", + " options=SchedulerOptions(\n", + " total_trials=total_trials,\n", + " max_pending_trials=5,\n", + " early_stopping_strategy=percentile_early_stopping_strategy,\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78257ebb", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "scheduler.run_all_trials()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8c5afbe8", + "metadata": {}, + "source": [ + "## Results\n", + "\n", + "First, we examine the data stored on the experiment. This shows that each trial is associated with an entire learning curve, represented by the column \"steps\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "764365f0", + "metadata": {}, + "outputs": [], + "source": [ + "experiment.lookup_data().map_df.head(n=10)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0033ed2e", + "metadata": {}, + "source": [ + "Below is a summary of the experiment, showing that a portion of trials have been early stopped." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00f2b35f", + "metadata": {}, + "outputs": [], + "source": [ + "exp_to_df(experiment)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f8909cf2", + "metadata": {}, + "source": [ + "We can give a very rough estimate of the amount of computational savings due to early stopping, by looking at the total number of steps used when early stopping is used versus the number of steps used if we ran all trials to completion. Note to do a true comparison, one should run full HPO loops with and without early stopping (as early stopping will influence the model and future points selected by the generation strategy). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5abb3ce8", + "metadata": {}, + "outputs": [], + "source": [ + "map_df = experiment.lookup_data().map_df\n", + "trial_to_max_steps = map_df.groupby(\"trial_index\")[\"step\"].max()\n", + "completed_trial_steps = trial_to_max_steps.iloc[0]\n", + "savings = 1.0 - trial_to_max_steps.sum() / (\n", + " completed_trial_steps * len(trial_to_max_steps)\n", + ")\n", + "# TODO format nicer\n", + "print(f\"A rough estimate of the computational savings is {100 * savings}%.\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "37df6964", + "metadata": {}, + "source": [ + "## Visualizations\n", + "\n", + "Finally, we show a visualization of learning curves versus actual elapsed wall time. This helps to illustrate that stopped trials make room for additional trials to be run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c88cb8d0", + "metadata": {}, + "outputs": [], + "source": [ + "# helper function for getting trial start times\n", + "def time_started(row):\n", + " trial_index = row[\"trial_index\"]\n", + " return experiment.trials[trial_index].time_run_started\n", + "\n", + "\n", + "# helper function for getting trial completion times\n", + "def time_completed(row):\n", + " trial_index = row[\"trial_index\"]\n", + " return experiment.trials[trial_index].time_completed\n", + "\n", + "\n", + "# helper function for getting relevant data from experiment\n", + "# with early stopping into useful dfs\n", + "def early_stopping_exp_to_df(experiment):\n", + " trials_df = exp_to_df(experiment)\n", + " curve_df = experiment.lookup_data().map_df\n", + " training_row_df = (\n", + " curve_df.groupby(\"trial_index\").max().reset_index()[[\"trial_index\", \"steps\"]]\n", + " )\n", + " trials_df = trials_df.merge(training_row_df, on=\"trial_index\")\n", + " trials_df[\"time_started\"] = trials_df.apply(func=time_started, axis=1)\n", + " trials_df[\"time_completed\"] = trials_df.apply(func=time_completed, axis=1)\n", + " start_time = trials_df[\"time_started\"].min()\n", + " trials_df[\"time_started_rel\"] = (\n", + " trials_df[\"time_started\"] - start_time\n", + " ).dt.total_seconds()\n", + " trials_df[\"time_completed_rel\"] = (\n", + " trials_df[\"time_completed\"] - start_time\n", + " ).dt.total_seconds()\n", + " return trials_df, curve_df\n", + "\n", + "\n", + "def plot_curves_by_wall_time(trials_df, curve_df):\n", + " trials = set(curve_df[\"trial_index\"])\n", + " fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n", + " ax.set(xlabel=\"seconds since start\", ylabel=\"validation accuracy\")\n", + " for trial_index in trials:\n", + " this_trial_df = curve_df[curve_df[\"trial_index\"] == trial_index]\n", + " start_time_rel = trials_df[\"time_started_rel\"].iloc[trial_index]\n", + " completed_time_rel = trials_df[\"time_completed_rel\"].iloc[trial_index]\n", + " total_steps = trials_df.loc[trial_index, \"steps\"]\n", + " smoothed_curve = this_trial_df[\"mean\"].rolling(window=3).mean()\n", + " x = (\n", + " start_time_rel\n", + " + (completed_time_rel - start_time_rel)\n", + " / total_steps\n", + " * this_trial_df[\"steps\"]\n", + " )\n", + " ax.plot(\n", + " x,\n", + " smoothed_curve,\n", + " label=f\"trial #{trial_index}\" if trial_index % 2 == 1 else None,\n", + " )\n", + " ax.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7f52fed", + "metadata": {}, + "outputs": [], + "source": [ + "# wrap in try/except in case of flaky I/O issues\n", + "try:\n", + " trials_df, curve_df = early_stopping_exp_to_df(experiment)\n", + " plot_curves_by_wall_time(trials_df, curve_df)\n", + "except Exception as e:\n", + " print(f\"Encountered exception while plotting results: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "193e2fc7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 }