@@ -50,19 +50,65 @@ def __init__(
50
50
self .input_kwargs_to_split = set (input_kwargs_to_split )
51
51
52
52
def forward (self , * args , ** kwargs ) -> Union [torch .Tensor , Tuple [torch .Tensor ]]:
53
- r"""Forward method of `SplitInferenceModule`.
53
+ r"""Forward method for the `SplitInferenceModule`.
54
54
55
- All inputs that should be split should be passed as keyword arguments. Only those keywords arguments will be
56
- split that are specified in `inputs_to_split` when initializing the module.
55
+ This method processes the input by splitting specified keyword arguments along a given dimension, running the
56
+ underlying module on each split, and then concatenating the results. The splitting is controlled by the
57
+ `split_size` and `split_dim` parameters specified during initialization.
58
+
59
+ Args:
60
+ *args (`Any`):
61
+ Positional arguments that are passed directly to the `module` without modification.
62
+ **kwargs (`Dict[str, torch.Tensor]`):
63
+ Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
64
+ entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
65
+ arguments are passed unchanged.
66
+
67
+ Returns:
68
+ `Union[torch.Tensor, Tuple[torch.Tensor]]`:
69
+ The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
70
+ without it.
71
+ - If the underlying module returns a single tensor, the result will be a single concatenated tensor
72
+ along the same `split_dim` after processing all splits.
73
+ - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
74
+ along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
75
+
76
+ Workflow:
77
+ 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
78
+ `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
79
+ 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
80
+ that were passed.
81
+ 3. The output tensors from each split are concatenated back together along `split_dim` before returning.
82
+
83
+ Example:
84
+ ```python
85
+ >>> import torch
86
+
87
+ >>> model = nn.Linear(1000, 1000)
88
+ >>> split_module = SplitInferenceModule(
89
+ ... model, split_size=2, split_dim=0, input_kwargs_to_split=["input_data"]
90
+ ... )
91
+
92
+ >>> input_tensor = torch.randn(42, 1000)
93
+ >>> # Will split the tensor into 21 slices of shape [2, 1000].
94
+ >>> output = split_module(input_data=input_tensor)
95
+ ```
96
+
97
+ This method is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
98
+ them into smaller chunks, processing each chunk separately, and then reassembling the results.
99
+
100
+ It is also possible to nest `SplitInferenceModule` across different split dimensions.
57
101
"""
58
102
split_inputs = {}
59
103
104
+ # 1. Split inputs that were specified during initialization and also present in passed kwargs
60
105
for key in list (kwargs .keys ()):
61
106
if key not in self .input_kwargs_to_split or not torch .is_tensor (kwargs [key ]):
62
107
continue
63
108
split_inputs [key ] = torch .split (kwargs [key ], self .split_size , self .split_dim )
64
109
kwargs .pop (key )
65
110
111
+ # 2. Invoke forward pass across each split
66
112
results = []
67
113
for split_input in zip (* split_inputs .values ()):
68
114
inputs = dict (zip (split_inputs .keys (), split_input ))
@@ -71,6 +117,7 @@ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
71
117
intermediate_tensor_or_tensor_tuple = self .module (* args , ** inputs )
72
118
results .append (intermediate_tensor_or_tensor_tuple )
73
119
120
+ # 3. Concatenate split restuls to obtain final outputs
74
121
if isinstance (results [0 ], torch .Tensor ):
75
122
return torch .cat (results , dim = self .split_dim )
76
123
elif isinstance (results [0 ], tuple ):
0 commit comments