diff --git a/examples/how_to_scale_op.ipynb b/examples/how_to_scale_op.ipynb new file mode 100644 index 0000000..5d1054c --- /dev/null +++ b/examples/how_to_scale_op.ipynb @@ -0,0 +1,633 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) 2024 Graphcore Ltd. All rights reserved.\n", + "\n", + "# How to unit-scale an op\n", + "\n", + "The unit-scaled maximal update parametrisation, [u-μP](https://arxiv.org/abs/2407.17465), enables hyperparameter transfer and low-precision training by paying careful attention to the _scale_ (standard deviation, or 'std') of tensors in the forward and backward passes.\n", + "\n", + "In order to construct u-μP models, we need _scaled_ ops. A scaled op produces approximately unit-std outputs when given unit-std inputs. Likewise, in the backward pass, it produces unit-std input gradients when given unit-std output gradients. The [unit-scaling](https://github.com/graphcore-research/unit-scaling) library provides implementations of many common ops, but it can never be exhaustive, so in this notebook we walk through how to unit-scale an op for ourselves.\n", + "\n", + "Structure:\n", + " - **Introduction** - what is the task?\n", + " - **Empirical scaling** - scaling our op via simulation and empirical measurement.\n", + " - **Statistical scaling** - scaling our op via statistical analysis.\n", + " - **Scaling constraints** - a mechanism for supporting the cut-edge rule.\n", + " - **Summing up** - phew!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "#### Imports/preamble/helpers (nothing much to see here - feel free to skip)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from math import e, erf, pi, sqrt, exp\n", + "from typing import *\n", + "\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torch import Tensor\n", + "\n", + "from unit_scaling.constraints import apply_constraint\n", + "from unit_scaling.scale import scale_fwd, scale_bwd\n", + "\n", + "matplotlib.rc(\"axes\", **{\"spines.top\": False, \"spines.right\": False})\n", + "matplotlib.rc(\"legend\", frameon=False)\n", + "\n", + "def jointplot(df: pd.DataFrame, *, x: str, y: str,\n", + " xlabel: Optional[str] = None, ylabel: Optional[str] = None) -> sns.JointGrid:\n", + " g = sns.JointGrid(data=df, x=x, y=y, height=4, ratio=2)\n", + " g.plot_joint(sns.scatterplot, s=8, lw=0.2)\n", + " g.plot_marginals(sns.histplot, bins=20)\n", + "\n", + " g.ax_joint.set_xticks([-2, -1, 0, 1, 2])\n", + " g.ax_joint.set_xlim(-2.5, 2.5)\n", + " g.ax_joint.set_xlabel(f\"${xlabel or x}$\")\n", + " g.ax_joint.set_yticks([-2, -1, 0, 1, 2])\n", + " g.ax_joint.set_ylim(-2.5, 2.5)\n", + " g.ax_joint.set_ylabel(f\"${ylabel or y}$\")\n", + "\n", + " x_rms = sqrt((df[x]**2).mean())\n", + " y_rms = sqrt((df[y]**2).mean())\n", + " g.ax_marg_x.set_title(f\"$\\\\mathrm{{RMS}}({xlabel or x})={x_rms:.2f}$\", fontsize=10)\n", + " g.ax_marg_y.set_title(f\"$\\\\mathrm{{RMS}}({ylabel or y})={y_rms:.2f}$\", fontsize=10)\n", + "\n", + "def opplot(name_to_fn: Dict[str, Callable[[Tensor], Tensor]]) -> None:\n", + " fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(9, 3))\n", + " x = torch.linspace(-2.5, 2.5, int(1e4)).requires_grad_()\n", + " for name, fn in name_to_fn.items():\n", + " x.grad = None\n", + " y = fn(x)\n", + " y.backward(torch.ones_like(y))\n", + " ax0.plot(x.detach(), y.detach(), label=f\"y={name}\")\n", + " ax1.plot(x.detach(), x.grad)\n", + " ax0.set_xlabel(\"x\")\n", + " ax0.set_ylabel(\"y\")\n", + " ax1.set_xlabel(\"x\")\n", + " ax1.set_ylabel(\"dy/dx\")\n", + " fig.legend(*ax0.get_legend_handles_labels(), loc=\"center left\", bbox_to_anchor=(.9, 0.5))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Introduction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our chosen task is to unit-scale `F.hardtanh`, an elementwise nonlinearity provided by PyTorch that looks like a harder/sharper version of `tanh`. It's defined as `F.hardtanh(x, a=-1, b=1) = clip(x, a, b)`, and is suitable as an illustrative example, as it permits both empirical and statistical scaling methods, as we'll see.\n", + "\n", + "The op and gradient look like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "opplot({\"F.hardtanh(x)\": F.hardtanh})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's the _unscaled_ op, but we want to find the equivalent _scaled_ op. To do this:\n", + "\n", + " - Replace thresholds `a` and `b` with a `mult` shape parameter, since the output should have zero-mean.\n", + " - Insert scaling factors separately into the forward and backward passes, to achieve unit scale.\n", + "\n", + "In code, that's\n", + "\n", + "```python\n", + "def hardtanh(x: Tensor, mult: float = 1) -> Tensor:\n", + " y_scale, grad_scale = ... # ???\n", + " x = scale_bwd(x, grad_scale)\n", + " y = F.hardtanh(x, -1/mult, 1/mult)\n", + " return scale_fwd(y, y_scale)\n", + "```\n", + "\n", + "which relies on the utilities `scale_fwd` and `scale_bwd` from `unit_scaling.scale`. These apply a muliplicative scaling factor in either the forward or backwards pass (compare with the straight-forward `x * scale` which applies the same scale in both foward and backward passes).\n", + "\n", + "**The remaining problem is how to choose `y_scale` and `grad_scale`, so that `y.std` ≈ 1 when `x.std` = 1 and also `x.grad.std` ≈ 1 when `y.grad.std` = 1.**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Empirical scaling\n", + "\n", + "The simplest way to choose these factors is to feed inputs from an appropriate distribution (typically unit Gaussian), into both forward and backward passes, and measure how the scale changes. Then set scaling factors to counteract this.\n", + "\n", + "Let's try:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.nn.functional.hardtanh\n", + " x.std = 0.999\n", + " y.std = 0.718\n", + " grad_x.std = 0.828\n" + ] + } + ], + "source": [ + "def check_scaling(fn: Callable[[Tensor], Tensor], **kwargs: Any) -> None:\n", + " x = torch.randn(int(1e6)).requires_grad_()\n", + " y = fn(x, **kwargs)\n", + " y.backward(torch.randn_like(y))\n", + "\n", + " name = f\"{fn.__module__}.{fn.__name__}\".replace(\"__main__.\", \"\")\n", + " print(name + (f\" {kwargs}\" if kwargs else \"\"))\n", + " for k, v in {\"x\": x, \"y\": y, \"grad_x\": x.grad}.items():\n", + " print(f\"{k:>10}.std = {v.std(correction=0).item():.3f}\")\n", + "\n", + "check_scaling(F.hardtanh)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The unscaled op therefore shrinks the scale in both forward and backward passes.\n", + "\n", + "We can pluck these scales, and use them to set `y_scale = 1 / empirical_y_std` and `grad_scale = 1 / empirical_grad_x_std`:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hardtanh_scaled_empirical\n", + " x.std = 1.000\n", + " y.std = 1.000\n", + " grad_x.std = 1.001\n" + ] + } + ], + "source": [ + "def hardtanh_scaled_empirical(x: Tensor) -> Tensor:\n", + " y_scale = 1 / 0.718\n", + " grad_scale = 1 / 0.826\n", + " x = scale_bwd(x, grad_scale)\n", + " y = F.hardtanh(x)\n", + " return scale_fwd(y, y_scale)\n", + "\n", + "check_scaling(hardtanh_scaled_empirical)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Success! We've unit-scaled hardtanh!\n", + "\n", + "However, we still haven't dealt with `mult`, as these static factors do not generalise to different clipping thresholds. We could do this by fitting a curve `f` to `y.std = f(mult)`, but we'll instead use this as an opportunity to highlight an alternative approach based on statistical analysis." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Statistical scaling\n", + "\n", + "Our second approach will be to try to compute the output distribution and its standard deviation, given an input distribution.\n", + "\n", + "First, let's eyeball the forward-pass distribution for `y = F.hardtanh(x)` when $x \\sim \\mathcal{N}(0, 1)$. Note that this plot shows RMS, which is equal to standard deviation when zero-mean." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x = torch.randn(int(1e5)).requires_grad_()\n", + "y = F.hardtanh(x)\n", + "y.backward(torch.ones_like(y))\n", + "df = pd.DataFrame.from_dict(dict(x=x.detach(), y=y.detach(), grad_x=x.grad))\n", + "\n", + "jointplot(df, x=\"x\", y=\"y\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Recalling that `hardtanh(x, mult)` is defined as `clip(x, -1/mult, 1/mult)`, we can see that the output distribution is a mixture of three components:\n", + "\n", + " - truncated Gaussian with weight `p(-1/mult <= x <= 1/mult)`\n", + " - spike at `y=-1/mult` with weight `p(x < -1/mult)`\n", + " - spike at `y=1/mult` with weight `p(x > 1/mult)`\n", + "\n", + "So we can write out the pdf of $Y$ as a mixture distribution (where $\\alpha = \\mathrm{mult}$):\n", + "\n", + "$\\mathrm{P}(Y=y) = \\begin{cases}\\frac{1-Z}{2}\\, \\delta(y - \\alpha^{-1}) + \\frac{1-Z}{2}\\, \\delta(y + \\alpha^{-1}) + \\varphi(y) & \\textrm{if } {-\\alpha^{-1}} \\leq y \\leq \\alpha^{-1} \\\\\n", + "0 & \\textrm{otherwise}\n", + "\\end{cases}\n", + "$\n", + "\n", + "where $Z \\coloneqq \\mathrm{erf}(\\sqrt{\\frac{1}{2}}\\, \\alpha^{-1})$ is the probability that $X \\sim \\mathcal{N}(0,1)$ falls in the range $[-\\alpha^{-1}, \\alpha^{-1}]$, $\\delta(\\cdot)$ is the Dirac delta function and $\\varphi(\\cdot)$ is the Gaussian pdf (note that the $Z$ normaliser for a truncated Gaussian cancels exactly with the mixture weight).\n", + "\n", + "Next, by symmetry, we can observe that $\\mathrm{E}(Y) = 0$. Therefore the scale, $\\sigma_Y = \\sqrt{\\mathrm{E}(Y^2) - \\mathrm{E}(Y)^2} = \\sqrt{\\mathrm{E}(Y^2)}$.\n", + "\n", + "This expands to:\n", + "\n", + "$\\sigma_Y = \\sqrt{(1-Z)\\, \\alpha^{-2} + Z\\,(1 - 2 e^{-1/(2\\alpha^2)} / (Z \\alpha \\sqrt{2 \\pi}))}$\n", + "\n", + "where the first term is from the pair of spikes, and the second term from the variance of a symmetric [truncated Gaussian](https://en.wikipedia.org/wiki/Truncated_normal_distribution).\n", + "\n", + "Leading to the **forward scale**:\n", + "\n", + "$\\sigma_Y = \\sqrt{\\alpha^{-2} + (1 - \\alpha^{-2})\\,\\mathrm{erf}(\\sqrt{\\frac{1}{2}}\\, \\alpha^{-1}) - \\sqrt{\\frac{2}{\\pi}}\\, \\alpha^{-1}\\, e^{-\\frac{1}{2}\\alpha^{-2}}}$\n", + "\n", + "> Note: when $\\alpha\\!=\\!1$, this simplifies to $\\sqrt{1 - \\sqrt{2 / (\\pi e)}}$\n", + "\n", + "Let's test this rule by sweeping `mult` over a logarithmic range:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x = torch.randn(int(1e6)).requires_grad_()\n", + "mults = 2**torch.linspace(-3, 3, 31)\n", + "scales, grad_scales = [], []\n", + "for mult in mults:\n", + " x.grad = None\n", + " y = F.hardtanh(x, -1/mult, 1/mult)\n", + " y.backward(torch.randn_like(y))\n", + " scales.append(y.std(correction=0).item())\n", + " grad_scales.append(x.grad.std(correction=0).item())\n", + "\n", + "model_scales = [\n", + " sqrt(mult**-2 + (1 - mult**-2) * erf(1/(mult*sqrt(2))) - sqrt(2/pi) * mult**-1 * exp(-1/2 * mult**-2))\n", + " for mult in mults\n", + "]\n", + "_, ax = plt.subplots(figsize=(5, 3))\n", + "ax.plot(mults, scales, label=\"experiment\", zorder=1)\n", + "ax.plot(mults, model_scales, lw=4, label=\"model\", zorder=0)\n", + "ax.set_xscale(\"log\", base=2); ax.set_yscale(\"log\", base=2); ax.legend()\n", + "ax.set_xlabel(\"mult\"); ax.set_ylabel(\"y.std\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This looks pretty solid.\n", + "\n", + "Now for the backwards pass. Note that in this case we just feed in `y.grad = 1` in order to obtain the partial derivatives $\\partial y / \\partial x$ that define the scaling behaviour:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "jointplot(df, x=\"x\", y=\"grad_x\", ylabel=r\"\\frac{\\partial y}{\\partial x}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The partial derivatives are a square pulse, where $\\frac{\\partial y}{\\partial x}=1$ when ${-\\alpha^{-1}} \\leq x \\leq \\alpha^{-1}$ and $0$ otherwise.\n", + "\n", + "We assume that the output gradient $\\dot{Y}$ is independent of the inputs $X$, so that the input gradient is the product of two random variables: $\\dot{X} = \\dot{Y} \\Delta$, where $\\dot{Y} \\sim N(0, 1)$ and $\\Delta \\sim \\mathrm{Bernoulli}(Z)$, where $Z$ is defined as previously.\n", + "\n", + "Since they're independent, the expectation of the product is the product of the expectations, so $\\mathrm{E}(\\dot{X})=0$, and\n", + "\n", + "$\\sigma_{\\dot{X}} = \\sqrt{\\mathrm{E}((\\dot{Y} \\Delta)^2)} = \\sqrt{Z\\, \\mathrm{E}(\\dot{Y}^2) + (1-Z)\\,0} = \\sqrt{Z}$\n", + "\n", + "Therefore, we have the **backward scale**:\n", + "\n", + "$\\sigma_{\\dot{X}} = \\sqrt{\\mathrm{erf}(\\sqrt{\\frac{1}{2}}\\, \\alpha^{-1})}$" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model_grad_scales = [sqrt(erf(1/(mult*sqrt(2)))) for mult in mults]\n", + "\n", + "_, ax = plt.subplots(figsize=(5, 3))\n", + "ax.plot(mults, grad_scales, label=\"experiment\", zorder=1)\n", + "ax.plot(mults, model_grad_scales, lw=4, label=\"model\", zorder=0)\n", + "ax.set_xscale(\"log\", base=2); ax.set_yscale(\"log\", base=2); ax.legend()\n", + "ax.set_xlabel(\"mult\"); ax.set_ylabel(\"x.grad.std\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Putting these rules together, we can test our new, 'fancy statistical' version of scaled hardtanh:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hardtanh_scaled_statistical\n", + " x.std = 1.000\n", + " y.std = 1.000\n", + " grad_x.std = 1.001\n", + "\n", + "hardtanh_scaled_statistical {'mult': 0.25}\n", + " x.std = 1.000\n", + " y.std = 1.000\n", + " grad_x.std = 0.999\n", + "\n", + "hardtanh_scaled_statistical {'mult': 4}\n", + " x.std = 1.000\n", + " y.std = 1.000\n", + " grad_x.std = 0.999\n" + ] + } + ], + "source": [ + "def hardtanh_scaled_statistical(x: Tensor, mult: float = 1.0) -> Tensor:\n", + " Z = erf(1 / (mult * sqrt(2)))\n", + " y_scale = 1 / sqrt(Z + (1 - Z) / mult**2 - sqrt(2/pi) / mult * exp(-1/2 / mult**2))\n", + " grad_scale = 1 / sqrt(Z)\n", + " x = scale_bwd(x, grad_scale)\n", + " y = F.hardtanh(x, -1/mult, 1/mult)\n", + " return scale_fwd(y, y_scale)\n", + "\n", + "check_scaling(hardtanh_scaled_statistical); print()\n", + "check_scaling(hardtanh_scaled_statistical, mult=1/4); print()\n", + "check_scaling(hardtanh_scaled_statistical, mult=4)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Supporting scaling constraints\n", + "\n", + "The version we've got is great for maintaining scale in the forward and backward passes. Unit scaling, however, sometimes requires these scales to be kept consistent. This is enforced by the user according to the cut edge rule (see the [unit scaling](https://arxiv.org/abs/2303.11257) paper for more detail) and supported in the op via a string argument `constraint`:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hardtanh {'constraint': 'to_output_scale'}\n", + " x.std = 1.000\n", + " y.std = 1.000\n", + " grad_x.std = 1.149\n", + "\n", + "hardtanh {'constraint': 'to_grad_input_scale'}\n", + " x.std = 1.001\n", + " y.std = 0.870\n", + " grad_x.std = 0.999\n", + "\n", + "hardtanh {'constraint': 'gmean'}\n", + " x.std = 1.000\n", + " y.std = 0.932\n", + " grad_x.std = 1.073\n", + "\n", + "hardtanh {'constraint': None}\n", + " x.std = 0.999\n", + " y.std = 1.000\n", + " grad_x.std = 1.002\n", + "\n" + ] + } + ], + "source": [ + "# (Final version)\n", + "def hardtanh(x: Tensor, mult: float = 1.0, constraint: Optional[str] = \"to_output_scale\") -> Tensor:\n", + " Z = erf(1 / (mult * sqrt(2)))\n", + " y_scale = 1 / sqrt(Z + (1 - Z) / mult**2 - sqrt(2/pi) / mult * exp(-1/2 / mult**2))\n", + " grad_scale = 1 / sqrt(Z)\n", + " y_scale, grad_scale = apply_constraint(constraint, y_scale, grad_scale)\n", + " x = scale_bwd(x, grad_scale)\n", + " y = F.hardtanh(x, -1/mult, 1/mult)\n", + " return scale_fwd(y, y_scale)\n", + "\n", + "check_scaling(hardtanh, constraint=\"to_output_scale\"); print()\n", + "check_scaling(hardtanh, constraint=\"to_grad_input_scale\"); print()\n", + "check_scaling(hardtanh, constraint=\"gmean\"); print()\n", + "check_scaling(hardtanh, constraint=None); print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we can see the trade-off implied by constraints. When constraining the scaling factors, either the forward or backward passes can be well-scaled (or some trade-off between them), but in general it isn't possible for both to have good scale when `y_scale == grad_scale`.\n", + "\n", + "With larger `mult`, the constrained scaling rule must relax the unit scale requirement, as the ideal forward and backward scales are more different:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "hardtanh {'mult': 4, 'constraint': 'to_output_scale'}\n", + " x.std = 0.999\n", + " y.std = 1.000\n", + " grad_x.std = 1.906\n", + "\n", + "hardtanh {'mult': 4, 'constraint': 'to_grad_input_scale'}\n", + " x.std = 1.001\n", + " y.std = 0.524\n", + " grad_x.std = 0.999\n" + ] + } + ], + "source": [ + "check_scaling(hardtanh, mult=4, constraint=\"to_output_scale\"); print()\n", + "check_scaling(hardtanh, mult=4, constraint=\"to_grad_input_scale\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The default constraint should generally be `\"to_output_scale\"`, which keeps forward and backward passes consistent while prioritising forward-pass scaling. An exception is for arguments that are typically trainable parameters, where the default constraint should be `None`.\n", + "\n", + "Let's take a look at the final scaled op (with default constraint):" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "opplot({\"F.hardtanh(x)\": F.hardtanh,\n", + " \"hardtanh(x)\": hardtanh,\n", + " \"hardtanh(x, mult=3)\": lambda x: hardtanh(x, mult=3)})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summing up\n", + "\n", + "That's it; unit-scaling an op can be a relatively systematic process, based on either empirical or statistical analysis.\n", + "\n", + "> Note: The main thing our `hardtanh` example missed is any dependence on input tensor shapes. Since it is an elementwise nonlinearity, these do not change the scaling behaviour. In general, a scaling rule would have to consider input shapes.\n", + "\n", + "To unit-scale an op for u-μP:\n", + "\n", + " 1. If needed, add a `mult` hyperparameter to control the shape of the op when the input scale is unit.\n", + " 1. Make distributional assumptions about the forward and backward pass inputs (typically: IID unit Gaussian).\n", + " 1. _Either:_ feed in data from these distributions and measure the change in `std`.\n", + " 1. _Or:_ do some maths to work out the theoretical change in `std`.\n", + " 1. Add the `apply_constraint` boilerplate to support constrained scales.\n", + " 1. Test it over a range of shapes and `mult` (if applicable), by feeding in artificial inputs & grad-outputs, measuring `std`.\n", + "\n", + "To use the compendium of ops for which we have proposed scaling rules, or to develop your own rules, see https://github.com/graphcore-research/unit-scaling.\n", + "\n", + "Thanks for reading, well done for reaching the end & happy scaling!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}