Skip to content

Commit f19d018

Browse files
authored
Revert "Update deprecated Jax calls (#35919)" (#36880)
* Revert "Update deprecated Jax calls (#35919)" This reverts commit f0d5b2f. * Revert "Update deprecated Jax calls (#35919)" This reverts commit f0d5b2f. * udpate
1 parent 62116c9 commit f19d018

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

src/transformers/commands/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def run(self):
129129
flax_version = flax.__version__
130130
jax_version = jax.__version__
131131
jaxlib_version = jaxlib.__version__
132-
jax_backend = jax.default_backend()
132+
jax_backend = jax.lib.xla_bridge.get_backend().platform
133133

134134
info = {
135135
"`transformers` version": version,

src/transformers/models/longt5/modeling_flax_longt5.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
387387
relative_buckets += (relative_position > 0) * num_buckets
388388
relative_position = jnp.abs(relative_position)
389389
else:
390-
relative_position = -jnp.clip(relative_position, max=0)
390+
relative_position = -jnp.clip(relative_position, a_max=0)
391391
# now relative_position is in the range [0, inf)
392392

393393
# half of the buckets are for exact increments in positions
@@ -398,7 +398,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
398398
relative_position_if_large = max_exact + (
399399
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
400400
)
401-
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
401+
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
402402

403403
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
404404

@@ -672,7 +672,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
672672
relative_buckets += (relative_position > 0) * num_buckets
673673
relative_position = jnp.abs(relative_position)
674674
else:
675-
relative_position = -jnp.clip(relative_position, max=0)
675+
relative_position = -jnp.clip(relative_position, a_max=0)
676676
# now relative_position is in the range [0, inf)
677677

678678
# half of the buckets are for exact increments in positions
@@ -683,7 +683,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
683683
relative_position_if_large = max_exact + (
684684
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
685685
)
686-
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
686+
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
687687

688688
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
689689

@@ -895,7 +895,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
895895
relative_buckets += (relative_position > 0) * num_buckets
896896
relative_position = jnp.abs(relative_position)
897897
else:
898-
relative_position = -jnp.clip(relative_position, max=0)
898+
relative_position = -jnp.clip(relative_position, a_max=0)
899899
# now relative_position is in the range [0, inf)
900900

901901
# half of the buckets are for exact increments in positions
@@ -906,7 +906,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
906906
relative_position_if_large = max_exact + (
907907
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
908908
)
909-
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
909+
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
910910

911911
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
912912

src/transformers/models/t5/modeling_flax_t5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
247247
relative_buckets += (relative_position > 0) * num_buckets
248248
relative_position = jnp.abs(relative_position)
249249
else:
250-
relative_position = -jnp.clip(relative_position, max=0)
250+
relative_position = -jnp.clip(relative_position, a_max=0)
251251
# now relative_position is in the range [0, inf)
252252

253253
# half of the buckets are for exact increments in positions
@@ -258,7 +258,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
258258
relative_position_if_large = max_exact + (
259259
jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
260260
)
261-
relative_position_if_large = jnp.clip(relative_position_if_large, max=num_buckets - 1)
261+
relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
262262

263263
relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
264264

0 commit comments

Comments
 (0)