Skip to content

Commit afa6c7c

Browse files
sayakpaula-r-r-o-w
authored andcommitted
[Tests] clean up and refactor gradient checkpointing tests (#9494)
* check. * fixes * fixes * updates * fixes * fixes
1 parent 5d5af74 commit afa6c7c

15 files changed

+180
-273
lines changed

tests/models/autoencoders/test_models_vae.py

Lines changed: 25 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
load_hf_numpy,
4040
require_torch_accelerator,
4141
require_torch_accelerator_with_fp16,
42-
require_torch_accelerator_with_training,
4342
require_torch_gpu,
4443
skip_mps,
4544
slow,
@@ -170,52 +169,17 @@ def prepare_init_args_and_inputs_for_common(self):
170169
inputs_dict = self.dummy_input
171170
return init_dict, inputs_dict
172171

172+
@unittest.skip("Not tested.")
173173
def test_forward_signature(self):
174174
pass
175175

176+
@unittest.skip("Not tested.")
176177
def test_training(self):
177178
pass
178179

179-
@require_torch_accelerator_with_training
180-
def test_gradient_checkpointing(self):
181-
# enable deterministic behavior for gradient checkpointing
182-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
183-
model = self.model_class(**init_dict)
184-
model.to(torch_device)
185-
186-
assert not model.is_gradient_checkpointing and model.training
187-
188-
out = model(**inputs_dict).sample
189-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
190-
# we won't calculate the loss and rather backprop on out.sum()
191-
model.zero_grad()
192-
193-
labels = torch.randn_like(out)
194-
loss = (out - labels).mean()
195-
loss.backward()
196-
197-
# re-instantiate the model now enabling gradient checkpointing
198-
model_2 = self.model_class(**init_dict)
199-
# clone model
200-
model_2.load_state_dict(model.state_dict())
201-
model_2.to(torch_device)
202-
model_2.enable_gradient_checkpointing()
203-
204-
assert model_2.is_gradient_checkpointing and model_2.training
205-
206-
out_2 = model_2(**inputs_dict).sample
207-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
208-
# we won't calculate the loss and rather backprop on out.sum()
209-
model_2.zero_grad()
210-
loss_2 = (out_2 - labels).mean()
211-
loss_2.backward()
212-
213-
# compare the output and parameters gradients
214-
self.assertTrue((loss - loss_2).abs() < 1e-5)
215-
named_params = dict(model.named_parameters())
216-
named_params_2 = dict(model_2.named_parameters())
217-
for name, param in named_params.items():
218-
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
180+
def test_gradient_checkpointing_is_applied(self):
181+
expected_set = {"Decoder", "Encoder"}
182+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
219183

220184
def test_from_pretrained_hub(self):
221185
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
@@ -329,9 +293,11 @@ def prepare_init_args_and_inputs_for_common(self):
329293
inputs_dict = self.dummy_input
330294
return init_dict, inputs_dict
331295

296+
@unittest.skip("Not tested.")
332297
def test_forward_signature(self):
333298
pass
334299

300+
@unittest.skip("Not tested.")
335301
def test_forward_with_norm_groups(self):
336302
pass
337303

@@ -364,9 +330,20 @@ def prepare_init_args_and_inputs_for_common(self):
364330
inputs_dict = self.dummy_input
365331
return init_dict, inputs_dict
366332

333+
@unittest.skip("Not tested.")
367334
def test_outputs_equivalence(self):
368335
pass
369336

337+
def test_gradient_checkpointing_is_applied(self):
338+
expected_set = {"DecoderTiny", "EncoderTiny"}
339+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
340+
341+
@unittest.skip(
342+
"Gradient checkpointing is supported but this test doesn't apply to this class because it's forward is a bit different from the rest."
343+
)
344+
def test_effective_gradient_checkpointing(self):
345+
pass
346+
370347

371348
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
372349
model_class = ConsistencyDecoderVAE
@@ -443,55 +420,17 @@ def prepare_init_args_and_inputs_for_common(self):
443420
inputs_dict = self.dummy_input
444421
return init_dict, inputs_dict
445422

423+
@unittest.skip("Not tested.")
446424
def test_forward_signature(self):
447425
pass
448426

427+
@unittest.skip("Not tested.")
449428
def test_training(self):
450429
pass
451430

452-
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
453-
def test_gradient_checkpointing(self):
454-
# enable deterministic behavior for gradient checkpointing
455-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
456-
model = self.model_class(**init_dict)
457-
model.to(torch_device)
458-
459-
assert not model.is_gradient_checkpointing and model.training
460-
461-
out = model(**inputs_dict).sample
462-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
463-
# we won't calculate the loss and rather backprop on out.sum()
464-
model.zero_grad()
465-
466-
labels = torch.randn_like(out)
467-
loss = (out - labels).mean()
468-
loss.backward()
469-
470-
# re-instantiate the model now enabling gradient checkpointing
471-
model_2 = self.model_class(**init_dict)
472-
# clone model
473-
model_2.load_state_dict(model.state_dict())
474-
model_2.to(torch_device)
475-
model_2.enable_gradient_checkpointing()
476-
477-
assert model_2.is_gradient_checkpointing and model_2.training
478-
479-
out_2 = model_2(**inputs_dict).sample
480-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
481-
# we won't calculate the loss and rather backprop on out.sum()
482-
model_2.zero_grad()
483-
loss_2 = (out_2 - labels).mean()
484-
loss_2.backward()
485-
486-
# compare the output and parameters gradients
487-
self.assertTrue((loss - loss_2).abs() < 1e-5)
488-
named_params = dict(model.named_parameters())
489-
named_params_2 = dict(model_2.named_parameters())
490-
for name, param in named_params.items():
491-
if "post_quant_conv" in name:
492-
continue
493-
494-
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
431+
def test_gradient_checkpointing_is_applied(self):
432+
expected_set = {"Encoder", "TemporalDecoder"}
433+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
495434

496435

497436
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
@@ -522,9 +461,11 @@ def prepare_init_args_and_inputs_for_common(self):
522461
inputs_dict = self.dummy_input
523462
return init_dict, inputs_dict
524463

464+
@unittest.skip("Not tested.")
525465
def test_forward_signature(self):
526466
pass
527467

468+
@unittest.skip("Not tested.")
528469
def test_forward_with_norm_groups(self):
529470
pass
530471

tests/models/test_modeling_common.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import copy
1617
import inspect
1718
import json
1819
import os
@@ -57,6 +58,7 @@
5758
require_torch_gpu,
5859
require_torch_multi_gpu,
5960
run_test_in_subprocess,
61+
torch_all_close,
6062
torch_device,
6163
)
6264

