|
21 | 21 | import numpy as np
|
22 | 22 | import tensorflow.compat.v2 as tf
|
23 | 23 |
|
24 |
| -if JAX_MODE or NUMPY_MODE: |
25 |
| - numpy_ops = np |
| 24 | +if NUMPY_MODE: |
| 25 | + take_along_axis = np.take_along_axis |
| 26 | +elif JAX_MODE: |
| 27 | + from jax.numpy import take_along_axis |
26 | 28 | else:
|
27 |
| - from tensorflow.python.ops import numpy_ops |
| 29 | + from tensorflow.python.ops.numpy_ops import take_along_axis |
28 | 30 |
|
29 | 31 | from tensorflow_probability.python.internal import assert_util
|
30 | 32 | from tensorflow_probability.python.internal import distribution_util
|
@@ -800,10 +802,10 @@ def windowed_variance(
|
800 | 802 | def index_for_cumulative(indices):
|
801 | 803 | return tf.maximum(indices - 1, 0)
|
802 | 804 | cum_sums = tf.cumsum(x, axis=axis)
|
803 |
| - sums = numpy_ops.take_along_axis( |
| 805 | + sums = take_along_axis( |
804 | 806 | cum_sums, index_for_cumulative(indices), axis=axis)
|
805 | 807 | cum_variances = cumulative_variance(x, sample_axis=axis)
|
806 |
| - variances = numpy_ops.take_along_axis( |
| 808 | + variances = take_along_axis( |
807 | 809 | cum_variances, index_for_cumulative(indices), axis=axis)
|
808 | 810 |
|
809 | 811 | # This formula is the binary accurate variance merge from [1],
|
@@ -904,8 +906,7 @@ def windowed_mean(
|
904 | 906 | paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32),
|
905 | 907 | (rank, 2))
|
906 | 908 | cum_sums = ps.pad(raw_cumsum, paddings)
|
907 |
| - sums = numpy_ops.take_along_axis(cum_sums, indices, |
908 |
| - axis=axis) |
| 909 | + sums = take_along_axis(cum_sums, indices, axis=axis) |
909 | 910 | counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype)
|
910 | 911 | return tf.math.divide_no_nan(sums[1] - sums[0], counts)
|
911 | 912 |
|
|
0 commit comments