Skip to content

Commit c80eda9

Browse files
authored
[Tests] Test layerwise casting with training (#10765)
* add a test to check if we can train with layerwise casting. * updates * updates * style
1 parent 7fb481f commit c80eda9

File tree

3 files changed

+44
-0
lines changed

3 files changed

+44
-0
lines changed

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ def test_forward_with_norm_groups(self):
114114
def test_set_attn_processor_for_determinism(self):
115115
return
116116

117+
@unittest.skip(
118+
"Test not supported because of 'weight_norm_fwd_first_dim_kernel' not implemented for 'Float8_e4m3fn'"
119+
)
120+
def test_layerwise_casting_training(self):
121+
return super().test_layerwise_casting_training()
122+
117123
@unittest.skip(
118124
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
119125
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"

tests/models/test_modeling_common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,36 @@ def test_variant_sharded_ckpt_right_format(self):
13381338
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
13391339
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
13401340

1341+
def test_layerwise_casting_training(self):
1342+
def test_fn(storage_dtype, compute_dtype):
1343+
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
1344+
return
1345+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1346+
1347+
model = self.model_class(**init_dict)
1348+
model = model.to(torch_device, dtype=compute_dtype)
1349+
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
1350+
model.train()
1351+
1352+
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
1353+
with torch.amp.autocast(device_type=torch.device(torch_device).type):
1354+
output = model(**inputs_dict)
1355+
1356+
if isinstance(output, dict):
1357+
output = output.to_tuple()[0]
1358+
1359+
input_tensor = inputs_dict[self.main_input_name]
1360+
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
1361+
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
1362+
loss = torch.nn.functional.mse_loss(output, noise)
1363+
1364+
loss.backward()
1365+
1366+
test_fn(torch.float16, torch.float32)
1367+
test_fn(torch.float8_e4m3fn, torch.float32)
1368+
test_fn(torch.float8_e5m2, torch.float32)
1369+
test_fn(torch.float8_e4m3fn, torch.bfloat16)
1370+
13411371
def test_layerwise_casting_inference(self):
13421372
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
13431373

tests/models/unets/test_models_unet_1d.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def test_ema_training(self):
6060
def test_training(self):
6161
pass
6262

63+
@unittest.skip("Test not supported.")
64+
def test_layerwise_casting_training(self):
65+
pass
66+
6367
def test_determinism(self):
6468
super().test_determinism()
6569

@@ -239,6 +243,10 @@ def test_ema_training(self):
239243
def test_training(self):
240244
pass
241245

246+
@unittest.skip("Test not supported.")
247+
def test_layerwise_casting_training(self):
248+
pass
249+
242250
def prepare_init_args_and_inputs_for_common(self):
243251
init_dict = {
244252
"in_channels": 14,

0 commit comments

Comments
 (0)