Skip to content

Commit 523f6e7

Browse files
authored
Fix: dtype cannot be str (#36262)
* fix * this wan't supposed to be here, revert * refine tests a bit more
1 parent 3f9ff19 commit 523f6e7

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/transformers/modeling_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,13 +1252,13 @@ def _get_torch_dtype(
12521252
for key, curr_dtype in torch_dtype.items():
12531253
if hasattr(config, key):
12541254
value = getattr(config, key)
1255+
curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
12551256
value.torch_dtype = curr_dtype
12561257
# main torch dtype for modules that aren't part of any sub-config
12571258
torch_dtype = torch_dtype.get("")
1259+
torch_dtype = torch_dtype if not isinstance(torch_dtype, str) else getattr(torch, torch_dtype)
12581260
config.torch_dtype = torch_dtype
1259-
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
1260-
torch_dtype = getattr(torch, torch_dtype)
1261-
elif torch_dtype is None:
1261+
if torch_dtype is None:
12621262
torch_dtype = torch.float32
12631263
else:
12641264
raise ValueError(
@@ -1269,7 +1269,7 @@ def _get_torch_dtype(
12691269
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
12701270
else:
12711271
# set fp32 as the default dtype for BC
1272-
default_dtype = str(torch.get_default_dtype()).split(".")[-1]
1272+
default_dtype = torch.get_default_dtype()
12731273
config.torch_dtype = default_dtype
12741274
for key in config.sub_configs.keys():
12751275
value = getattr(config, key)

tests/utils/test_modeling_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,9 +482,11 @@ def test_model_from_config_torch_dtype_str(self):
482482
# test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
483483
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float32")
484484
self.assertEqual(model.dtype, torch.float32)
485+
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
485486

486487
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="float16")
487488
self.assertEqual(model.dtype, torch.float16)
489+
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
488490

489491
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
490492
with self.assertRaises(ValueError):
@@ -495,14 +497,22 @@ def test_model_from_config_torch_dtype_composite(self):
495497
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
496498
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
497499
"""
500+
# Load without dtype specified
501+
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA)
502+
self.assertEqual(model.language_model.dtype, torch.float32)
503+
self.assertEqual(model.vision_tower.dtype, torch.float32)
504+
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
505+
498506
# should be able to set torch_dtype as a simple string and the model loads it correctly
499507
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
500508
self.assertEqual(model.language_model.dtype, torch.float32)
501509
self.assertEqual(model.vision_tower.dtype, torch.float32)
510+
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
502511

503512
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16)
504513
self.assertEqual(model.language_model.dtype, torch.float16)
505514
self.assertEqual(model.vision_tower.dtype, torch.float16)
515+
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
506516

507517
# should be able to set torch_dtype as a dict for each sub-config
508518
model = LlavaForConditionalGeneration.from_pretrained(
@@ -511,6 +521,7 @@ def test_model_from_config_torch_dtype_composite(self):
511521
self.assertEqual(model.language_model.dtype, torch.float32)
512522
self.assertEqual(model.vision_tower.dtype, torch.float16)
513523
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
524+
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
514525

515526
# should be able to set the values as torch.dtype (not str)
516527
model = LlavaForConditionalGeneration.from_pretrained(
@@ -519,6 +530,7 @@ def test_model_from_config_torch_dtype_composite(self):
519530
self.assertEqual(model.language_model.dtype, torch.float32)
520531
self.assertEqual(model.vision_tower.dtype, torch.float16)
521532
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
533+
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
522534

523535
# should be able to set the values in configs directly and pass it to `from_pretrained`
524536
config = copy.deepcopy(model.config)
@@ -529,13 +541,15 @@ def test_model_from_config_torch_dtype_composite(self):
529541
self.assertEqual(model.language_model.dtype, torch.float32)
530542
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
531543
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
544+
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
532545

533546
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
534547
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
535548
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
536549
self.assertEqual(model.language_model.dtype, torch.float32)
537550
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
538551
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
552+
self.assertIsInstance(model.config.torch_dtype, torch.dtype)
539553

540554
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
541555
with self.assertRaises(ValueError):

0 commit comments

Comments
 (0)