Skip to content

Commit 06acd02

Browse files
committed
Revert "Merge remote-tracking branch 'upstream/master' into optim-wip-clip-vis"
This reverts commit 6301aa5, reversing changes made to aea385a.
1 parent 6301aa5 commit 06acd02

35 files changed

+552
-1565
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ Captum can also be used by application engineers who are using trained models in
4949

5050
**Installation Requirements**
5151
- Python >= 3.6
52-
- PyTorch >= 1.6
52+
- PyTorch >= 1.2
5353

5454

5555
##### Installing the latest release

captum/__init__.py

-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
11
#!/usr/bin/env python3
2-
import captum.attr as attr # noqa
3-
import captum.concept as concept # noqa
4-
import captum.influence as influence # noqa
5-
import captum.log as log # noqa
6-
import captum.metrics as metrics # noqa
7-
import captum.robust as robust # noqa
8-
92

103
__version__ = "0.5.0"

captum/_utils/av.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(
4747
identifier: Optional[str] = None,
4848
layer: Optional[str] = None,
4949
num_id: Optional[str] = None,
50-
) -> None:
50+
):
5151
r"""
5252
Loads into memory the list of all activation file paths associated
5353
with the input `model_id`.

captum/_utils/models/linear_model/train.py

+70-69
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def sgd_train_linear_model(
9999
This will return the final training loss (averaged with
100100
`running_loss_window`)
101101
"""
102+
102103
loss_window: List[torch.Tensor] = []
103104
min_avg_loss = None
104105
convergence_counter = 0
@@ -144,77 +145,77 @@ def get_point(datapoint):
144145
if model.linear.bias is not None:
145146
model.linear.bias.zero_()
146147

147-
with torch.enable_grad():
148-
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
149-
if reduce_lr:
150-
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
151-
optim, factor=0.5, patience=patience, threshold=threshold
152-
)
153-
154-
t1 = time.time()
155-
epoch = 0
156-
i = 0
157-
while epoch < max_epoch:
158-
while True: # for x, y, w in dataloader
159-
if running_loss_window is None:
160-
running_loss_window = x.shape[0] * len(dataloader)
161-
162-
y = y.view(x.shape[0], -1)
163-
if w is not None:
164-
w = w.view(x.shape[0], -1)
165-
166-
i += 1
167-
168-
out = model(x)
169-
170-
loss = loss_fn(y, out, w)
171-
if reg_term is not None:
172-
reg = torch.norm(model.linear.weight, p=reg_term)
173-
loss += reg.sum() * alpha
174-
175-
if len(loss_window) >= running_loss_window:
176-
loss_window = loss_window[1:]
177-
loss_window.append(loss.clone().detach())
178-
assert len(loss_window) <= running_loss_window
179-
180-
average_loss = torch.mean(torch.stack(loss_window))
181-
if min_avg_loss is not None:
182-
# if we haven't improved by at least `threshold`
183-
if average_loss > min_avg_loss or torch.isclose(
184-
min_avg_loss, average_loss, atol=threshold
185-
):
186-
convergence_counter += 1
187-
if convergence_counter >= patience:
188-
converged = True
189-
break
190-
else:
191-
convergence_counter = 0
192-
if min_avg_loss is None or min_avg_loss >= average_loss:
193-
min_avg_loss = average_loss.clone()
194-
195-
if debug:
196-
print(
197-
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
198-
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
199-
)
200-
201-
loss.backward()
202-
optim.step()
203-
model.zero_grad()
204-
if scheduler:
205-
scheduler.step(average_loss)
206-
207-
temp = next(data_iter, None)
208-
if temp is None:
209-
break
210-
x, y, w = get_point(temp)
211-
212-
if converged:
148+
optim = torch.optim.SGD(model.parameters(), lr=initial_lr)
149+
if reduce_lr:
150+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
151+
optim, factor=0.5, patience=patience, threshold=threshold
152+
)
153+
154+
t1 = time.time()
155+
epoch = 0
156+
i = 0
157+
while epoch < max_epoch:
158+
while True: # for x, y, w in dataloader
159+
if running_loss_window is None:
160+
running_loss_window = x.shape[0] * len(dataloader)
161+
162+
y = y.view(x.shape[0], -1)
163+
if w is not None:
164+
w = w.view(x.shape[0], -1)
165+
166+
i += 1
167+
168+
out = model(x)
169+
170+
loss = loss_fn(y, out, w)
171+
if reg_term is not None:
172+
reg = torch.norm(model.linear.weight, p=reg_term)
173+
loss += reg.sum() * alpha
174+
175+
if len(loss_window) >= running_loss_window:
176+
loss_window = loss_window[1:]
177+
loss_window.append(loss.clone().detach())
178+
assert len(loss_window) <= running_loss_window
179+
180+
average_loss = torch.mean(torch.stack(loss_window))
181+
if min_avg_loss is not None:
182+
# if we haven't improved by at least `threshold`
183+
if average_loss > min_avg_loss or torch.isclose(
184+
min_avg_loss, average_loss, atol=threshold
185+
):
186+
convergence_counter += 1
187+
if convergence_counter >= patience:
188+
converged = True
189+
break
190+
else:
191+
convergence_counter = 0
192+
if min_avg_loss is None or min_avg_loss >= average_loss:
193+
min_avg_loss = average_loss.clone()
194+
195+
if debug:
196+
print(
197+
f"lr={optim.param_groups[0]['lr']}, Loss={loss},"
198+
+ "Aloss={average_loss}, min_avg_loss={min_avg_loss}"
199+
)
200+
201+
loss.backward()
202+
203+
optim.step()
204+
model.zero_grad()
205+
if scheduler:
206+
scheduler.step(average_loss)
207+
208+
temp = next(data_iter, None)
209+
if temp is None:
213210
break
211+
x, y, w = get_point(temp)
212+
213+
if converged:
214+
break
214215

