@@ -482,9 +482,11 @@ def test_model_from_config_torch_dtype_str(self):
482
482
# test that from_pretrained works with torch_dtype being strings like "float32" for PyTorch backend
483
483
model = AutoModel .from_pretrained (TINY_T5 , torch_dtype = "float32" )
484
484
self .assertEqual (model .dtype , torch .float32 )
485
+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
485
486
486
487
model = AutoModel .from_pretrained (TINY_T5 , torch_dtype = "float16" )
487
488
self .assertEqual (model .dtype , torch .float16 )
489
+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
488
490
489
491
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
490
492
with self .assertRaises (ValueError ):
@@ -495,14 +497,22 @@ def test_model_from_config_torch_dtype_composite(self):
495
497
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
496
498
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
497
499
"""
500
+ # Load without dtype specified
501
+ model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA )
502
+ self .assertEqual (model .language_model .dtype , torch .float32 )
503
+ self .assertEqual (model .vision_tower .dtype , torch .float32 )
504
+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
505
+
498
506
# should be able to set torch_dtype as a simple string and the model loads it correctly
499
507
model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA , torch_dtype = "float32" )
500
508
self .assertEqual (model .language_model .dtype , torch .float32 )
501
509
self .assertEqual (model .vision_tower .dtype , torch .float32 )
510
+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
502
511
503
512
model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA , torch_dtype = torch .float16 )
504
513
self .assertEqual (model .language_model .dtype , torch .float16 )
505
514
self .assertEqual (model .vision_tower .dtype , torch .float16 )
515
+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
506
516
507
517
# should be able to set torch_dtype as a dict for each sub-config
508
518
model = LlavaForConditionalGeneration .from_pretrained (
@@ -511,6 +521,7 @@ def test_model_from_config_torch_dtype_composite(self):
511
521
self .assertEqual (model .language_model .dtype , torch .float32 )
512
522
self .assertEqual (model .vision_tower .dtype , torch .float16 )
513
523
self .assertEqual (model .multi_modal_projector .linear_1 .weight .dtype , torch .bfloat16 )
524
+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
514
525
515
526
# should be able to set the values as torch.dtype (not str)
516
527
model = LlavaForConditionalGeneration .from_pretrained (
@@ -519,6 +530,7 @@ def test_model_from_config_torch_dtype_composite(self):
519
530
self .assertEqual (model .language_model .dtype , torch .float32 )
520
531
self .assertEqual (model .vision_tower .dtype , torch .float16 )
521
532
self .assertEqual (model .multi_modal_projector .linear_1 .weight .dtype , torch .bfloat16 )
533
+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
522
534
523
535
# should be able to set the values in configs directly and pass it to `from_pretrained`
524
536
config = copy .deepcopy (model .config )
@@ -529,13 +541,15 @@ def test_model_from_config_torch_dtype_composite(self):
529
541
self .assertEqual (model .language_model .dtype , torch .float32 )
530
542
self .assertEqual (model .vision_tower .dtype , torch .bfloat16 )
531
543
self .assertEqual (model .multi_modal_projector .linear_1 .weight .dtype , torch .float16 )
544
+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
532
545
533
546
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
534
547
LlavaForConditionalGeneration ._keep_in_fp32_modules = ["multi_modal_projector" ]
535
548
model = LlavaForConditionalGeneration .from_pretrained (TINY_LLAVA , config = config , torch_dtype = "auto" )
536
549
self .assertEqual (model .language_model .dtype , torch .float32 )
537
550
self .assertEqual (model .vision_tower .dtype , torch .bfloat16 )
538
551
self .assertEqual (model .multi_modal_projector .linear_1 .weight .dtype , torch .float32 )
552
+ self .assertIsInstance (model .config .torch_dtype , torch .dtype )
539
553
540
554
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
541
555
with self .assertRaises (ValueError ):
0 commit comments