Skip to content

Avoid unnecessary tensor construction when creating input masks for permutation/ablation #1527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@
# pyre-strict

import math
from typing import Any, Callable, cast, Generator, List, Optional, Tuple, TypeVar, Union
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
List,
Optional,
Tuple,
TypeVar,
Union,
)

import torch
from captum._utils.common import (
Expand Down Expand Up @@ -465,13 +476,21 @@ def _attribute_with_cross_tensor_feature_masks(
attrib_type: dtype,
**kwargs: Any,
) -> Tuple[List[Tensor], List[Tensor]]:
feature_idx_to_tensor_idx: Dict[int, List[int]] = {}
for i, mask in enumerate(formatted_feature_mask):
for feature_idx in torch.unique(mask):
if feature_idx.item() not in feature_idx_to_tensor_idx:
feature_idx_to_tensor_idx[feature_idx.item()] = []
feature_idx_to_tensor_idx[feature_idx.item()].append(i)

for (
current_inputs,
current_mask,
) in self._ablation_generator(
formatted_inputs,
baselines,
formatted_feature_mask,
feature_idx_to_tensor_idx,
**kwargs,
):
# modified_eval has (n_feature_perturbed * n_outputs) elements
Expand Down Expand Up @@ -511,27 +530,28 @@ def _ablation_generator(
inputs: Tuple[Tensor, ...],
baselines: BaselineType,
input_mask: Tuple[Tensor, ...],
feature_idx_to_tensor_idx: Dict[int, List[int]],
**kwargs: Any,
) -> Generator[
Tuple[
Tuple[Tensor, ...],
Tuple[Tensor, ...],
Tuple[Optional[Tensor], ...],
],
None,
None,
]:
unique_feature_ids = torch.unique(
torch.cat([mask.flatten() for mask in input_mask])
).tolist()

if isinstance(baselines, torch.Tensor):
baselines = baselines.reshape((1,) + tuple(baselines.shape))

# Process one feature per time, rather than processing every input tensor
for feature_idx in unique_feature_ids:
for feature_idx in feature_idx_to_tensor_idx.keys():
ablated_inputs, current_masks = (
self._construct_ablated_input_across_tensors(
inputs, input_mask, baselines, feature_idx
inputs,
input_mask,
baselines,
feature_idx,
feature_idx_to_tensor_idx[feature_idx],
)
)
yield ablated_inputs, current_masks
Expand All @@ -542,18 +562,17 @@ def _construct_ablated_input_across_tensors(
input_mask: Tuple[Tensor, ...],
baselines: BaselineType,
feature_idx: int,
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
tensor_idxs: List[int],
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:

ablated_inputs = []
current_masks = []
current_masks: List[Optional[Tensor]] = []
for i, input_tensor in enumerate(inputs):
mask = input_mask[i]
tensor_mask = mask == feature_idx
if not tensor_mask.any():
if i not in tensor_idxs:
ablated_inputs.append(input_tensor)
current_masks.append(torch.zeros_like(tensor_mask))
current_masks.append(None)
continue
tensor_mask = tensor_mask.to(input_tensor.device).long()
tensor_mask = (input_mask[i] == feature_idx).to(input_tensor.device).long()
baseline = baselines[i] if isinstance(baselines, tuple) else baselines
if isinstance(baseline, torch.Tensor):
baseline = baseline.reshape(
Expand Down Expand Up @@ -1173,7 +1192,7 @@ def _process_ablated_out(
def _process_ablated_out_full(
self,
modified_eval: Tensor,
current_mask: Tuple[Tensor, ...],
current_mask: Tuple[Optional[Tensor], ...],
flattened_initial_eval: Tensor,
inputs: TensorOrTupleOfTensorsGeneric,
n_outputs: int,
Expand All @@ -1195,9 +1214,10 @@ def _process_ablated_out_full(

if self.use_weights:
for weight, mask in zip(weights, current_mask):
weight += mask.float().sum(dim=0)
if mask is not None:
weight += mask.float().sum(dim=0)
for i, mask in enumerate(current_mask):
if inputs[i].numel() == 0:
if mask is None or inputs[i].numel() == 0:
continue
eval_diff = eval_diff.reshape(
eval_diff_shape + (inputs[i].dim() - 1) * (1,)
Expand Down
21 changes: 13 additions & 8 deletions captum/attr/_core/feature_permutation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3

# pyre-strict
from typing import Any, Callable, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
Expand All @@ -26,15 +26,15 @@ def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:


def _permute_features_across_tensors(
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Tensor, ...]
inputs: Tuple[Tensor, ...], feature_masks: Tuple[Optional[Tensor], ...]
) -> Tuple[Tensor, ...]:
"""
Permutes features across multiple input tensors using the corresponding
feature masks.
"""
permuted_outputs = []
for input_tensor, feature_mask in zip(inputs, feature_masks):
if not feature_mask.any():
if feature_mask is None or not feature_mask.any():
permuted_outputs.append(input_tensor)
continue
n = input_tensor.size(0)
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(
forward_func: Callable[..., Union[int, float, Tensor, Future[Tensor]]],
perm_func: Callable[[Tensor, Tensor], Tensor] = _permute_feature,
perm_func_cross_tensor: Callable[
[Tuple[Tensor, ...], Tuple[Tensor, ...]], Tuple[Tensor, ...]
[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]], Tuple[Tensor, ...]
] = _permute_features_across_tensors,
) -> None:
r"""
Expand Down Expand Up @@ -392,9 +392,14 @@ def _construct_ablated_input_across_tensors(
input_mask: Tuple[Tensor, ...],
baselines: BaselineType,
feature_idx: int,
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
feature_masks = tuple(
(mask == feature_idx).to(inputs[0].device) for mask in input_mask
)
tensor_idxs: List[int],
) -> Tuple[Tuple[Tensor, ...], Tuple[Optional[Tensor], ...]]:
current_masks: List[Optional[Tensor]] = []
for i, mask in enumerate(input_mask):
if i in tensor_idxs:
current_masks.append((mask == feature_idx).to(inputs[0].device))
else:
current_masks.append(None)
feature_masks = tuple(current_masks)
permuted_outputs = self.perm_func_cross_tensor(inputs, feature_masks)
return permuted_outputs, feature_masks