Skip to content

Commit a08883f

Browse files
99warriorsfacebook-github-bot
authored andcommitted
allow self influence iteration options (#1002)
Summary: Pull Request resolved: #1002 - For self influence computation, there needs to be an iteration over both checkpoints as well as batches. This diff adds a `by_checkpoints` option. If true, the outer iteration is over checkpoints. If false, the outer iteration is over checkpoints. Because self influence computation can be called through the `influence` and `self_influence` methods, this option is added to both methods. Because only `TracInCP` and `TracInCPFast` should be used for self influence computation, only those classes are changed. - The implement this option, the old `self_influence` method, which had the outer iteration over checkpoints, is renamed to be a private `_self_influence_by_checkpoints` method. A new `_self_influence_by_batches` method is added, which has an outer iteration over batches, and re-uses the `_self_influence_by_checkpoints` method to compute self influence scores for a single batch (this method can accept both a single batch, as well as a dataloader yielding batches). Because the logic of this method is the same for all classes, a helper method, `_self_influence_by_batches_helper`, is added to `captum.influence._utils.common`. Finally, the new `self_influence` method simply chooses whether to call `_self_influence_by_checkpoints` or `_self_influence_by_batches`. - Documentation describing the two options for `by_checkpoints` is added to the `self_influence` and `influence` methods. - `test_tracin_show_progress` now differentiates between 2 modes: "self influence by checkpoints" (the original test for progress bar when calculating self influence scores, which checks whether the outer progress bar over checkpoints and inner progress bars over batches both reach 100%), and the newly added mode "self influence by batches", which checks whether the progress bar over batches reaches 100%. - `test_tracin_self_influence` now also checks whether computing self influence scores gives the same result regardless of whether `by_checkpoints` is True or False Reviewed By: NarineK Differential Revision: D37743920 fbshipit-source-id: ead1bbc86e8eac477768113b9939556d9b1c0de1
1 parent 1a10252 commit a08883f

File tree

5 files changed

+372
-75
lines changed

5 files changed

+372
-75
lines changed

captum/influence/_core/tracincp.py

+86-16
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_get_k_most_influential_helper,
3131
_gradient_dot_product,
3232
_load_flexible_state_dict,
33+
_self_influence_by_batches_helper,
3334
)
3435
from captum.log import log_usage
3536
from torch import Tensor
@@ -475,7 +476,8 @@ def _influence_route_to_helpers(
475476

476477
if inputs is None:
477478
return influence_instance.self_influence(
478-
influence_instance.train_dataloader, show_progress
479+
influence_instance.train_dataloader,
480+
show_progress,
479481
)
480482
elif k is None:
481483
return influence_instance._influence(_inputs, targets, show_progress)
@@ -727,11 +729,9 @@ def influence( # type: ignore[override]
727729
requires "training dataset computations": computations for each
728730
batch in the training dataset `train_dataset`, which may
729731
take a long time. If `show_progress`is true, the progress of
730-
"training dataset computations" will be displayed. In particular,
731-
the number of batches for which computations have been performed
732-
will be displayed. It will try to use tqdm if available for
733-
advanced features (e.g. time estimation). Otherwise, it will
734-
fallback to a simple output of progress.
732+
"training dataset computations" will be displayed. It will try to
733+
use tqdm if available for advanced features (e.g. time estimation).
734+
Otherwise, it will fallback to a simple output of progress.
735735
Default: False
736736
737737
Returns:
@@ -926,7 +926,7 @@ def _get_k_most_influential(
926926
(
927927
f"Using {self.get_name()} to perform computation for "
928928
f'getting {"proponents" if proponents else "opponents"}. '
929-
"Processing training batches: 100%"
929+
"Processing training batches"
930930
)
931931
)
932932
)
@@ -943,7 +943,7 @@ def _get_k_most_influential(
943943
)
944944
)
945945

