Skip to content

Commit 661a0b3

Browse files
committed
add explanation of SplitInferenceModule
1 parent 12f0ae1 commit 661a0b3

File tree

1 file changed

+50
-3
lines changed

1 file changed

+50
-3
lines changed

src/diffusers/pipelines/free_noise_utils.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,65 @@ def __init__(
5050
self.input_kwargs_to_split = set(input_kwargs_to_split)
5151

5252
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
53-
r"""Forward method of `SplitInferenceModule`.
53+
r"""Forward method for the `SplitInferenceModule`.
5454
55-
All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be
56-
split that are specified in `inputs_to_split` when initializing the module.
55+
This method processes the input by splitting specified keyword arguments along a given dimension, running the
56+
underlying module on each split, and then concatenating the results. The splitting is controlled by the
57+
`split_size` and `split_dim` parameters specified during initialization.
58+
59+
Args:
60+
*args (`Any`):
61+
Positional arguments that are passed directly to the `module` without modification.
62+
**kwargs (`Dict[str, torch.Tensor]`):
63+
Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
64+
entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
65+
arguments are passed unchanged.
66+
67+
Returns:
68+
`Union[torch.Tensor, Tuple[torch.Tensor]]`:
69+
The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
70+
without it.
71+
- If the underlying module returns a single tensor, the result will be a single concatenated tensor
72+
along the same `split_dim` after processing all splits.
73+
- If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
74+
along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
75+
76+
Workflow:
77+
1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
78+
`torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
79+
2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
80+
that were passed.
81+
3. The output tensors from each split are concatenated back together along `split_dim` before returning.
82+
83+
Example:
84+
```python
85+
>>> import torch
86+
87+
>>> model = nn.Linear(1000, 1000)
88+
>>> split_module = SplitInferenceModule(
89+
... model, split_size=2, split_dim=0, input_kwargs_to_split=["input_data"]
90+
... )
91+
92+
>>> input_tensor = torch.randn(42, 1000)
93+
>>> # Will split the tensor into 21 slices of shape [2, 1000].
94+
>>> output = split_module(input_data=input_tensor)
95+
```
96+
97+
This method is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
98+
them into smaller chunks, processing each chunk separately, and then reassembling the results.
99+
100+
It is also possible to nest `SplitInferenceModule` across different split dimensions.
57101
"""
58102
split_inputs = {}
59103

104+
# 1. Split inputs that were specified during initialization and also present in passed kwargs
60105
for key in list(kwargs.keys()):
61106
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
62107
continue
63108
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
64109
kwargs.pop(key)
65110

111+
# 2. Invoke forward pass across each split
66112
results = []
67113
for split_input in zip(*split_inputs.values()):
68114
inputs = dict(zip(split_inputs.keys(), split_input))
@@ -71,6 +117,7 @@ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
71117
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
72118
results.append(intermediate_tensor_or_tensor_tuple)
73119

120+
# 3. Concatenate split restuls to obtain final outputs
74121
if isinstance(results[0], torch.Tensor):
75122
return torch.cat(results, dim=self.split_dim)
76123
elif isinstance(results[0], tuple):

0 commit comments

Comments
 (0)