Skip to content

Commit 3a655d9

Browse files
committed
Check shape and remove deprecated APIs in scheduling_ddpm_flax.py
`model_output.shape` may only have rank 1. There are warnings related to use of random keys. ``` tests/schedulers/test_scheduler_flax.py: 13 warnings /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type] ```
1 parent 56bd7e6 commit 3a655d9

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

docker/diffusers-flax-cpu/Dockerfile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ RUN apt update && \
1313
ca-certificates \
1414
libsndfile1-dev \
1515
libgl1 \
16-
python3.8 \
16+
python3.11 \
1717
python3-pip \
18-
python3.8-venv && \
18+
python3.11-venv && \
1919
rm -rf /var/lib/apt/lists
2020

2121
# make sure to use venv
@@ -27,9 +27,9 @@ ENV PATH="/opt/venv/bin:$PATH"
2727
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
2828
python3 -m uv pip install --upgrade --no-cache-dir \
2929
clu \
30-
"jax[cpu]>=0.2.16,!=0.3.2" \
31-
"flax>=0.4.1" \
32-
"jaxlib>=0.1.65" && \
30+
"jax[cpu]>=0.4.26" \
31+
"flax>=0.8.2" \
32+
"jaxlib>=0.4.26" && \
3333
python3 -m uv pip install --no-cache-dir \
3434
accelerate \
3535
datasets \

docker/diffusers-flax-tpu/Dockerfile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ RUN apt update && \
1313
ca-certificates \
1414
libsndfile1-dev \
1515
libgl1 \
16-
python3.8 \
16+
python3.11 \
1717
python3-pip \
18-
python3.8-venv && \
18+
python3.11-venv && \
1919
rm -rf /var/lib/apt/lists
2020

2121
# make sure to use venv
@@ -26,12 +26,12 @@ ENV PATH="/opt/venv/bin:$PATH"
2626
# follow the instructions here: https://cloud.google.com/tpu/docs/run-in-container#train_a_jax_model_in_a_docker_container
2727
RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
2828
python3 -m pip install --no-cache-dir \
29-
"jax[tpu]>=0.2.16,!=0.3.2" \
29+
"jax[tpu]>=0.4.26" \
3030
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
3131
python3 -m uv pip install --upgrade --no-cache-dir \
3232
clu \
33-
"flax>=0.4.1" \
34-
"jaxlib>=0.1.65" && \
33+
"flax>=0.8.2" \
34+
"jaxlib>=0.4.26" && \
3535
python3 -m uv pip install --no-cache-dir \
3636
accelerate \
3737
datasets \

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,13 @@ def step(
222222
t = timestep
223223

224224
if key is None:
225-
key = jax.random.PRNGKey(0)
225+
key = jax.random.key(0)
226226

227-
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
227+
if (
228+
len(model_output.shape) > 1
229+
and model_output.shape[1] == sample.shape[1] * 2
230+
and self.config.variance_type in ["learned", "learned_range"]
231+
):
228232
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
229233
else:
230234
predicted_variance = None
@@ -264,7 +268,7 @@ def step(
264268

265269
# 6. Add noise
266270
def random_variance():
267-
split_key = jax.random.split(key, num=1)
271+
split_key = jax.random.split(key, ())
268272
noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)
269273
return (self._get_variance(state, t, predicted_variance=predicted_variance) ** 0.5) * noise
270274

0 commit comments

Comments
 (0)