Skip to content

Commit 3a304d6

Browse files
authored
Remove intX and floatX from distributions (#7114)
1 parent 0d8ddba commit 3a304d6

12 files changed

+144
-146
lines changed

pymc/distributions/bound.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from pymc.distributions.transforms import _default_transform
2828
from pymc.logprob.basic import logp
2929
from pymc.model import modelcontext
30-
from pymc.pytensorf import floatX, intX
3130
from pymc.util import check_dist_not_registered
3231

3332
__all__ = ["Bound"]
@@ -206,7 +205,7 @@ def __new__(
206205
res = _ContinuousBounded(
207206
name,
208207
[dist, lower, upper],
209-
initval=floatX(initval),
208+
initval=initval.astype("float"),
210209
size=size,
211210
shape=shape,
212211
**kwargs,
@@ -215,7 +214,7 @@ def __new__(
215214
res = _DiscreteBounded(
216215
name,
217216
[dist, lower, upper],
218-
initval=intX(initval),
217+
initval=initval.astype("int"),
219218
size=size,
220219
shape=shape,
221220
**kwargs,
@@ -241,15 +240,15 @@ def dist(
241240
shape=shape,
242241
**kwargs,
243242
)
244-
res.tag.test_value = floatX(initval)
243+
res.tag.test_value = initval
245244
else:
246245
res = _DiscreteBounded.dist(
247246
[dist, lower, upper],
248247
size=size,
249248
shape=shape,
250249
**kwargs,
251250
)
252-
res.tag.test_value = intX(initval)
251+
res.tag.test_value = initval.astype("int")
253252
return res
254253

255254
@classmethod
@@ -286,9 +285,9 @@ def _set_values(cls, lower, upper, size, shape, initval):
286285
size = shape
287286

288287
lower = np.asarray(lower)
289-
lower = floatX(np.where(lower == None, -np.inf, lower)) # noqa E711
288+
lower = np.where(lower == None, -np.inf, lower) # noqa E711
290289
upper = np.asarray(upper)
291-
upper = floatX(np.where(upper == None, np.inf, upper)) # noqa E711
290+
upper = np.where(upper == None, np.inf, upper) # noqa E711
292291

293292
if initval is None:
294293
_size = np.broadcast_shapes(to_tuple(size), np.shape(lower), np.shape(upper))
@@ -303,7 +302,6 @@ def _set_values(cls, lower, upper, size, shape, initval):
303302
np.where(_upper == np.inf, _lower + 1, (_lower + _upper) / 2),
304303
),
305304
)
306-
307-
lower = as_tensor_variable(floatX(lower))
308-
upper = as_tensor_variable(floatX(upper))
305+
lower = as_tensor_variable(lower, dtype="floatX")
306+
upper = as_tensor_variable(upper, dtype="floatX")
309307
return lower, upper, initval

0 commit comments

Comments
 (0)