|
4 | 4 | import warnings
|
5 | 5 | from abc import abstractmethod
|
6 | 6 | from os.path import join
|
7 |
| -from typing import ( |
8 |
| - Any, |
9 |
| - Callable, |
10 |
| - Iterator, |
11 |
| - List, |
12 |
| - NamedTuple, |
13 |
| - Optional, |
14 |
| - Tuple, |
15 |
| - Type, |
16 |
| - Union, |
17 |
| -) |
| 7 | +from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union |
18 | 8 |
|
19 | 9 | import torch
|
20 | 10 | from captum._utils.av import AV
|
21 |
| -from captum._utils.common import _get_module_from_name, _parse_version |
22 |
| -from captum._utils.gradient import ( |
23 |
| - _compute_jacobian_wrt_params, |
24 |
| - _compute_jacobian_wrt_params_with_sample_wise_trick, |
25 |
| -) |
| 11 | +from captum._utils.common import _parse_version |
26 | 12 | from captum._utils.progress import NullProgress, progress
|
27 | 13 | from captum.influence._core.influence import DataInfluence
|
28 | 14 | from captum.influence._utils.common import (
|
29 | 15 | _check_loss_fn,
|
| 16 | + _compute_jacobian_sample_wise_grads_per_batch, |
30 | 17 | _format_inputs_dataset,
|
31 | 18 | _get_k_most_influential_helper,
|
32 | 19 | _gradient_dot_product,
|
| 20 | + _influence_route_to_helpers, |
33 | 21 | _load_flexible_state_dict,
|
34 | 22 | _self_influence_by_batches_helper,
|
| 23 | + _set_active_parameters, |
| 24 | + KMostInfluentialResults, |
35 | 25 | )
|
36 | 26 | from captum.log import log_usage
|
37 | 27 | from torch import Tensor
|
|
69 | 59 | """
|
70 | 60 |
|
71 | 61 |
|
72 |
| -class KMostInfluentialResults(NamedTuple): |
73 |
| - """ |
74 |
| - This namedtuple stores the results of using the `influence` method. This method |
75 |
| - is implemented by all subclasses of `TracInCPBase` to calculate |
76 |
| - proponents / opponents. The `indices` field stores the indices of the |
77 |
| - proponents / opponents for each example in the test dataset. For example, if |
78 |
| - finding opponents, `indices[i][j]` stores the index in the training data of the |
79 |
| - example with the `j`-th highest influence score on the `i`-th example in the test |
80 |
| - dataset. Similarly, the `influence_scores` field stores the actual influence |
81 |
| - scores, so that `influence_scores[i][j]` is the influence score of example |
82 |
| - `indices[i][j]` in the training data on example `i` of the test dataset. |
83 |
| - Please see `TracInCPBase.influence` for more details. |
84 |
| - """ |
85 |
| - |
86 |
| - indices: Tensor |
87 |
| - influence_scores: Tensor |
88 |
| - |
89 |
| - |
90 | 62 | class TracInCPBase(DataInfluence):
|
91 | 63 | """
|
92 | 64 | To implement the `influence` method, classes inheriting from `TracInCPBase` will
|
@@ -448,34 +420,6 @@ def get_name(cls: Type["TracInCPBase"]) -> str:
|
448 | 420 | return cls.__name__
|
449 | 421 |
|
450 | 422 |
|
451 |
| -def _influence_route_to_helpers( |
452 |
| - influence_instance: TracInCPBase, |
453 |
| - inputs: Union[Tuple[Any, ...], DataLoader], |
454 |
| - k: Optional[int] = None, |
455 |
| - proponents: bool = True, |
456 |
| - **kwargs, |
457 |
| -) -> Union[Tensor, KMostInfluentialResults]: |
458 |
| - """ |
459 |
| - This is a helper function called by `TracInCP.influence` and |
460 |
| - `TracInCPFast.influence`. Those methods share a common logic in that they assume |
461 |
| - an instance of their respective classes implement 2 private methods |
462 |
| - (``_influence`, `_get_k_most_influential`), and the logic of |
463 |
| - which private method to call is common, as described in the documentation of the |
464 |
| - `influence` method. The arguments and return values of this function are the exact |
465 |
| - same as the `influence` method. Note that `influence_instance` refers to the |
466 |
| - instance for which the `influence` method was called. |
467 |
| - """ |
468 |
| - if k is None: |
469 |
| - return influence_instance._influence(inputs, **kwargs) |
470 |
| - else: |
471 |
| - return influence_instance._get_k_most_influential( |
472 |
| - inputs, |
473 |
| - k, |
474 |
| - proponents, |
475 |
| - **kwargs, |
476 |
| - ) |
477 |
| - |
478 |
| - |
479 | 423 | class TracInCP(TracInCPBase):
|
480 | 424 | def __init__(
|
481 | 425 | self,
|
@@ -630,23 +574,7 @@ def __init__(
|
630 | 574 | """
|
631 | 575 | self.layer_modules = None
|
632 | 576 | if layers is not None:
|
633 |
| - assert isinstance(layers, List), "`layers` should be a list!" |
634 |
| - assert len(layers) > 0, "`layers` cannot be empty!" |
635 |
| - assert isinstance( |
636 |
| - layers[0], str |
637 |
| - ), "`layers` should contain str layer names." |
638 |
| - self.layer_modules = [ |
639 |
| - _get_module_from_name(self.model, layer) for layer in layers |
640 |
| - ] |
641 |
| - for layer, layer_module in zip(layers, self.layer_modules): |
642 |
| - for name, param in layer_module.named_parameters(): |
643 |
| - if not param.requires_grad: |
644 |
| - warnings.warn( |
645 |
| - "Setting required grads for layer: {}, name: {}".format( |
646 |
| - ".".join(layer), name |
647 |
| - ) |
648 |
| - ) |
649 |
| - param.requires_grad = True |
| 577 | + self.layer_modules = _set_active_parameters(model, layers) |
650 | 578 |
|
651 | 579 | @log_usage()
|
652 | 580 | def influence( # type: ignore[override]
|
@@ -1463,19 +1391,6 @@ def _basic_computation_tracincp(
|
1463 | 1391 | argument is only used if `sample_wise_grads_per_batch` was true in
|
1464 | 1392 | initialization.
|
1465 | 1393 | """
|
1466 |
| - if self.sample_wise_grads_per_batch: |
1467 |
| - return _compute_jacobian_wrt_params_with_sample_wise_trick( |
1468 |
| - self.model, |
1469 |
| - inputs, |
1470 |
| - targets, |
1471 |
| - loss_fn, |
1472 |
| - reduction_type, |
1473 |
| - self.layer_modules, |
1474 |
| - ) |
1475 |
| - return _compute_jacobian_wrt_params( |
1476 |
| - self.model, |
1477 |
| - inputs, |
1478 |
| - targets, |
1479 |
| - loss_fn, |
1480 |
| - self.layer_modules, |
| 1394 | + return _compute_jacobian_sample_wise_grads_per_batch( |
| 1395 | + self, inputs, targets, loss_fn, reduction_type |
1481 | 1396 | )
|
0 commit comments