Skip to content

Fix invocation of some slow Flax tests #3058

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 11, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions tests/test_pipelines_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax import pmap

from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline

Expand Down Expand Up @@ -70,14 +69,12 @@ def test_dummy_all_tpus(self):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images

assert images.shape == (num_samples, 1, 64, 64, 3)
if jax.device_count() == 8:
Expand Down Expand Up @@ -105,14 +102,12 @@ def test_stable_diffusion_v1_4(self):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images

assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
Expand All @@ -136,14 +131,12 @@ def test_stable_diffusion_v1_4_bfloat_16(self):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images

assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
Expand Down Expand Up @@ -211,14 +204,12 @@ def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)

images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images

assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
Expand Down