@@ -785,6 +787,101 @@ def test_enable_disable_gradient_checkpointing(self):
785787
model.disable_gradient_checkpointing()
786788
self.assertFalse(model.is_gradient_checkpointing)
787789

790+
@require_torch_accelerator_with_training
791+
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
792+
if not self.model_class._supports_gradient_checkpointing:
793+
return # Skip test if model does not support gradient checkpointing
794+
795+
# enable deterministic behavior for gradient checkpointing
796+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
797+
inputs_dict_copy = copy.deepcopy(inputs_dict)
798+
torch.manual_seed(0)
799+
model = self.model_class(**init_dict)
800+
model.to(torch_device)
801+
802+
assert not model.is_gradient_checkpointing and model.training
803+
804+
out = model(**inputs_dict).sample
805+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
806+
# we won't calculate the loss and rather backprop on out.sum()
807+
model.zero_grad()
808+
809+
labels = torch.randn_like(out)
810+
loss = (out - labels).mean()
811+
loss.backward()
812+
813+
# re-instantiate the model now enabling gradient checkpointing
814+
torch.manual_seed(0)
815+
model_2 = self.model_class(**init_dict)
816+
# clone model
817+
model_2.load_state_dict(model.state_dict())
818+
model_2.to(torch_device)
819+
model_2.enable_gradient_checkpointing()
820+
821+
assert model_2.is_gradient_checkpointing and model_2.training
822+
823+
out_2 = model_2(**inputs_dict_copy).sample
824+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
825+
# we won't calculate the loss and rather backprop on out.sum()
826+
model_2.zero_grad()
827+
loss_2 = (out_2 - labels).mean()
828+
loss_2.backward()
829+
830+
# compare the output and parameters gradients
831+
self.assertTrue((loss - loss_2).abs() < loss_tolerance)
832+
named_params = dict(model.named_parameters())
833+
named_params_2 = dict(model_2.named_parameters())
834+
835+
for name, param in named_params.items():
836+
if "post_quant_conv" in name:
837+
continue
838+
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
839+
840+
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
841+
def test_gradient_checkpointing_is_applied(
842+
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
843+
):
844+
if not self.model_class._supports_gradient_checkpointing:
845+
return # Skip test if model does not support gradient checkpointing
846+
if self.model_class.__name__ in [
847+
"UNetSpatioTemporalConditionModel",
848+
"AutoencoderKLTemporalDecoder",
849+
]:
850+
return
851+
852+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
853+
854+
if attention_head_dim is not None:
855+
init_dict["attention_head_dim"] = attention_head_dim
856+
if num_attention_heads is not None:
857+
init_dict["num_attention_heads"] = num_attention_heads
858+
if block_out_channels is not None:
859+
init_dict["block_out_channels"] = block_out_channels
860+
861+
model_class_copy = copy.copy(self.model_class)
862+
863+
modules_with_gc_enabled = {}
864+
865+
# now monkey patch the following function:
866+
# def _set_gradient_checkpointing(self, module, value=False):
867+
# if hasattr(module, "gradient_checkpointing"):
868+
# module.gradient_checkpointing = value
869+
870+
def _set_gradient_checkpointing_new(self, module, value=False):
871+
if hasattr(module, "gradient_checkpointing"):
872+
module.gradient_checkpointing = value
873+
modules_with_gc_enabled[module.__class__.__name__] = True
874+
875+
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
876+
877+
model = model_class_copy(**init_dict)
878+
model.enable_gradient_checkpointing()
879+
880+
print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")
881+
882+
assert set(modules_with_gc_enabled.keys()) == expected_set
883+
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
884+
788885
def test_deprecated_kwargs(self):
789886
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
790887
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0

