Skip to content

Commit

Permalink
Merge pull request #52 from graphcore-research/awf/pt23
Browse files Browse the repository at this point in the history
Update to PyTorch 2.2
awf authored Apr 22, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents d60dbc3 + 600fe18 commit 2394ee4
Showing 7 changed files with 98 additions and 96 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -4,4 +4,4 @@ einops
numpy
seaborn
tabulate
torch==2.1
torch>=2.2
5 changes: 3 additions & 2 deletions unit_scaling/analysis.py
Original file line number Diff line number Diff line change
@@ -282,8 +282,6 @@ def _rename(s: str) -> str:
s = s.replace("transformer_", "")
return s

p.set_yticklabels([_rename(item.get_text()) for item in p.get_yticklabels()])

plt.axvline(2**-14, color="grey", dashes=(3, 1))
plt.axvline(2**-7, color="grey", dashes=(1, 3))
plt.axvline(240, color="grey", dashes=(1, 3))
@@ -438,6 +436,9 @@ def draw_arrow(node_a: Node, node_b: Node, direction: str) -> None:
for direction in ["fwd", "bwd"]:
draw_error_bar(n, direction)

p.set_yticks(p.get_yticks())
p.set_yticklabels([_rename(item.get_text()) for item in p.get_yticklabels()])

return p # type: ignore[no-any-return]


