Skip to content

Commit

Permalink
Fix PyTorch nightly tests (#9983)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 27, 2025
1 parent 9bffcd4 commit 7b95800
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions torch_geometric/nn/fx.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import warnings
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Type, Union

import torch
from torch import Tensor
from torch.nn import Module, ModuleDict, ModuleList, Sequential

try:
Expand Down Expand Up @@ -289,7 +290,7 @@ def is_leaf_module(self, module: Module, *args, **kwargs) -> bool:
# details on the rationale
# TODO: Revisit https://github.com/pyg-team/pytorch_geometric/pull/5021
@st.compatibility(is_backward_compatible=True)
def trace(self, root: st.Union[torch.nn.Module, st.Callable[..., Any]],
def trace(self, root: Union[torch.nn.Module, Callable[..., Any]],
concrete_args: Optional[Dict[str, Any]] = None) -> Graph:

if isinstance(root, torch.nn.Module):
Expand All @@ -303,17 +304,16 @@ def trace(self, root: st.Union[torch.nn.Module, st.Callable[..., Any]],
self.root = torch.nn.Module()
fn = root

tracer_cls: Optional[st.Type['Tracer']] = getattr(
tracer_cls: Optional[Type['Tracer']] = getattr(
self, '__class__', None)
self.graph = Graph(tracer_cls=tracer_cls)

self.tensor_attrs: Dict[st.Union[torch.Tensor, st.ScriptObject],
str] = {}
self.tensor_attrs: Dict[Union[Tensor, st.ScriptObject], str] = {}

def collect_tensor_attrs(m: torch.nn.Module,
prefix_atoms: st.List[str]):
prefix_atoms: List[str]):
for k, v in m.__dict__.items():
if isinstance(v, (torch.Tensor, st.ScriptObject)):
if isinstance(v, (Tensor, st.ScriptObject)):
self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])
for k, v in m.named_children():
collect_tensor_attrs(v, prefix_atoms + [k])
Expand Down

0 comments on commit 7b95800

Please sign in to comment.