14
14
import torch .nn .functional as F
15
15
from huggingface_hub import snapshot_download
16
16
from PIL import Image
17
- from transformers import (AutoModelForCausalLM , AutoTokenizer , BatchEncoding ,
18
- BatchFeature )
17
+ from transformers import (AutoConfig , AutoModelForCausalLM , AutoTokenizer ,
18
+ BatchEncoding , BatchFeature )
19
19
from transformers .models .auto .auto_factory import _BaseAutoModelClass
20
20
21
21
from tests .models .utils import (TokensTextLogprobs ,
22
22
TokensTextLogprobsPromptLogprobs )
23
23
from vllm import LLM , SamplingParams
24
24
from vllm .assets .image import ImageAsset
25
25
from vllm .assets .video import VideoAsset
26
- from vllm .config import TaskOption , TokenizerPoolConfig
26
+ from vllm .config import TaskOption , TokenizerPoolConfig , _get_and_verify_dtype
27
27
from vllm .connections import global_http_connection
28
28
from vllm .distributed import (cleanup_dist_env_and_memory ,
29
29
init_distributed_environment ,
34
34
from vllm .logger import init_logger
35
35
from vllm .outputs import RequestOutput
36
36
from vllm .sampling_params import BeamSearchParams
37
- from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , cuda_device_count_stateless ,
38
- identity , is_list_of )
37
+ from vllm .utils import cuda_device_count_stateless , is_list_of
39
38
40
39
logger = init_logger (__name__ )
41
40
@@ -271,14 +270,18 @@ def video_assets() -> _VideoAssets:
271
270
272
271
class HfRunner :
273
272
274
- def wrap_device (self , x : _T , device : Optional [ str ] = None ) -> _T :
273
+ def get_default_device (self ) :
275
274
from vllm .platforms import current_platform
275
+
276
+ return ("cpu" if current_platform .is_cpu ()
277
+ or current_platform .is_openvino () else "cuda" )
278
+
279
+ def wrap_device (self , x : _T , device : Optional [str ] = None ) -> _T :
276
280
if x is None or isinstance (x , (bool , )):
277
281
return x
278
282
279
283
if device is None :
280
- device = "cpu" if current_platform .is_cpu (
281
- ) or current_platform .is_openvino () else "cuda"
284
+ device = self .device
282
285
283
286
if isinstance (x , dict ):
284
287
return {k : self .wrap_device (v , device ) for k , v in x .items ()}
@@ -291,45 +294,59 @@ def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
291
294
def __init__ (
292
295
self ,
293
296
model_name : str ,
294
- dtype : str = "half " ,
297
+ dtype : str = "auto " ,
295
298
* ,
296
299
model_kwargs : Optional [dict [str , Any ]] = None ,
297
300
is_sentence_transformer : bool = False ,
298
301
is_cross_encoder : bool = False ,
299
302
skip_tokenizer_init : bool = False ,
300
303
auto_cls : type [_BaseAutoModelClass ] = AutoModelForCausalLM ,
301
- postprocess_inputs : Callable [..., BatchEncoding ] = identity ,
302
304
) -> None :
303
- torch_dtype = STR_DTYPE_TO_TORCH_DTYPE [dtype ]
304
-
305
305
self .model_name = model_name
306
306
307
+ self .config = AutoConfig .from_pretrained (
308
+ model_name ,
309
+ trust_remote_code = True ,
310
+ )
311
+ self .device = self .get_default_device ()
312
+ self .dtype = torch_dtype = _get_and_verify_dtype (self .config , dtype )
313
+
314
+ model_kwargs = model_kwargs if model_kwargs is not None else {}
315
+ model_kwargs .setdefault ("torch_dtype" , torch_dtype )
316
+
307
317
if is_sentence_transformer :
308
318
# Lazy init required for AMD CI
309
319
from sentence_transformers import SentenceTransformer
310
- self .model = self .wrap_device (
311
- SentenceTransformer (
312
- model_name ,
313
- device = "cpu" ,
314
- trust_remote_code = True ,
315
- ).to (dtype = torch_dtype ))
320
+
321
+ self .model = SentenceTransformer (
322
+ model_name ,
323
+ device = self .device ,
324
+ model_kwargs = model_kwargs ,
325
+ trust_remote_code = True ,
326
+ )
316
327
elif is_cross_encoder :
317
328
# Lazy init required for AMD CI
318
329
from sentence_transformers import CrossEncoder
319
- self .model = CrossEncoder (model_name ,
320
- device = "cpu" ,
321
- trust_remote_code = True )
322
- self .model .model = self .wrap_device (self .model .model )\
323
- .to (dtype = torch_dtype )
330
+
331
+ self .model = CrossEncoder (
332
+ model_name ,
333
+ device = self .device ,
334
+ automodel_args = model_kwargs ,
335
+ trust_remote_code = True ,
336
+ )
324
337
else :
325
- model_kwargs = model_kwargs if model_kwargs is not None else {}
326
- self .model = self .wrap_device (
327
- auto_cls .from_pretrained (
328
- model_name ,
329
- torch_dtype = torch_dtype ,
330
- trust_remote_code = True ,
331
- ** model_kwargs ,
332
- ))
338
+ model = auto_cls .from_pretrained (
339
+ model_name ,
340
+ trust_remote_code = True ,
341
+ ** model_kwargs ,
342
+ )
343
+
344
+ if (getattr (model , "quantization_method" , None ) != "bitsandbytes"
345
+ and len ({p .device
346
+ for p in model .parameters ()}) < 2 ):
347
+ model = model .to (self .device )
348
+
349
+ self .model = model
333
350
334
351
if not skip_tokenizer_init :
335
352
self .tokenizer = AutoTokenizer .from_pretrained (
@@ -349,16 +366,13 @@ def __init__(
349
366
if skip_tokenizer_init :
350
367
self .tokenizer = self .processor .tokenizer
351
368
352
- self .dtype = dtype
353
- self .postprocess_inputs = postprocess_inputs
354
-
355
369
def get_inputs (
356
370
self ,
357
371
prompts : list [str ],
358
372
images : Optional [PromptImageInput ] = None ,
359
373
videos : Optional [PromptVideoInput ] = None ,
360
374
audios : Optional [PromptAudioInput ] = None ,
361
- ) -> list [BatchEncoding ]:
375
+ ) -> list [Union [ BatchFeature , BatchEncoding ] ]:
362
376
if images is not None :
363
377
assert len (prompts ) == len (images )
364
378
@@ -368,7 +382,7 @@ def get_inputs(
368
382
if audios is not None :
369
383
assert len (prompts ) == len (audios )
370
384
371
- all_inputs : list [BatchEncoding ] = []
385
+ all_inputs : list [Union [ BatchFeature , BatchEncoding ] ] = []
372
386
for i , prompt in enumerate (prompts ):
373
387
processor_kwargs : dict [str , Any ] = {
374
388
"text" : prompt ,
@@ -384,7 +398,8 @@ def get_inputs(
384
398
processor_kwargs ["sampling_rate" ] = sr
385
399
386
400
inputs = self .processor (** processor_kwargs )
387
- inputs = self .postprocess_inputs (inputs , dtype = self .dtype )
401
+ if isinstance (inputs , BatchFeature ):
402
+ inputs = inputs .to (dtype = self .dtype )
388
403
389
404
all_inputs .append (inputs )
390
405
@@ -417,7 +432,7 @@ def generate(
417
432
outputs : list [tuple [list [list [int ]], list [str ]]] = []
418
433
for inputs in all_inputs :
419
434
output_ids = self .model .generate (
420
- ** self .wrap_device (inputs , device = self . model . device . type ),
435
+ ** self .wrap_device (inputs ),
421
436
use_cache = True ,
422
437
** kwargs ,
423
438
)
@@ -488,7 +503,7 @@ def generate_greedy_logprobs(
488
503
all_logprobs : list [list [torch .Tensor ]] = []
489
504
for inputs in all_inputs :
490
505
output = self .model .generate (
491
- ** self .wrap_device (inputs , device = self . model . device . type ),
506
+ ** self .wrap_device (inputs ),
492
507
use_cache = True ,
493
508
do_sample = False ,
494
509
max_new_tokens = max_tokens ,
@@ -569,7 +584,7 @@ def generate_greedy_logprobs_limit(
569
584
570
585
for inputs in all_inputs :
571
586
output = self .model .generate (
572
- ** self .wrap_device (inputs , device = self . model . device . type ),
587
+ ** self .wrap_device (inputs ),
573
588
use_cache = True ,
574
589
do_sample = False ,
575
590
max_new_tokens = max_tokens ,
@@ -620,19 +635,15 @@ def generate_encoder_decoder_greedy_logprobs_limit(
620
635
if images is not None and images [i ] is not None :
621
636
processor_kwargs ["images" ] = images [i ]
622
637
623
- encoder_inputs = self .wrap_device (
624
- self .processor (** processor_kwargs ),
625
- device = self .model .device .type ,
626
- )
638
+ encoder_inputs = self .processor (** processor_kwargs )
639
+ encoder_inputs = self .wrap_device (encoder_inputs )
627
640
628
641
if decoder_prompt is None :
629
642
decoder_input_ids = None
630
643
else :
631
- decoder_input_ids = self .wrap_device (
632
- self .tokenizer (decoder_prompt ,
633
- return_tensors = "pt" ).input_ids ,
634
- device = self .model .device .type ,
635
- )
644
+ decoder_inputs = self .tokenizer (decoder_prompt ,
645
+ return_tensors = "pt" )
646
+ decoder_input_ids = self .wrap_device (decoder_inputs .input_ids )
636
647
637
648
output = self .model .generate (
638
649
decoder_input_ids = decoder_input_ids ,
@@ -684,6 +695,7 @@ class VllmRunner:
684
695
"""
685
696
The default value of some arguments have been modified from
686
697
:class:`~vllm.LLM` as follows:
698
+
687
699
- `trust_remote_code`: Set to `True` instead of `False` for convenience.
688
700
- `seed`: Set to `0` instead of `None` for test reproducibility.
689
701
- `max_model_len`: Set to `1024` instead of `None` to reduce memory usage.
@@ -701,10 +713,8 @@ def __init__(
701
713
tokenizer_mode : str = "auto" ,
702
714
trust_remote_code : bool = True ,
703
715
seed : Optional [int ] = 0 ,
704
- # Use smaller max model length, otherwise bigger model cannot run due
705
- # to kv cache size limit.
706
716
max_model_len : int = 1024 ,
707
- dtype : str = "half " ,
717
+ dtype : str = "auto " ,
708
718
disable_log_stats : bool = True ,
709
719
tensor_parallel_size : int = 1 ,
710
720
block_size : int = 16 ,
@@ -1110,4 +1120,4 @@ def pytest_collection_modifyitems(config, items):
1110
1120
skip_optional = pytest .mark .skip (reason = "need --optional option to run" )
1111
1121
for item in items :
1112
1122
if "optional" in item .keywords :
1113
- item .add_marker (skip_optional )
1123
+ item .add_marker (skip_optional )
0 commit comments