Skip to content

Commit f0d5b2f

Browse files
rasmijakevdp
andauthored
Update deprecated Jax calls (#35919)
* Remove deprecated arguments for jax.numpy.clip. * Remove deprecated arguments for jax.numpy.clip. * Update jax version to 0.4.27 to 0.4.38. * Avoid use of deprecated xla_bridge.get_backend().platform Co-authored-by: Jake Vanderplas <[email protected]> --------- Co-authored-by: Jake Vanderplas <[email protected]>
1 parent 1ddb649 commit f0d5b2f

File tree

5 files changed

+13
-13
lines changed

5 files changed

+13
-13
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@
121121
"importlib_metadata",
122122
"ipadic>=1.0.0,<2.0",
123123
"isort>=5.5.4",
124-
"jax>=0.4.1,<=0.4.13",
125-
"jaxlib>=0.4.1,<=0.4.13",
124+
"jax>=0.4.27,<=0.4.38",
125+
"jaxlib>=0.4.27,<=0.4.38",
126126
"jieba",
127127
"jinja2>=3.1.0",
128128
"kenlm",

src/transformers/commands/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def run(self):
129129
flax_version = flax.__version__
130130
jax_version = jax.__version__
131131
jaxlib_version = jaxlib.__version__
132-
jax_backend = jax.lib.xla_bridge.get_backend().platform
132+
jax_backend = jax.default_backend()
133133

134134
info = {
135135
"`transformers` version": version,

src/transformers/dependency_versions_table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
"importlib_metadata": "importlib_metadata",
2929
"ipadic": "ipadic>=1.0.0,<2.0",
3030
"isort": "isort>=5.5.4",
31-
"jax": "jax>=0.4.1,<=0.4.13",
32-
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
31+
"jax": "jax>=0.4.27,<=0.4.38",
32+
"jaxlib": "jaxlib>=0.4.27,<=0.4.38",
3333
"jieba": "jieba",
3434
"jinja2": "jinja2>=3.1.0",
3535
"kenlm": "kenlm",

src/transformers/models/longt5/modeling_flax_longt5.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
387387
relative_buckets += (relative_position > 0) * num_buckets
388388
relative_position = jnp.abs(relative_position)
389389
else:
390-
relative_position = -jnp.clip(relative_position, a_max=0)
390+
relative_position = -jnp.clip(relative_position, max=0)
391391
# now relative_position is in the range [0, inf)
392392

393393
# half of the buckets are for exact increments in positions
@@ -398,7 +398,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
398398
relative_position_if_large = max_exact + (
399399
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
400400
)
401-
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
401+
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
402402

403403
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
404404

@@ -672,7 +672,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
672672
relative_buckets += (relative_position > 0) * num_buckets
673673
relative_position = jnp.abs(relative_position)
674674
else:
675-
relative_position = -jnp.clip(relative_position, a_max=0)
675+
relative_position = -jnp.clip(relative_position, max=0)
676676
# now relative_position is in the range [0, inf)
677677

678678
# half of the buckets are for exact increments in positions
@@ -683,7 +683,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
683683
relative_position_if_large = max_exact + (
684684
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
685685
)
686-
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
686+
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
687687

688688
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
689689

@@ -895,7 +895,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
895895
relative_buckets += (relative_position > 0) * num_buckets
896896
relative_position = jnp.abs(relative_position)
897897
else:
898-
relative_position = -jnp.clip(relative_position, a_max=0)
898+
relative_position = -jnp.clip(relative_position, max=0)
899899
# now relative_position is in the range [0, inf)
900900

901901
# half of the buckets are for exact increments in positions
@@ -906,7 +906,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
906906
relative_position_if_large = max_exact + (
907907
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
908908
)
909-
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
909+
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
910910

911911
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
912912

src/transformers/models/t5/modeling_flax_t5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
247247
relative_buckets += (relative_position > 0) * num_buckets
248248
relative_position = jnp.abs(relative_position)
249249
else:
250-
relative_position = -jnp.clip(relative_position, a_max=0)
250+
relative_position = -jnp.clip(relative_position, max=0)
251251
# now relative_position is in the range [0, inf)
252252

253253
# half of the buckets are for exact increments in positions
@@ -258,7 +258,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
258258
relative_position_if_large = max_exact + (
259259
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
260260
)
261-
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
261+
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
262262

263263
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
264264

0 commit comments

Comments
 (0)