19 changes: 10 additions & 9 deletions unit_scaling/docs.py
Original file line number Diff line number Diff line change
@@ -41,20 +41,21 @@ def _validate(
raise ValueError(f"unsupported arg '{arg}' has no default value")

@wraps(f)
def f_new(*args: Any, **kwargs: Any) -> T:
def _validate_args_supported(*args: Any, **kwargs: Any) -> T:
arg_values = dict(zip(argspec.args, args))
full_kwargs = {**arg_values, **kwargs}
for arg_name, arg_value in full_kwargs.items():
arg_default_value = default_kwargs[arg_name]
if arg_name in unsupported_args and arg_value != arg_default_value:
raise ValueError(
f"Support for the '{arg_name}' argument has not been implemented"
" for the unit-scaling library. Please remove it or replace it"
" with its default value."
)
if arg_name in unsupported_args:
arg_default_value = default_kwargs[arg_name]
if arg_value != arg_default_value:
raise ValueError(
f"Support for the '{arg_name}' argument has not been"
" implemented for the unit-scaling library."
" Please remove it or replace it with its default value."
)
return f(*args, **kwargs)

return f_new
return _validate_args_supported


def _get_docstring_from_target(
8 changes: 4 additions & 4 deletions unit_scaling/tests/test_analysis.py
Original file line number Diff line number Diff line change
@@ -82,14 +82,14 @@ def forward(self, x: Tensor) -> Tensor: # pragma: no covers
"layer": [
"x",
"x",
"relu",
"relu",
"y",
"y",
"linear_weight",
"linear_weight",
"linear_bias",
"linear_bias",
"linear",
"linear",
"z",
"z",
"sum_1",
"sum_1",
],
18 changes: 11 additions & 7 deletions unit_scaling/tests/transforms/test_track_scales.py
Original file line number Diff line number Diff line change
@@ -19,18 +19,18 @@
)


def get_target(node: Node) -> Union[str, Callable]: # type: ignore[type-arg]
def get_target_or_node_name(node: Node) -> Union[str, Callable[..., Any]]:
return node.meta["clean_name"] if isinstance(node.target, str) else node.target


def get_targets(graph: Graph) -> Set[Union[str, Callable]]: # type: ignore[type-arg]
return set(get_target(node) for node in graph.nodes)
return set(get_target_or_node_name(node) for node in graph.nodes)


def get_target_map(
graph: Graph,
) -> Dict[Union[str, Callable], Dict[str, Any]]: # type: ignore[type-arg]
return {get_target(node): node.meta for node in graph.nodes}
return {get_target_or_node_name(node): node.meta for node in graph.nodes}


def test_track_scales() -> None:
@@ -181,12 +181,16 @@ def forward(self, idxs: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover
model(idxs)

graph = model.scales_graph()

# Version-dependent, see https://github.com/graphcore-research/unit-scaling/pull/52
var_lhs_flatten = "x"
var_lhs_view = "x_1"
expected_targets = {
"idxs",
"emb_weight",
F.embedding,
"flatten",
"view",
var_lhs_flatten,
var_lhs_view,
"linear_weight",
"linear_bias",
F.linear,
@@ -203,7 +207,7 @@ def forward(self, idxs: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover

graph = prune_same_scale_tensors(graph)
graph_targets = get_targets(graph)
expected_targets -= {"flatten", "view"}
expected_targets -= {var_lhs_flatten, var_lhs_view}
assert graph_targets == expected_targets

graph = prune_same_scale_tensors(graph, rtol=2**-4)
@@ -234,7 +238,7 @@ def forward(self, a: Tensor) -> Tensor: # pragma: no cover
operator.mul,
F.relu,
operator.sub,
"sum_1",
"f",
"output",
}
graph_targets = get_targets(graph)
13 changes: 7 additions & 6 deletions unit_scaling/tests/transforms/test_unit_scale.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

import logging
import math
import re
from typing import Tuple

import torch
@@ -94,14 +95,14 @@ def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]: # pragma: no cover
loss.backward()

expected_logs = [
"unit scaling function: add\n",
"unit scaling function: iadd\n",
"unit scaling function: iadd_1 (residual-add, tau=0.5)",
"unit scaling function: add_1 (residual-add, tau=0.5)",
r"unit scaling function: (input_2)\n",
r"unit scaling function: (input_4)\n",
r"unit scaling function: (skip_1|input_3) \(residual-add, tau=0\.5\)",
r"unit scaling function: (add_1|input_6) \(residual-add, tau=0\.5\)",
]
print(caplog.text)

for log_msg in expected_logs:
assert log_msg in caplog.text
assert re.search(log_msg, caplog.text)


def test_fp8_unit_scaling(caplog: LogCaptureFixture) -> None:
129 changes: 62 additions & 67 deletions unit_scaling/transforms/utils.py
Original file line number Diff line number Diff line change
@@ -2,23 +2,22 @@

"""Utilities for working with transforms."""

import copy
import functools
from contextlib import contextmanager
from copy import copy, deepcopy
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
no_type_check,
)
from unittest.mock import patch

import torch
import torch._dynamo
from torch import Tensor, nn
from torch.fx.graph import Graph
@@ -35,58 +34,41 @@
_unit_scaled_functions = [getattr(U, f) for f in U.__all__]


def _get_patched_allowed_function_ids(
non_recurse_functions: Iterable[Callable[..., Any]],
) -> Set[int]:
allowed_function_ids = copy(torch._dynamo.allowed_functions._allowed_function_ids)
for v in nn.modules.__dict__.values():
if isinstance(v, type) and v not in nn.modules.loss.__dict__.values():
i = id(v)
if i in allowed_function_ids:
allowed_function_ids.remove(i)
for f in non_recurse_functions:
allowed_function_ids.add(id(f))
return allowed_function_ids # type: ignore[no-any-return]


def _patched_call_function( # type: ignore[no-untyped-def]
self,
tx,
args,
kwargs,
): # pragma: no cover
if isinstance(self.obj, torch._dynamo.variables.NNModuleVariable):
module_attr = getattr(self.fn, "__module__", "")
if (
module_attr is not None
and module_attr.startswith("torch.nn.modules.module")
or self.is_constant
):
return self.obj.call_method( # type: ignore[no-untyped-call]
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
).add_options(self)
return super(
torch._dynamo.variables.functions.UserMethodVariable, self
).call_function(tx, args, kwargs)


@contextmanager
def _expand_modules_patch(non_recurse_functions): # type: ignore[no-untyped-def]
patcher_a = patch(
"torch._dynamo.allowed_functions._allowed_function_ids",
new=_get_patched_allowed_function_ids(non_recurse_functions),
)
patcher_b = patch(
"torch._dynamo.variables.functions.UserMethodVariable.call_function",
new=_patched_call_function,
)
with patcher_a, patcher_b:
yield (patcher_a.start(), patcher_b.start())


def patch_to_expand_modules(
fn: Callable[..., T], non_recurse_functions: Iterable[Callable[..., Any]] = ()
) -> Callable[..., T]:
def torch_nn_modules_to_user_modules(mod: nn.Module) -> Any:
"""
Convert torch.nn.module classes to `trivial_subclass` versions.
By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or
:mod:`torch.nn.functional` functions when capturing the FX graph.
This function makes `torch.nn` modules into user modules.
To use this with a :class:`torch.nn.Module` the typical use case
is to call `module = torch_nn_modules_to_user_modules(module)`.
"""

for n, submod in mod.named_modules():
# Mirroring the check in https://github.com/pytorch/pytorch/blob/72662bf05b3499ce96aae9183a489c78f0c44c84/torch/_dynamo/variables/functions.py#L335 # noqa: E501
if submod.__module__.startswith("torch.nn."):
# Generate a new name, so e.g. torch.nn.modules.sparse.Embedding
# becomes trivial_subclass_modules_sparse_Embedding
modulename = submod.__module__
modulename = modulename.replace("torch.nn.", "", 1)
modulename = modulename.replace(".", "_")
newtypename = "trivial_subclass_" + modulename + "_" + type(submod).__name__

# Create a new type object deriving from type(submod)
newmodtype = type(newtypename, (type(submod),), {})

# Initialize and copy state using pickle
newsubmod = newmodtype.__new__(newmodtype) # type: ignore [call-overload]
newsubmod.__setstate__(submod.__getstate__())

# Update module in mod
setattr(mod, n, newsubmod)


def patch_to_expand_modules(fn: Callable[..., T]) -> Callable[..., T]:
"""By default TorchDynamo doesn't recurse into :mod:`torch.nn` modules or
:mod:`torch.nn.functional` functions when capturing the FX graph.
Any function which is wrapped in
@@ -98,21 +80,32 @@ def patch_to_expand_modules(
is to call `module = torch._dynamo.optimize(backend)(module)`, followed by
`module.forward = patch_to_expand_modules(module.forward)`.
In addition, any functions the user *doesn't* wish to recurse into can be passed
into `non_recurse_functions` and these will not be expanded.
This should be used in conjunction with :func:`torch_nn_modules_to_user_modules`
Args:
fn (Callable[..., T]): the function to be patched.
non_recurse_functions (Iterable[Callable[..., Any]], optional): functions which
the user does not wish to be recursed into. Defaults to ().
Returns:
Callable[..., T]: the new version of `fn` with patching applied.
"""

def _patched_call_function( # type: ignore[no-untyped-def]
self,
tx,
args,
kwargs,
): # pragma: no cover
# Removing the check in https://github.com/pytorch/pytorch/blob/72662bf05b3499ce96aae9183a489c78f0c44c84/torch/_dynamo/variables/functions.py#L335 # noqa: E501
return super(
torch._dynamo.variables.functions.UserMethodVariable, self
).call_function(tx, args, kwargs)

@functools.wraps(fn)
def new_fn(*args: Any, **kwargs: Any) -> Any:
with _expand_modules_patch(non_recurse_functions):
with patch(
"torch._dynamo.variables.functions.UserMethodVariable.call_function",
new=_patched_call_function,
):
return fn(*args, **kwargs)

return new_fn
@@ -211,22 +204,24 @@ def apply_transform(
Returns:
nn.Module: the transformed module.
"""
module = deepcopy(module)
module = copy.deepcopy(module)

torch_nn_modules_to_user_modules(module)

if not hasattr(module, "backends"):
module.backends = []
module.backends.append(backend)
if not hasattr(module, "non_recurse_functions"):
module.non_recurse_functions = list(_unit_scaled_functions)
module.non_recurse_functions += non_recurse_functions

for v in non_recurse_functions:
torch._dynamo.allow_in_graph(v)

backend = _compose_backends(module.backends)

def new_forward(*args: Any, **kwargs: Any) -> Any:
if module.rerun_transform:
torch._dynamo.reset()
dynamo_module = torch._dynamo.optimize(backend)(module)
module.dynamo_forward = patch_to_expand_modules(
dynamo_module.forward, module.non_recurse_functions
)
module.dynamo_forward = patch_to_expand_modules(dynamo_module.forward)
module.rerun_transform = False
with patch.object(module, "forward", module.base_forward):
return module.dynamo_forward(*args, **kwargs)

0 comments on commit 2394ee4

Please sign in to comment.