Skip to content

Check shape and remove deprecated APIs in scheduling_ddpm_flax.py #7703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2024

Conversation

ppham27
Copy link
Contributor

@ppham27 ppham27 commented Apr 17, 2024

model_output.shape may only have rank 1.

There are warnings related to use of random keys.

tests/schedulers/test_scheduler_flax.py: 13 warnings
  /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)

tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas
  /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    u = uniform(key, shape, dtype, lo, hi)  # type: ignore[arg-type]

What does this PR do?

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yiyixuxu yiyixuxu requested a review from pcuenca April 29, 2024 22:56

if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
if (
len(model_output.shape) > 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we adding this len() > 1 check here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because input could have shape like x.shape == (10,). Accessing index 1 in the next line is an error in that case.

@@ -264,7 +268,7 @@ def step(

# 6. Add noise
def random_variance():
split_key = jax.random.split(key, num=1)
split_key = jax.random.split(key, ())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we make this change here? num=1 would still work with jax.random.key, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I'd use:

Suggested change
split_key = jax.random.split(key, ())
split_key = jax.random.split(key, num=(1,))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't fix the batching dimension error. In any case, I have done jax.random.split(key, num=1)[0] to remove the batch dimension.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes you are right!

@ppham27 ppham27 force-pushed the jax branch 2 times, most recently from 3a655d9 to 3e6cde9 Compare April 30, 2024 01:39
Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! However, please note that the project dependencies are defined in

"jax>=0.4.1",
, and then this file is generated.

If we change this code, we need to ensure that it still works with the existing versions (can probably be confirmed with the CI tests), or upgrade the dependencies if necessary.

@@ -13,9 +13,9 @@ RUN apt update && \
ca-certificates \
libsndfile1-dev \
libgl1 \
python3.8 \
python3.10 \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diffusers codebase is compatible with Python 3.8, and the other Dockerfiles use that version too:

.

Is it a requirement to upgrade?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I reverted the upgrade. But note Python 3.8 is already not supported and even Python 3.10 will lose support EOY, https://scientific-python.org/specs/spec-0000/ and https://jax.readthedocs.io/en/latest/deprecation.html.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you are right, we should probably upgrade at the project level cc @yiyixuxu @sayakpaul

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @sayakpaul here! let's upgrade?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -264,7 +268,7 @@ def step(

# 6. Add noise
def random_variance():
split_key = jax.random.split(key, num=1)
split_key = jax.random.split(key, ())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I'd use:

Suggested change
split_key = jax.random.split(key, ())
split_key = jax.random.split(key, num=(1,))

`model_output.shape` may only have rank 1.

There are warnings related to use of random keys.

```
tests/schedulers/test_scheduler_flax.py: 13 warnings
  /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)

tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas
  /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    u = uniform(key, shape, dtype, lo, hi)  # type: ignore[arg-type]
```
@pcuenca
Copy link
Member

pcuenca commented May 8, 2024

Thanks a lot @ppham27!

@pcuenca pcuenca merged commit f29b934 into huggingface:main May 8, 2024
15 checks passed
lawrence-cj pushed a commit to lawrence-cj/diffusers that referenced this pull request May 8, 2024
…ggingface#7703)

`model_output.shape` may only have rank 1.

There are warnings related to use of random keys.

```
tests/schedulers/test_scheduler_flax.py: 13 warnings
  /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)

tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas
  /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    u = uniform(key, shape, dtype, lo, hi)  # type: ignore[arg-type]
```
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
)

`model_output.shape` may only have rank 1.

There are warnings related to use of random keys.

```
tests/schedulers/test_scheduler_flax.py: 13 warnings
  /Users/phillypham/diffusers/src/diffusers/schedulers/scheduling_ddpm_flax.py:268: FutureWarning: normal accepts a single key, but was given a key array of shape (1, 2) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    noise = jax.random.normal(split_key, shape=model_output.shape, dtype=self.dtype)

tests/schedulers/test_scheduler_flax.py::FlaxDDPMSchedulerTest::test_betas
  /Users/phillypham/virtualenv/diffusers/lib/python3.9/site-packages/jax/_src/random.py:731: FutureWarning: uniform accepts a single key, but was given a key array of shape (1,) != (). Use jax.vmap for batching. In a future JAX version, this will be an error.
    u = uniform(key, shape, dtype, lo, hi)  # type: ignore[arg-type]
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants