Skip to content

Commit 4002d8b

Browse files
committed
Fix take_along_axis import for jax backend
1 parent 8d20563 commit 4002d8b

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

tensorflow_probability/python/stats/sample_stats.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
import numpy as np
2222
import tensorflow.compat.v2 as tf
2323

24-
if JAX_MODE or NUMPY_MODE:
25-
numpy_ops = np
24+
if NUMPY_MODE:
25+
take_along_axis = np.take_along_axis
26+
elif JAX_MODE:
27+
from jax.numpy import take_along_axis
2628
else:
27-
from tensorflow.python.ops import numpy_ops
29+
from tensorflow.python.ops.numpy_ops import take_along_axis
2830

2931
from tensorflow_probability.python.internal import assert_util
3032
from tensorflow_probability.python.internal import distribution_util
@@ -800,10 +802,10 @@ def windowed_variance(
800802
def index_for_cumulative(indices):
801803
return tf.maximum(indices - 1, 0)
802804
cum_sums = tf.cumsum(x, axis=axis)
803-
sums = numpy_ops.take_along_axis(
805+
sums = take_along_axis(
804806
cum_sums, index_for_cumulative(indices), axis=axis)
805807
cum_variances = cumulative_variance(x, sample_axis=axis)
806-
variances = numpy_ops.take_along_axis(
808+
variances = take_along_axis(
807809
cum_variances, index_for_cumulative(indices), axis=axis)
808810

809811
# This formula is the binary accurate variance merge from [1],
@@ -904,8 +906,7 @@ def windowed_mean(
904906
paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32),
905907
(rank, 2))
906908
cum_sums = ps.pad(raw_cumsum, paddings)
907-
sums = numpy_ops.take_along_axis(cum_sums, indices,
908-
axis=axis)
909+
sums = take_along_axis(cum_sums, indices, axis=axis)
909910
counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype)
910911
return tf.math.divide_no_nan(sums[1] - sums[0], counts)
911912

0 commit comments

Comments
 (0)