File tree Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -145,11 +145,15 @@ def pack_tensor(activation: torch.Tensor) -> int:
145
145
num_bytes = get_num_bytes_tensor (activation )
146
146
tensor_id = get_tensor_id ()
147
147
148
- # only offload hefty bois if they're activations (our heuristic for that is to
149
- # check if they're not params or buffers)!
150
- if num_bytes >= self .min_tensor_size_bytes and (
151
- not isinstance (activation , torch .nn .Parameter )
152
- and not isinstance (activation , torch .nn .Buffer )
148
+ # only offload hefty bois if they're activations on CUDA (our heuristic
149
+ # for that is to check if they're not params or buffers)!
150
+ if (
151
+ activation .is_cuda
152
+ and num_bytes >= self .min_tensor_size_bytes
153
+ and (
154
+ not isinstance (activation , torch .nn .Parameter )
155
+ and not isinstance (activation , torch .nn .Buffer )
156
+ )
153
157
):
154
158
if self .use_streams :
155
159
# First, sync back and dereference previously offloaded tensors
You can’t perform that action at this time.
0 commit comments