Skip to content

Commit 393aefc

Browse files
authored
[tests] fix audioldm2 for transformers main. (#11522)
fix audioldm2 for transformers main.
1 parent 6674a51 commit 393aefc

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
logging,
4141
replace_example_docstring,
4242
)
43+
from ...utils.import_utils import is_transformers_version
4344
from ...utils.torch_utils import randn_tensor
4445
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
4546
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
@@ -312,8 +313,19 @@ def generate_language_model(
312313
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
313314
The sequence of generated hidden-states.
314315
"""
316+
cache_position_kwargs = {}
317+
if is_transformers_version("<", "4.52.0.dev0"):
318+
cache_position_kwargs["input_ids"] = inputs_embeds
319+
cache_position_kwargs["model_kwargs"] = model_kwargs
320+
else:
321+
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
322+
cache_position_kwargs["device"] = (
323+
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
324+
)
325+
cache_position_kwargs["model_kwargs"] = model_kwargs
315326
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
316-
model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
327+
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
328+
317329
for _ in range(max_new_tokens):
318330
# prepare model inputs
319331
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)

0 commit comments

Comments
 (0)