Skip to content

Update deprecated Jax calls #35919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@
"importlib_metadata",
"ipadic>=1.0.0,<2.0",
"isort>=5.5.4",
"jax>=0.4.1,<=0.4.13",
"jaxlib>=0.4.1,<=0.4.13",
"jax>=0.4.27,<=0.4.38",
"jaxlib>=0.4.27,<=0.4.38",
"jieba",
"jinja2>=3.1.0",
"kenlm",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/commands/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def run(self):
flax_version = flax.__version__
jax_version = jax.__version__
jaxlib_version = jaxlib.__version__
jax_backend = jax.lib.xla_bridge.get_backend().platform
jax_backend = jax.default_backend()

info = {
"`transformers` version": version,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
"importlib_metadata": "importlib_metadata",
"ipadic": "ipadic>=1.0.0,<2.0",
"isort": "isort>=5.5.4",
"jax": "jax>=0.4.1,<=0.4.13",
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
"jax": "jax>=0.4.27,<=0.4.38",
"jaxlib": "jaxlib>=0.4.27,<=0.4.38",
"jieba": "jieba",
"jinja2": "jinja2>=3.1.0",
"kenlm": "kenlm",
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/models/longt5/modeling_flax_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
relative_position = -jnp.clip(relative_position, max=0)
# now relative_position is in the range [0, inf)

# half of the buckets are for exact increments in positions
Expand All @@ -398,7 +398,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)

relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

Expand Down Expand Up @@ -672,7 +672,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
relative_position = -jnp.clip(relative_position, max=0)
# now relative_position is in the range [0, inf)

# half of the buckets are for exact increments in positions
Expand All @@ -683,7 +683,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)

relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

Expand Down Expand Up @@ -895,7 +895,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
relative_position = -jnp.clip(relative_position, max=0)
# now relative_position is in the range [0, inf)

# half of the buckets are for exact increments in positions
Expand All @@ -906,7 +906,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)

relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/t5/modeling_flax_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_buckets += (relative_position > 0) * num_buckets
relative_position = jnp.abs(relative_position)
else:
relative_position = -jnp.clip(relative_position, a_max=0)
relative_position = -jnp.clip(relative_position, max=0)
# now relative_position is in the range [0, inf)

# half of the buckets are for exact increments in positions
Expand All @@ -258,7 +258,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
relative_position_if_large = max_exact + (
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
)
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)

relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)

Expand Down