@@ -387,7 +387,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
387
387
relative_buckets += (relative_position > 0 ) * num_buckets
388
388
relative_position = jnp .abs (relative_position )
389
389
else :
390
- relative_position = - jnp .clip (relative_position , max = 0 )
390
+ relative_position = - jnp .clip (relative_position , a_max = 0 )
391
391
# now relative_position is in the range [0, inf)
392
392
393
393
# half of the buckets are for exact increments in positions
@@ -398,7 +398,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
398
398
relative_position_if_large = max_exact + (
399
399
jnp .log (relative_position / max_exact ) / jnp .log (max_distance / max_exact ) * (num_buckets - max_exact )
400
400
)
401
- relative_position_if_large = jnp .clip (relative_position_if_large , max = num_buckets - 1 )
401
+ relative_position_if_large = jnp .clip (relative_position_if_large , a_max = num_buckets - 1 )
402
402
403
403
relative_buckets += jnp .where (is_small , relative_position , relative_position_if_large )
404
404
@@ -672,7 +672,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
672
672
relative_buckets += (relative_position > 0 ) * num_buckets
673
673
relative_position = jnp .abs (relative_position )
674
674
else :
675
- relative_position = - jnp .clip (relative_position , max = 0 )
675
+ relative_position = - jnp .clip (relative_position , a_max = 0 )
676
676
# now relative_position is in the range [0, inf)
677
677
678
678
# half of the buckets are for exact increments in positions
@@ -683,7 +683,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
683
683
relative_position_if_large = max_exact + (
684
684
jnp .log (relative_position / max_exact ) / jnp .log (max_distance / max_exact ) * (num_buckets - max_exact )
685
685
)
686
- relative_position_if_large = jnp .clip (relative_position_if_large , max = num_buckets - 1 )
686
+ relative_position_if_large = jnp .clip (relative_position_if_large , a_max = num_buckets - 1 )
687
687
688
688
relative_buckets += jnp .where (is_small , relative_position , relative_position_if_large )
689
689
@@ -895,7 +895,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
895
895
relative_buckets += (relative_position > 0 ) * num_buckets
896
896
relative_position = jnp .abs (relative_position )
897
897
else :
898
- relative_position = - jnp .clip (relative_position , max = 0 )
898
+ relative_position = - jnp .clip (relative_position , a_max = 0 )
899
899
# now relative_position is in the range [0, inf)
900
900
901
901
# half of the buckets are for exact increments in positions
@@ -906,7 +906,7 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
906
906
relative_position_if_large = max_exact + (
907
907
jnp .log (relative_position / max_exact ) / jnp .log (max_distance / max_exact ) * (num_buckets - max_exact )
908
908
)
909
- relative_position_if_large = jnp .clip (relative_position_if_large , max = num_buckets - 1 )
909
+ relative_position_if_large = jnp .clip (relative_position_if_large , a_max = num_buckets - 1 )
910
910
911
911
relative_buckets += jnp .where (is_small , relative_position , relative_position_if_large )
912
912
0 commit comments