Skip to content

Commit

Permalink
Fixed issue #32, where some torchvision ops weren't being decorated (…
Browse files Browse the repository at this point in the history
…since they're added separately to the PyTorch namespace)
  • Loading branch information
JohnMark Taylor committed Dec 27, 2024
1 parent f38d342 commit 334c673
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

setup(
name="torchlens",
version="0.1.25",
version="0.1.26",
description="A package for extracting activations from PyTorch models",
long_description="A package for extracting activations from PyTorch models. Contains functionality for "
"extracting model activations, visualizing a model's computational graph, and "
Expand Down
15 changes: 15 additions & 0 deletions torchlens/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,5 +415,20 @@ def my_get_overridable_functions() -> List:
return func_names


TORCHVISION_FUNCS = [
("torch.ops.torchvision.nms", "_op"),
("torch.ops.torchvision.deform_conv2d", "_op"),
("torch.ops.torchvision.ps_roi_align", "_op"),
("torch.ops.torchvision.ps_roi_pool", "_op"),
("torch.ops.torchvision.roi_align", "_op"),
("torch.ops.torchvision.roi_pool", "_op")]

OVERRIDABLE_FUNCS = my_get_overridable_functions()
ORIG_TORCH_FUNCS = OVERRIDABLE_FUNCS + IGNORED_FUNCS

try:
import torchvision

ORIG_TORCH_FUNCS += TORCHVISION_FUNCS
except ModuleNotFoundError:
pass

0 comments on commit 334c673

Please sign in to comment.