Skip to content

Commit 2db93a7

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Remove unneeded feature permutation utils
Summary: Looping over the features has moved up into `_construct_ablated_input_across_tensors` after D71435704, so we don't need this anymore Reviewed By: cyrjano Differential Revision: D72064893 fbshipit-source-id: 0f3ac2c967a577f4bc3f688893319aac3e51000d
1 parent 315a3d5 commit 2db93a7

File tree

1 file changed

+0
-33
lines changed

1 file changed

+0
-33
lines changed

captum/attr/_core/feature_permutation.py

-33
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,6 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
2525
)
2626

2727

28-
def _permute_features_across_tensors(
29-
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Optional[Tensor], ...]
30-
) -> Tuple[Tensor, ...]:
31-
"""
32-
Permutes features across multiple input tensors using the corresponding
33-
feature masks.
34-
"""
35-
permuted_outputs = []
36-
for input_tensor, feature_mask in zip(inputs, feature_masks):
37-
if feature_mask is None or not feature_mask.any():
38-
permuted_outputs.append(input_tensor)
39-
continue
40-
n = input_tensor.size(0)
41-
assert n > 1, "cannot permute features with batch_size = 1"
42-
perm = torch.randperm(n)
43-
no_perm = torch.arange(n)
44-
while (perm == no_perm).all():
45-
perm = torch.randperm(n)
46-
permuted_x = (
47-
input_tensor[perm] * feature_mask.to(dtype=input_tensor.dtype)
48-
) + (input_tensor * feature_mask.bitwise_not().to(dtype=input_tensor.dtype))
49-
permuted_outputs.append(permuted_x)
50-
return tuple(permuted_outputs)
51-
52-
5328
class FeaturePermutation(FeatureAblation):
5429
r"""
5530
A perturbation based approach to compute attribution, which
@@ -102,9 +77,6 @@ def __init__(
10277
self,
10378
forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]],
10479
perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature,
105-
perm_func_cross_tensor: Callable[
106-
[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]], Tuple[Tensor, ...]
107-
] = _permute_features_across_tensors,
10880
) -> None:
10981
r"""
11082
Args:
@@ -117,14 +89,9 @@ def __init__(
11789
which applies a random permutation, this argument only needs
11890
to be provided if a custom permutation behavior is desired.
11991
Default: `_permute_feature`
120-
perm_func_cross_tensor (Callable, optional): Similar to perm_func,
121-
except it can permute grouped features across multiple input
122-
tensors, rather than taking each input tensor independently.
123-
Default: `_permute_features_across_tensors`
12492
"""
12593
FeatureAblation.__init__(self, forward_func=forward_func)
12694
self.perm_func = perm_func
127-
self.perm_func_cross_tensor = perm_func_cross_tensor
12895

12996
# suppressing error caused by the child class not having a matching
13097
# signature to the parent

0 commit comments

Comments
 (0)