Skip to content

Commit

Permalink
Merge pull request #57 from graphcore-research/fix-transform-recursion
Browse files Browse the repository at this point in the history
Fix recursion in torch_nn_modules_to_user_modules()
DouglasOrr authored Jul 16, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents ebdbebc + f4402e6 commit 0a6c9af
Showing 2 changed files with 10 additions and 7 deletions.
8 changes: 4 additions & 4 deletions unit_scaling/tests/transforms/test_unit_scale.py
Original file line number Diff line number Diff line change
@@ -40,17 +40,17 @@ def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
return input, input.sum()

input = randn(2**6, 2**10, requires_grad=True)
model = MLPLayer(2**10)
model = nn.Sequential(MLPLayer(2**10))
model = unit_scale(model, replace={custom_gelu: F.gelu})
output, loss = model(input)
loss.backward()

assert_unit_scaled(
output,
input.grad,
model.layer_norm.weight.grad,
model.l1.weight.grad,
model.l2.weight.grad,
model[0].layer_norm.weight.grad,
model[0].l1.weight.grad,
model[0].l2.weight.grad,
abs=0.2,
)

9 changes: 6 additions & 3 deletions unit_scaling/transforms/utils.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
_unit_scaled_functions = [getattr(U, f) for f in U.__all__]


def torch_nn_modules_to_user_modules(mod: nn.Module) -> Any:
def torch_nn_modules_to_user_modules(mod: nn.Module) -> None:
"""
Convert torch.nn.module classes to `trivial_subclass` versions.
@@ -47,7 +47,9 @@ def torch_nn_modules_to_user_modules(mod: nn.Module) -> Any:
is to call `module = torch_nn_modules_to_user_modules(module)`.
"""

for n, submod in mod.named_modules():
for n, submod in mod.named_children():
torch_nn_modules_to_user_modules(submod)

# Mirroring the check at https://github.com/pytorch/pytorch/blob/34bce27f0d12bf7226b37dfe365660aad456701a/torch/_dynamo/variables/nn_module.py#L307 # noqa: E501
if submod.__module__.startswith(("torch.nn.", "torch.ao.")):
# Generate a new name, so e.g. torch.nn.modules.sparse.Embedding
@@ -62,7 +64,8 @@ def torch_nn_modules_to_user_modules(mod: nn.Module) -> Any:

# Initialize and copy state using pickle
newsubmod = newmodtype.__new__(newmodtype) # type: ignore [call-overload]
newsubmod.__setstate__(submod.__getstate__())
state = submod.__getstate__() # type: ignore [no-untyped-call]
newsubmod.__setstate__(state)

# Update module in mod
setattr(mod, n, newsubmod)

0 comments on commit 0a6c9af

Please sign in to comment.