Skip to content

Commit af474e5

Browse files
YiYi Xusayakpaul
YiYi Xu
authored andcommitted
a few fix for shard checkpoints (#8656)
fix Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent 6da31e4 commit af474e5

File tree

3 files changed

+5
-1
lines changed

3 files changed

+5
-1
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
819819
offload_folder=offload_folder,
820820
offload_state_dict=offload_state_dict,
821821
dtype=torch_dtype,
822-
force_hook=force_hook,
822+
force_hooks=force_hook,
823823
strict=True,
824824
)
825825
model._undo_temp_convert_self_to_deprecated_attention_blocks()

tests/models/test_modeling_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,7 @@ def test_sharded_checkpoints(self):
898898
self.assertTrue(actual_num_shards == expected_num_shards)
899899

900900
new_model = self.model_class.from_pretrained(tmp_dir)
901+
new_model = new_model.to(torch_device)
901902

902903
torch.manual_seed(0)
903904
new_output = new_model(**inputs_dict)
@@ -933,6 +934,7 @@ def test_sharded_checkpoints_device_map(self):
933934
self.assertTrue(actual_num_shards == expected_num_shards)
934935

935936
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
937+
new_model = new_model.to(torch_device)
936938

937939
torch.manual_seed(0)
938940
new_output = new_model(**inputs_dict)

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,6 +1039,7 @@ def test_ip_adapter_plus(self):
10391039
def test_load_sharded_checkpoint_from_hub(self):
10401040
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10411041
loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy")
1042+
loaded_model = loaded_model.to(torch_device)
10421043
new_output = loaded_model(**inputs_dict)
10431044

10441045
assert loaded_model
@@ -1049,6 +1050,7 @@ def test_load_sharded_checkpoint_from_hub_local(self):
10491050
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
10501051
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
10511052
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
1053+
loaded_model = loaded_model.to(torch_device)
10521054
new_output = loaded_model(**inputs_dict)
10531055

10541056
assert loaded_model

0 commit comments

Comments
 (0)