946-
def self_influence(
946+
def _self_influence_by_checkpoints(
947947
self,
948948
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
949949
show_progress: bool = False,
@@ -957,7 +957,11 @@ def self_influence(
957957
will call `model` on that single batch, and if `inputs_dataset` yields
958958
batches, this will call `model` on each batch that is yielded. Therefore,
959959
please ensure that for both cases, the batch(es) that `model` is called
960-
with are not too large, so that there will not be an out-of-memory error.
960+
with are not too large, so that there will not be an out-of-memory error. This
961+
implementation performs an outer iteration over checkpoints, and an inner
962+
iteration over all batches that `inputs_dataset` represents. The pros of this
963+
implementation are that the checkpoints do not need to be loaded too many
964+
times.
961965
962966
Args:
963967
batches (Tuple, or DataLoader): Either a single tuple of any, or a
@@ -976,13 +980,10 @@ def self_influence(
976980
displayed. In more detail, this computation will iterate over all
977981
checkpoints (provided as the `checkpoints` initialization argument)
978982
in an outer loop, and iterate over all batches that
979-
`inputs_dataset` represents in an inner loop. Therefore, the
980-
total number of (checkpoint, batch) combinations that need to be
981-
iterated over is
982-
(# of checkpoints x # of batches that `inputs_dataset` represents).
983-
If `show_progress` is True, the total progress of both the outer
984-
iteration over checkpoints and the inner iteration over batches is
985-
displayed. It will try to use tqdm if available for advanced
983+
`inputs_dataset` represents in an inner loop. Thus if
984+
`show_progress` is True, the progress of both the outer
985+
iteration and the inner iterations will be displayed. To show
986+
progress, it will try to use tqdm if available for advanced
986987
features (e.g. time estimation). Otherwise, it will fallback to a
987988
simple output of progress.
988989
Default: False
@@ -1097,6 +1098,75 @@ def get_checkpoint_contribution(checkpoint):
10971098

10981099
return batches_self_tracin_scores
10991100

1101+
def self_influence(
1102+
self,
1103+
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
1104+
show_progress: bool = False,
1105+
outer_loop_by_checkpoints: bool = False,
1106+
) -> Tensor:
1107+
"""
1108+
Computes self influence scores for the examples in `inputs_dataset`, which is
1109+
either a single batch or a Pytorch `DataLoader` that yields batches. Therefore,
1110+
the computed self influence scores are *not* for the examples in training
1111+
dataset `train_dataset` (unlike when computing self influence scores using the
1112+
`influence` method). Note that if `inputs_dataset` is a single batch, this
1113+
will call `model` on that single batch, and if `inputs_dataset` yields
1114+
batches, this will call `model` on each batch that is yielded. Therefore,
1115+
please ensure that for both cases, the batch(es) that `model` is called
1116+
with are not too large, so that there will not be an out-of-memory error.
1117+
Internally, this computation requires iterating both over the batches in
1118+
`inputs_dataset`, as well as different model checkpoints. There are two ways
1119+
this iteration can be done. If `outer_loop_by_checkpoints` is False, the outer
1120+
iteration will be over batches, and the inner iteration will be over
1121+
checkpoints. This has the pro that displaying the progress of the computation
1122+
is more intuitive, involving displaying the number of batches for which self
1123+
influence scores have been computed. If `outer_loop_by_checkpoints` is True,
1124+
the outer iteration will be over checkpoints, and the inner iteration will be
1125+
over batches. This has the pro that the checkpoints do not need to be loaded
1126+
for each batch. For large models, loading checkpoints can be time-intensive.
1127+
1128+
Args:
1129+
batches (Tuple, or DataLoader): Either a single tuple of any, or a
1130+
`DataLoader`, where each batch yielded is a tuple of any. In
1131+
either case, the tuple represents a single batch, where the last
1132+
element is assumed to be the labels for the batch. That is,
1133+
`model(*batch[0:-1])` produces the output for `model`,
1134+
and `batch[-1]` are the labels, if any. This is the same
1135+
assumption made for each batch yielded by training dataset
1136+
`train_dataset`. Please see documentation for the
1137+
`train_dataset` argument to `TracInCP.__init__` for
1138+
more details on the assumed structure of a batch.
1139+
show_progress (bool, optional): Computation of self influence scores can
1140+
take a long time if `inputs_dataset` represents many examples. If
1141+
`show_progress`is true, the progress of this computation will be
1142+
displayed. In more detail, if `outer_loop_by_checkpoints` is False,
1143+
this computation will iterate over all batches in an outer loop.
1144+
Thus if `show_progress` is True, the number of batches for which
1145+
self influence scores have been computed will be displayed. If
1146+
`outer_loop_by_checkpoints` is True, this computation will iterate
1147+
over all checkpoints (provided as the `checkpoints` initialization
1148+
argument) in an outer loop, and iterate over all batches that
1149+
`inputs_dataset` represents in an inner loop. Thus if
1150+
`show_progress` is True, the progress of both the outer
1151+
iteration and the inner iterations will be displayed. To show
1152+
progress, it will try to use tqdm if available for advanced
1153+
features (e.g. time estimation). Otherwise, it will fallback to a
1154+
simple output of progress.
1155+
Default: False
1156+
outer_loop_by_checkpoints (bool, optional): If performing an outer
1157+
iteration over checkpoints; see method description for more
1158+
details.
1159+
Default: False
1160+
"""
1161+
if outer_loop_by_checkpoints:
1162+
return self._self_influence_by_checkpoints(inputs_dataset, show_progress)
1163+
return _self_influence_by_batches_helper(
1164+
self._self_influence_by_checkpoints,
1165+
self.get_name(),
1166+
inputs_dataset,
1167+
show_progress,
1168+
)
1169+
11001170
def _basic_computation_tracincp(
11011171
self,
11021172
inputs: Tuple[Any, ...],

captum/influence/_core/tracincp_fast_rand_proj.py

+89-15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_get_k_most_influential_helper,
1818
_jacobian_loss_wrt_inputs,
1919
_load_flexible_state_dict,
20+
_self_influence_by_batches_helper,
2021
_tensor_batch_dot,
2122
)
2223
from captum.influence._utils.nearest_neighbors import (
@@ -263,11 +264,9 @@ def influence( # type: ignore[override]
263264
requires "training dataset computations": computations for each
264265
batch in the training dataset `train_dataset`, which may
265266
take a long time. If `show_progress`is true, the progress of
266-
"training dataset computations" will be displayed. In particular,
267-
the number of batches for which computations have been performed
268-
will be displayed. It will try to use tqdm if available for
269-
advanced features (e.g. time estimation). Otherwise, it will
270-
fallback to a simple output of progress.
267+
"training dataset computations" will be displayed. It will try to
268+
use tqdm if available for advanced features (e.g. time estimation).
269+
Otherwise, it will fallback to a simple output of progress.
271270
Default: False
272271
273272
Returns:
@@ -466,7 +465,7 @@ def _get_k_most_influential( # type: ignore[override]
466465
(
467466
f"Using {self.get_name()} to perform computation for "
468467
f'getting {"proponents" if proponents else "opponents"}. '
469-
"Processing training batches: 100%"
468+
"Processing training batches"
470469
)
471470
)
472471
)
@@ -483,7 +482,7 @@ def _get_k_most_influential( # type: ignore[override]
483482
)
484483
)
485484

486-
def self_influence(
485+
def _self_influence_by_checkpoints(
487486
self,
488487
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
489488
show_progress: bool = False,
@@ -497,7 +496,11 @@ def self_influence(
497496
will call `model` on that single batch, and if `inputs_dataset` yields
498497
batches, this will call `model` on each batch that is yielded. Therefore,
499498
please ensure that for both cases, the batch(es) that `model` is called
500-
with are not too large, so that there will not be an out-of-memory error.
499+
with are not too large, so that there will not be an out-of-memory error. This
500+
implementation performs an outer iteration over checkpoints, and an inner
501+
iteration over all batches that `inputs_dataset` represents. The pros of this
502+
implementation are that the checkpoints do not need to be loaded too many
503+
times.
501504
502505
Args:
503506
batches (Tuple, or DataLoader): Either a single tuple of any, or a
@@ -516,13 +519,10 @@ def self_influence(
516519
displayed. In more detail, this computation will iterate over all
517520
checkpoints (provided as the `checkpoints` initialization argument)
518521
in an outer loop, and iterate over all batches that
519-
`inputs_dataset` represents in an inner loop. Therefore, the
520-
total number of (checkpoint, batch) combinations that need to be
521-
iterated over is
522-
(# of checkpoints x # of batches that `inputs_dataset` represents).
523-
If `show_progress` is True, the total progress of both the outer
524-
iteration over checkpoints and the inner iteration over batches is
525-
displayed. It will try to use tqdm if available for advanced
522+
`inputs_dataset` represents in an inner loop. Thus if
523+
`show_progress` is True, the progress of both the outer
524+
iteration and the inner iterations will be displayed. To show
525+
progress, it will try to use tqdm if available for advanced
526526
features (e.g. time estimation). Otherwise, it will fallback to a
527527
simple output of progress.
528528
Default: False
@@ -619,6 +619,75 @@ def get_checkpoint_contribution(checkpoint):
619619

620620
return batches_self_tracin_scores
621621

622+
def self_influence(
623+
self,
624+
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
625+
show_progress: bool = False,
626+
outer_loop_by_checkpoints: bool = False,
627+
) -> Tensor:
628+
"""
629+
Computes self influence scores for the examples in `inputs_dataset`, which is
630+
either a single batch or a Pytorch `DataLoader` that yields batches. Therefore,
631+
the computed self influence scores are *not* for the examples in training
632+
dataset `train_dataset` (unlike when computing self influence scores using the
633+
`influence` method). Note that if `inputs_dataset` is a single batch, this
634+
will call `model` on that single batch, and if `inputs_dataset` yields
635+
batches, this will call `model` on each batch that is yielded. Therefore,
636+
please ensure that for both cases, the batch(es) that `model` is called
637+
with are not too large, so that there will not be an out-of-memory error.
638+
Internally, this computation requires iterating both over the batches in
639+
`inputs_dataset`, as well as different model checkpoints. There are two ways
640+
this iteration can be done. If `outer_loop_by_checkpoints` is False, the outer
641+
iteration will be over batches, and the inner iteration will be over
642+
checkpoints. This has the pro that displaying the progress of the computation
643+
is more intuitive, involving displaying the number of batches for which self
644+
influence scores have been computed. If `outer_loop_by_checkpoints` is True,
645+
the outer iteration will be over checkpoints, and the inner iteration will be
646+
over batches. This has the pro that the checkpoints do not need to be loaded
647+
for each batch. For large models, loading checkpoints can be time-intensive.
648+
649+
Args:
650+
batches (Tuple, or DataLoader): Either a single tuple of any, or a
651+
`DataLoader`, where each batch yielded is a tuple of any. In
652+
either case, the tuple represents a single batch, where the last
653+
element is assumed to be the labels for the batch. That is,
654+
`model(*batch[0:-1])` produces the output for `model`,
655+
and `batch[-1]` are the labels, if any. This is the same
656+
assumption made for each batch yielded by training dataset
657+
`train_dataset`. Please see documentation for the
658+
`train_dataset` argument to `TracInCP.__init__` for
659+
more details on the assumed structure of a batch.
660+
show_progress (bool, optional): Computation of self influence scores can
661+
take a long time if `inputs_dataset` represents many examples. If
662+
`show_progress`is true, the progress of this computation will be
663+
displayed. In more detail, if `outer_loop_by_checkpoints` is False,
664+
this computation will iterate over all batches in an outer loop.
665+
Thus if `show_progress` is True, the number of batches for which
666+
self influence scores have been computed will be displayed. If
667+
`outer_loop_by_checkpoints` is True, this computation will iterate
668+
over all checkpoints (provided as the `checkpoints` initialization
669+
argument) in an outer loop, and iterate over all batches that
670+
`inputs_dataset` represents in an inner loop. Thus if
671+
`show_progress` is True, the progress of both the outer
672+
iteration and the inner iterations will be displayed. To show
673+
progress, it will try to use tqdm if available for advanced
674+
features (e.g. time estimation). Otherwise, it will fallback to a
675+
simple output of progress.
676+
Default: False
677+
outer_loop_by_checkpoints (bool, optional): If performing an outer
678+
iteration over checkpoints; see method description for more
679+
details.
680+
Default: False
681+
"""
682+
if outer_loop_by_checkpoints:
683+
return self._self_influence_by_checkpoints(inputs_dataset, show_progress)
684+
return _self_influence_by_batches_helper(
685+
self._self_influence_by_checkpoints,
686+
self.get_name(),
687+
inputs_dataset,
688+
show_progress,
689+
)
690+
622691

623692
def _basic_computation_tracincp_fast(
624693
influence_instance: TracInCPFast,
@@ -946,6 +1015,7 @@ def self_influence(
9461015
self,
9471016
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
9481017
show_progress: bool = False,
1018+
outer_loop_by_checkpoints: bool = False,
9491019
) -> Tensor:
9501020
"""
9511021
NOT IMPLEMENTED - no need to implement `TracInCPFastRandProj.self_influence`,
@@ -985,6 +1055,10 @@ def self_influence(
9851055
if available for advanced features (e.g. time estimation).
9861056
Otherwise, it will fallback to a simple output of progress.
9871057
Default: False
1058+
outer_loop_by_checkpoints (bool, optional): If performing an outer
1059+
iteration over checkpoints; see method description for more
1060+
details.
1061+
Default: False
9881062
9891063
Returns:
9901064
self_influence_scores (Tensor): This is a 1D tensor containing the self

0 commit comments

Comments
 (0)