tests/models/transformers/test_models_dit_transformer2d.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ def test_correct_class_remapping_from_dict_config(self):
8484
model = Transformer2DModel.from_config(init_dict)
8585
assert isinstance(model, DiTTransformer2DModel)
8686

87+
def test_gradient_checkpointing_is_applied(self):
88+
expected_set = {"DiTTransformer2DModel"}
89+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
90+
91+
def test_effective_gradient_checkpointing(self):
92+
super().test_effective_gradient_checkpointing(loss_tolerance=1e-4)
93+
8794
def test_correct_class_remapping_from_pretrained_config(self):
8895
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
8996
model = Transformer2DModel.from_config(config)

tests/models/transformers/test_models_pixart_transformer2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def test_output(self):
9292
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
9393
)
9494

95+
def test_gradient_checkpointing_is_applied(self):
96+
expected_set = {"PixArtTransformer2DModel"}
97+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
98+
9599
def test_correct_class_remapping_from_dict_config(self):
96100
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
97101
model = Transformer2DModel.from_config(init_dict)

tests/models/transformers/test_models_transformer_allegro.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,7 @@ def prepare_init_args_and_inputs_for_common(self):
7777
}
7878
inputs_dict = self.dummy_input
7979
return init_dict, inputs_dict
80+
81+
def test_gradient_checkpointing_is_applied(self):
82+
expected_set = {"AllegroTransformer3DModel"}
83+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_aura_flow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def prepare_init_args_and_inputs_for_common(self):
7474
inputs_dict = self.dummy_input
7575
return init_dict, inputs_dict
7676

77+
def test_gradient_checkpointing_is_applied(self):
78+
expected_set = {"AuraFlowTransformer2DModel"}
79+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
80+
7781
@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
7882
def test_set_attn_processor_for_determinism(self):
7983
pass

tests/models/transformers/test_models_transformer_cogvideox.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,7 @@ def prepare_init_args_and_inputs_for_common(self):
8181
}
8282
inputs_dict = self.dummy_input
8383
return init_dict, inputs_dict
84+
85+
def test_gradient_checkpointing_is_applied(self):
86+
expected_set = {"CogVideoXTransformer3DModel"}
87+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_cogview3plus.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,7 @@ def prepare_init_args_and_inputs_for_common(self):
8383
}
8484
inputs_dict = self.dummy_input
8585
return init_dict, inputs_dict
86+
87+
def test_gradient_checkpointing_is_applied(self):
88+
expected_set = {"CogView3PlusTransformer2DModel"}
89+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,7 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
111111
torch.allclose(output_1, output_2, atol=1e-5),
112112
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
113113
)
114+
115+
def test_gradient_checkpointing_is_applied(self):
116+
expected_set = {"FluxTransformer2DModel"}
117+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_latte.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,7 @@ def test_output(self):
8686
super().test_output(
8787
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
8888
)
89+
90+
def test_gradient_checkpointing_is_applied(self):
91+
expected_set = {"LatteTransformer3DModel"}
92+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def prepare_init_args_and_inputs_for_common(self):
8484
def test_set_attn_processor_for_determinism(self):
8585
pass
8686

87+
def test_gradient_checkpointing_is_applied(self):
88+
expected_set = {"SD3Transformer2DModel"}
89+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
90+
8791

8892
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
8993
model_class = SD3Transformer2DModel
@@ -139,3 +143,7 @@ def prepare_init_args_and_inputs_for_common(self):
139143
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
140144
def test_set_attn_processor_for_determinism(self):
141145
pass
146+
147+
def test_gradient_checkpointing_is_applied(self):
148+
expected_set = {"SD3Transformer2DModel"}
149+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

0 commit comments

Comments
 (0)