Skip to content

Add enable_cross_tensor_attribution flag to attribute_future #1546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 115 additions & 78 deletions captum/attr/_core/feature_ablation.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ def attribute_future(
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
enable_cross_tensor_attribution: bool = False,
**kwargs: Any,
) -> Future[TensorOrTupleOfTensorsGeneric]:
r"""
Expand All @@ -743,17 +744,18 @@ def attribute_future(
formatted_additional_forward_args = _format_additional_forward_args(
additional_forward_args
)
num_examples = formatted_inputs[0].shape[0]
formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs)

assert (
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1
), "Perturbations per evaluation must be an integer and at least 1."
with torch.no_grad():
attr_progress = None
if show_progress:
attr_progress = self._attribute_progress_setup(
formatted_inputs,
formatted_feature_mask,
enable_cross_tensor_attribution,
**kwargs,
perturbations_per_eval=perturbations_per_eval,
)
Expand All @@ -768,7 +770,7 @@ def attribute_future(
formatted_additional_forward_args,
)

if show_progress:
if attr_progress is not None:
attr_progress.update()

processed_initial_eval_fut: Optional[
Expand All @@ -788,101 +790,136 @@ def attribute_future(
)
)

# The will be the same amount futures as modified_eval down there,
# since we cannot add up the evaluation result adhoc under async mode.
all_modified_eval_futures: List[
List[Future[Tuple[List[Tensor], List[Tensor]]]]
] = [[] for _ in range(len(inputs))]
# Iterate through each feature tensor for ablation
for i in range(len(formatted_inputs)):
# Skip any empty input tensors
if torch.numel(formatted_inputs[i]) == 0:
continue

for (
current_inputs,
current_add_args,
current_target,
current_mask,
) in self._ith_input_ablation_generator(
i,
if enable_cross_tensor_attribution:
raise NotImplementedError("Not supported yet")
else:
# pyre-fixme[7]: Expected `Future[Variable[TensorOrTupleOfTensorsGeneric
# <:[Tensor, typing.Tuple[Tensor, ...]]]]` but got
# `Future[Union[Tensor, typing.Tuple[Tensor, ...]]]`
return self._attribute_with_independent_feature_masks_future( # type: ignore # noqa: E501 line too long
formatted_inputs,
formatted_additional_forward_args,
target,
baselines,
formatted_feature_mask,
perturbations_per_eval,
attr_progress,
processed_initial_eval_fut,
is_inputs_tuple,
**kwargs,
):
# modified_eval has (n_feature_perturbed * n_outputs) elements
# shape:
# agg mode: (*initial_eval.shape)
# non-agg mode:
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
modified_eval: Union[Tensor, Future[Tensor]] = _run_forward(
self.forward_func,
current_inputs,
current_target,
current_add_args,
)
)

if show_progress:
attr_progress.update()
def _attribute_with_independent_feature_masks_future(
self,
formatted_inputs: Tuple[Tensor, ...],
formatted_additional_forward_args: Optional[Tuple[object, ...]],
target: TargetType,
baselines: BaselineType,
formatted_feature_mask: Tuple[Tensor, ...],
perturbations_per_eval: int,
attr_progress: Optional[Union[SimpleProgress[IterableType], tqdm]],
processed_initial_eval_fut: Future[
Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype]
],
is_inputs_tuple: bool,
**kwargs: Any,
) -> Future[Union[Tensor, Tuple[Tensor, ...]]]:
num_examples = formatted_inputs[0].shape[0]
# The will be the same amount futures as modified_eval down there,
# since we cannot add up the evaluation result adhoc under async mode.
all_modified_eval_futures: List[
List[Future[Tuple[List[Tensor], List[Tensor]]]]
] = [[] for _ in range(len(formatted_inputs))]
# Iterate through each feature tensor for ablation
for i in range(len(formatted_inputs)):
# Skip any empty input tensors
if torch.numel(formatted_inputs[i]) == 0:
continue

if not isinstance(modified_eval, torch.Future):
raise AssertionError(
"when using attribute_future, modified_eval should have "
f"Future type rather than {type(modified_eval)}"
)
if processed_initial_eval_fut is None:
raise AssertionError(
"processed_initial_eval_fut should not be None"
)
for (
current_inputs,
current_add_args,
current_target,
current_mask,
) in self._ith_input_ablation_generator(
i,
formatted_inputs,
formatted_additional_forward_args,
target,
baselines,
formatted_feature_mask,
perturbations_per_eval,
**kwargs,
):
# modified_eval has (n_feature_perturbed * n_outputs) elements
# shape:
# agg mode: (*initial_eval.shape)
# non-agg mode:
# (feature_perturbed * batch_size, *initial_eval.shape[1:])
modified_eval: Union[Tensor, Future[Tensor]] = _run_forward(
self.forward_func,
current_inputs,
current_target,
current_add_args,
)

# Need to collect both initial eval and modified_eval
eval_futs: Future[
List[
Future[
Union[
Tuple[
List[Tensor],
List[Tensor],
Tensor,
Tensor,
int,
dtype,
],
if attr_progress is not None:
attr_progress.update()

if not isinstance(modified_eval, torch.Future):
raise AssertionError(
"when using attribute_future, modified_eval should have "
f"Future type rather than {type(modified_eval)}"
)
if processed_initial_eval_fut is None:
raise AssertionError(
"processed_initial_eval_fut should not be None"
)

# Need to collect both initial eval and modified_eval
eval_futs: Future[
List[
Future[
Union[
Tuple[
List[Tensor],
List[Tensor],
Tensor,
Tensor,
]
int,
dtype,
],
Tensor,
]
]
] = collect_all(
[
processed_initial_eval_fut,
modified_eval,
]
)
]
] = collect_all(
[
processed_initial_eval_fut,
modified_eval,
]
)

ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = (
eval_futs.then(
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
eval_futs=eval_futs,
current_inputs=current_inputs,
current_mask=current_mask,
i=i,
perturbations_per_eval=perturbations_per_eval,
num_examples=num_examples,
formatted_inputs=formatted_inputs,
)
ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = (
eval_futs.then(
lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long
eval_futs=eval_futs,
current_inputs=current_inputs,
current_mask=current_mask,
i=i,
perturbations_per_eval=perturbations_per_eval,
num_examples=num_examples,
formatted_inputs=formatted_inputs,
)
)
)

all_modified_eval_futures[i].append(ablated_out_fut)
all_modified_eval_futures[i].append(ablated_out_fut)

if show_progress:
attr_progress.close()
if attr_progress is not None:
attr_progress.close()

return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long
return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long

# pyre-fixme[3] return type must be annotated
def _attribute_progress_setup(
Expand Down