@@ -25,31 +25,6 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
25
25
)
26
26
27
27
28
- def _permute_features_across_tensors (
29
- inputs : Tuple [Tensor , ...], feature_masks : Tuple [Optional [Tensor ], ...]
30
- ) -> Tuple [Tensor , ...]:
31
- """
32
- Permutes features across multiple input tensors using the corresponding
33
- feature masks.
34
- """
35
- permuted_outputs = []
36
- for input_tensor , feature_mask in zip (inputs , feature_masks ):
37
- if feature_mask is None or not feature_mask .any ():
38
- permuted_outputs .append (input_tensor )
39
- continue
40
- n = input_tensor .size (0 )
41
- assert n > 1 , "cannot permute features with batch_size = 1"
42
- perm = torch .randperm (n )
43
- no_perm = torch .arange (n )
44
- while (perm == no_perm ).all ():
45
- perm = torch .randperm (n )
46
- permuted_x = (
47
- input_tensor [perm ] * feature_mask .to (dtype = input_tensor .dtype )
48
- ) + (input_tensor * feature_mask .bitwise_not ().to (dtype = input_tensor .dtype ))
49
- permuted_outputs .append (permuted_x )
50
- return tuple (permuted_outputs )
51
-
52
-
53
28
class FeaturePermutation (FeatureAblation ):
54
29
r"""
55
30
A perturbation based approach to compute attribution, which
@@ -102,9 +77,6 @@ def __init__(
102
77
self ,
103
78
forward_func : Callable [..., Union [int , float , Tensor , Future [Tensor ]]],
104
79
perm_func : Callable [[Tensor , Tensor ], Tensor ] = _permute_feature ,
105
- perm_func_cross_tensor : Callable [
106
- [Tuple [Tensor , ...], Tuple [Optional [Tensor ], ...]], Tuple [Tensor , ...]
107
- ] = _permute_features_across_tensors ,
108
80
) -> None :
109
81
r"""
110
82
Args:
@@ -117,14 +89,9 @@ def __init__(
117
89
which applies a random permutation, this argument only needs
118
90
to be provided if a custom permutation behavior is desired.
119
91
Default: `_permute_feature`
120
- perm_func_cross_tensor (Callable, optional): Similar to perm_func,
121
- except it can permute grouped features across multiple input
122
- tensors, rather than taking each input tensor independently.
123
- Default: `_permute_features_across_tensors`
124
92
"""
125
93
FeatureAblation .__init__ (self , forward_func = forward_func )
126
94
self .perm_func = perm_func
127
- self .perm_func_cross_tensor = perm_func_cross_tensor
128
95
129
96
# suppressing error caused by the child class not having a matching
130
97
# signature to the parent
0 commit comments