-
Notifications
You must be signed in to change notification settings - Fork 6k
Check shape and remove deprecated APIs in scheduling_ddpm_flax.py #7703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
||
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: | ||
if ( | ||
len(model_output.shape) > 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we adding this len() > 1
check here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because input could have shape like x.shape == (10,)
. Accessing index 1
in the next line is an error in that case.
@@ -264,7 +268,7 @@ def step( | |||
|
|||
# 6. Add noise | |||
def random_variance(): | |||
split_key = jax.random.split(key, num=1) | |||
split_key = jax.random.split(key, ()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we make this change here? num=1
would still work with jax.random.key
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be an error soon in newer versions of JAX, https://github.com/google/jax/blob/343e18fcb693c3f1b2cace56a4faea8fd3e2cadd/jax/_src/random.py#L104-L108.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, I'd use:
split_key = jax.random.split(key, ()) | |
split_key = jax.random.split(key, num=(1,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That doesn't fix the batching dimension error. In any case, I have done jax.random.split(key, num=1)[0]
to remove the batch dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes you are right!
3a655d9
to
3e6cde9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! However, please note that the project dependencies are defined in
Line 109 in 26a7851
"jax>=0.4.1", |
If we change this code, we need to ensure that it still works with the existing versions (can probably be confirmed with the CI tests), or upgrade the dependencies if necessary.
docker/diffusers-flax-cpu/Dockerfile
Outdated
@@ -13,9 +13,9 @@ RUN apt update && \ | |||
ca-certificates \ | |||
libsndfile1-dev \ | |||
libgl1 \ | |||
python3.8 \ | |||
python3.10 \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The diffusers codebase is compatible with Python 3.8
, and the other Dockerfiles use that version too:
python3.8 \ |
Is it a requirement to upgrade?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I reverted the upgrade. But note Python 3.8 is already not supported and even Python 3.10 will lose support EOY, https://scientific-python.org/specs/spec-0000/ and https://jax.readthedocs.io/en/latest/deprecation.html.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you are right, we should probably upgrade at the project level cc @yiyixuxu @sayakpaul
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @sayakpaul here! let's upgrade?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -264,7 +268,7 @@ def step( | |||
|
|||
# 6. Add noise | |||
def random_variance(): | |||
split_key = jax.random.split(key, num=1) | |||
split_key = jax.random.split(key, ()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, I'd use:
split_key = jax.random.split(key, ()) | |
split_key = jax.random.split(key, num=(1,)) |
`model_output.shape` may only have rank 1. There are warnings related to use of random keys. ``` tests/schedulers/test_scheduler_flax.py: 13 warnings /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type] ```
Thanks a lot @ppham27! |
…ggingface#7703) `model_output.shape` may only have rank 1. There are warnings related to use of random keys. ``` tests/schedulers/test_scheduler_flax.py: 13 warnings /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type] ```
) `model_output.shape` may only have rank 1. There are warnings related to use of random keys. ``` tests/schedulers/test_scheduler_flax.py: 13 warnings /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype) tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error. u = uniform(key, shape, dtype, lo, hi) # type: ignore[arg-type] ```
model_output.shape
may only have rank 1.There are warnings related to use of random keys.
What does this PR do?
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.