Skip to content

Commit 8dadbaa

Browse files
authored
Only offload if activation is on CUDA (#2466)
1 parent f3587e5 commit 8dadbaa

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

torchtune/training/_activation_offloading.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,15 @@ def pack_tensor(activation: torch.Tensor) -> int:
145145
num_bytes = get_num_bytes_tensor(activation)
146146
tensor_id = get_tensor_id()
147147

148-
# only offload hefty bois if they're activations (our heuristic for that is to
149-
# check if they're not params or buffers)!
150-
if num_bytes >= self.min_tensor_size_bytes and (
151-
not isinstance(activation, torch.nn.Parameter)
152-
and not isinstance(activation, torch.nn.Buffer)
148+
# only offload hefty bois if they're activations on CUDA (our heuristic
149+
# for that is to check if they're not params or buffers)!
150+
if (
151+
activation.is_cuda
152+
and num_bytes >= self.min_tensor_size_bytes
153+
and (
154+
not isinstance(activation, torch.nn.Parameter)
155+
and not isinstance(activation, torch.nn.Buffer)
156+
)
153157
):
154158
if self.use_streams:
155159
# First, sync back and dereference previously offloaded tensors

0 commit comments

Comments
 (0)