Skip to content

Commit

Permalink
[Inductor][CPP] Fix node name for wgt delete (pytorch#147056)
Browse files Browse the repository at this point in the history
**Summary**
This is a regression issue caused by a change in the FX node name. In commit 71010bf, both the node name and target for the `get_attr` node in `V.graph.graph.nodes` were `_frozen_param2`. However, in the latest main, the node name has changed to `_reorder_linear_weight`. This PR fixes the regression by using the node's target instead of its name.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_cpp_weight_prune
```

Pull Request resolved: pytorch#147056
Approved by: https://github.com/jgong5
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Feb 14, 2025
1 parent 10bc8f2 commit bd019c0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
23 changes: 23 additions & 0 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,6 +2051,29 @@ def forward(self, x):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)

@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_cpp_weight_prune(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 128, bias=False)

def forward(self, x):
return self.linear(x)

v = torch.randn(2, 32).to(torch.bfloat16)
mod = M().eval().to(torch.bfloat16)
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
with verify(torch.bfloat16) as (atol, rtol):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["select_algorithm_weight_prune"], 1)

@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
Expand Down
14 changes: 8 additions & 6 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,9 @@ def get_candidates(input_nodes, new_input_nodes):
# Case may happen when the candidate tensor is used by more than 1 get_attr node
# https://github.com/pytorch/pytorch/issues/134998
if node.op == "get_attr" and hasattr(
V.graph.module, node.name
V.graph.module, node.target
): # candidate tensor might already be deleted
comp_tensor = getattr(V.graph.module, node.name)
comp_tensor = getattr(V.graph.module, node.target)
if isinstance(comp_tensor, torch.Tensor) and share_storage(
candidate_tensor, comp_tensor
):
Expand All @@ -395,13 +395,15 @@ def get_candidates(input_nodes, new_input_nodes):
# The get_attr node has only 1 user fx node
# The candidate tensor has been used by only 1 get_attr node
if (
node.name == candidate_node.get_name()
node.op == "get_attr"
and node.target == candidate_node.get_name()
and len(node.users) == 1
and candidate_tensor_users == 1
):
del V.graph.constants[node.name]
delattr(V.graph.module, node.name)
delattr(V.graph.graph.owning_module, node.name)
del V.graph.constants[node.target]
delattr(V.graph.module, node.target)
delattr(V.graph.graph.owning_module, node.target)
counters["inductor"]["select_algorithm_weight_prune"] += 1


def gen_2d_view_of_epilogue_buf(
Expand Down

0 comments on commit bd019c0

Please sign in to comment.