215-
epoch += 1
216-
data_iter = iter(dataloader)
217-
x, y, w = get_point(next(data_iter))
216+
epoch += 1
217+
data_iter = iter(dataloader)
218+
x, y, w = get_point(next(data_iter))
218219

219220
t2 = time.time()
220221
return {

captum/_utils/progress.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
class DisableErrorIOWrapper(object):
15-
def __init__(self, wrapped: TextIO) -> None:
15+
def __init__(self, wrapped: TextIO):
1616
"""
1717
The wrapper around a TextIO object to ignore write errors like tqdm
1818
https://github.com/tqdm/tqdm/blob/bcce20f771a16cb8e4ac5cc5b2307374a2c0e535/tqdm/utils.py#L131
@@ -48,7 +48,7 @@ def __init__(
4848
total: int = None,
4949
file: TextIO = None,
5050
mininterval: float = 0.5,
51-
) -> None:
51+
):
5252
"""
5353
Simple progress output used when tqdm is unavailable.
5454
Same as tqdm, output to stderr channel

captum/_utils/sample_gradient.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from collections import defaultdict
22
from enum import Enum
3-
from typing import cast, DefaultDict, Iterable, List, Tuple, Union
3+
from typing import cast, Iterable, Tuple, Union
44

55
import torch
66
from captum._utils.common import _format_tensor_into_tuples, _register_backward_hook
77
from torch import Tensor
88
from torch.nn import Module
99

1010

11-
def _reset_sample_grads(module: Module) -> None:
11+
def _reset_sample_grads(module: Module):
1212
module.weight.sample_grad = 0 # type: ignore
1313
if module.bias is not None:
1414
module.bias.sample_grad = 0 # type: ignore
@@ -100,19 +100,19 @@ class SampleGradientWrapper:
100100
- https://github.com/pytorch/opacus/tree/main/opacus/grad_sample
101101
"""
102102

103-
def __init__(self, model) -> None:
103+
def __init__(self, model):
104104
self.model = model
105105
self.hooks_added = False
106-
self.activation_dict: DefaultDict[Module, List[Tensor]] = defaultdict(list)
107-
self.gradient_dict: DefaultDict[Module, List[Tensor]] = defaultdict(list)
108-
self.forward_hooks: List[torch.utils.hooks.RemovableHandle] = []
109-
self.backward_hooks: List[torch.utils.hooks.RemovableHandle] = []
106+
self.activation_dict = defaultdict(list)
107+
self.gradient_dict = defaultdict(list)
108+
self.forward_hooks = []
109+
self.backward_hooks = []
110110

111-
def add_hooks(self) -> None:
111+
def add_hooks(self):
112112
self.hooks_added = True
113113
self.model.apply(self._register_module_hooks)
114114

115-
def _register_module_hooks(self, module: torch.nn.Module) -> None:
115+
def _register_module_hooks(self, module: torch.nn.Module):
116116
if isinstance(module, tuple(SUPPORTED_MODULES.keys())):
117117
self.forward_hooks.append(
118118
module.register_forward_hook(self._forward_hook_fn)
@@ -126,7 +126,7 @@ def _forward_hook_fn(
126126
module: Module,
127127
module_input: Union[Tensor, Tuple[Tensor, ...]],
128128
module_output: Union[Tensor, Tuple[Tensor, ...]],
129-
) -> None:
129+
):
130130
inp_tuple = _format_tensor_into_tuples(module_input)
131131
self.activation_dict[module].append(inp_tuple[0].clone().detach())
132132

@@ -135,11 +135,11 @@ def _backward_hook_fn(
135135
module: Module,
136136
grad_input: Union[Tensor, Tuple[Tensor, ...]],
137137
grad_output: Union[Tensor, Tuple[Tensor, ...]],
138-
) -> None:
138+
):
139139
grad_output_tuple = _format_tensor_into_tuples(grad_output)
140140
self.gradient_dict[module].append(grad_output_tuple[0].clone().detach())
141141

142-
def remove_hooks(self) -> None:
142+
def remove_hooks(self):
143143
self.hooks_added = False
144144

145145
for hook in self.forward_hooks:
@@ -151,11 +151,11 @@ def remove_hooks(self) -> None:
151151
self.forward_hooks = []
152152
self.backward_hooks = []
153153

154-
def _reset(self) -> None:
154+
def _reset(self):
155155
self.activation_dict = defaultdict(list)
156156
self.gradient_dict = defaultdict(list)
157157

158-
def compute_param_sample_gradients(self, loss_blob, loss_mode="mean") -> None:
158+
def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"):
159159
assert (
160160
loss_mode.upper() in LossMode.__members__
161161
), f"Provided loss mode {loss_mode} is not valid"

captum/attr/_core/lime.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -512,17 +512,17 @@ def attribute(
512512
if show_progress:
513513
attr_progress.close()
514514

515-
combined_interp_inps = torch.cat(interpretable_inps).float()
515+
combined_interp_inps = torch.cat(interpretable_inps).double()
516516
combined_outputs = (
517517
torch.cat(outputs)
518518
if len(outputs[0].shape) > 0
519519
else torch.stack(outputs)
520-
).float()
520+
).double()
521521
combined_sim = (
522522
torch.cat(similarities)
523523
if len(similarities[0].shape) > 0
524524
else torch.stack(similarities)
525-
).float()
525+
).double()
526526
dataset = TensorDataset(
527527
combined_interp_inps, combined_outputs, combined_sim
528528
)
@@ -734,7 +734,7 @@ def __init__(
734734
735735
forward_func (callable): The forward function of the model or any
736736
modification of it
737-
interpretable_model (Model, optional): Model object to train
737+
interpretable_model (optional, Model): Model object to train
738738
interpretable model.
739739
740740
This argument is optional and defaults to SkLearnLasso(alpha=0.01),
@@ -760,7 +760,7 @@ def __init__(
760760
Note that calling fit multiple times should retrain the
761761
interpretable model, each attribution call reuses
762762
the same given interpretable model object.
763-
similarity_func (callable, optional): Function which takes a single sample
763+
similarity_func (optional, callable): Function which takes a single sample
764764
along with its corresponding interpretable representation
765765
and returns the weight of the interpretable sample for
766766
training the interpretable model.
@@ -793,7 +793,7 @@ def __init__(
793793
794794
kwargs includes baselines, feature_mask, num_interp_features
795795
(integer, determined from feature mask).
796-
perturb_func (callable, optional): Function which returns a single
796+
perturb_func (optional, callable): Function which returns a single
797797
sampled input, which is a binary vector of length
798798
num_interp_features, or a generator of such tensors.
799799

captum/attr/_core/saliency.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def attribute(
4343
r"""
4444
Args:
4545
46-
inputs (tensor or tuple of tensors): Input for which saliency
47-
is computed. If forward_func takes a single tensor
48-
as input, a single input tensor should be provided.
46+
inputs (tensor or tuple of tensors): Input for which integrated
47+
gradients are computed. If forward_func takes a single
48+
tensor as input, a single input tensor should be provided.
4949
If forward_func takes multiple tensors as input, a tuple
5050
of the input tensors should be provided. It is assumed
5151
that for all given input tensors, dimension 0 corresponds

0 commit comments

Comments
 (0)