Skip to content

Commit

Permalink
More assertive require_grad check (#4709)
Browse files Browse the repository at this point in the history
Fixes #4687
  • Loading branch information
kit1980 authored Nov 8, 2023
1 parent 4c0665e commit c214707
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 8 deletions.
4 changes: 3 additions & 1 deletion tools/torchfix/tests/fixtures/misc/checker/require_grad.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import random
import torch
x = torch.zeros(1)
x.require_grad = False
x.require_grad = True
grad = random.choice([False, True])
x.require_grad = grad

# Don't trigger
x.requires_grad = False
require_grad = False
x.require_grad = 10
3 changes: 2 additions & 1 deletion tools/torchfix/tests/fixtures/misc/checker/require_grad.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
3:1 TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`?
4:1 TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`?
5:1 TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`?
7:1 TOR002 Likely typo `require_grad` in assignment. Did you mean `requires_grad`?
6 changes: 6 additions & 0 deletions tools/torchfix/tests/fixtures/misc/codemod/require_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@
x = torch.zeros(1)
x.require_grad = False
x.require_grad = True

# from https://github.com/pytorch/test-infra/issues/4687
import torch.nn as nn
model = nn.Module()
for name, param in model.named_parameters():
param.require_grad = 'specific_layer' in name
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@ import torch
x = torch.zeros(1)
x.requires_grad = False
x.requires_grad = True

# from https://github.com/pytorch/test-infra/issues/4687
import torch.nn as nn
model = nn.Module()
for name, param in model.named_parameters():
param.requires_grad = 'specific_layer' in name
9 changes: 3 additions & 6 deletions tools/torchfix/torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ class TorchRequireGradVisitor(TorchVisitor):
MESSAGE = "Likely typo `require_grad` in assignment. Did you mean `requires_grad`?"

def visit_Assign(self, node):
# Look for any assignment with `require_grad` attribute on the left
# and `False` or `True` on the right.
# Look for any assignment with `require_grad` attribute on the left.
#
# If this causes false-positives on real code (unlikely),
# we can do type inference (not sure if feasible here) or
# at least check that `torch` is imported in the file.
# This is unlikely to cause false-positives on real code, especially
# because TorchFix only looks at files that have a `torch` string.
if m.matches(
node,
m.Assign(
Expand All @@ -28,7 +26,6 @@ def visit_Assign(self, node):
target=m.Attribute(attr=m.Name(value="require_grad"))
)
],
value=(m.Name("True") | m.Name("False")),
),
):
replacement = node.with_deep_changes(
Expand Down

0 comments on commit c214707

Please sign in to comment.