|
7 | 7 | import pytensor
|
8 | 8 | from pytensor.tensor.basic import arange, as_tensor_variable, constant
|
9 | 9 | from pytensor.tensor.random.op import RandomVariable
|
10 |
| -from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType |
11 | 10 | from pytensor.tensor.random.utils import (
|
12 | 11 | broadcast_params,
|
13 | 12 | )
|
14 |
| -from pytensor.tensor.random.var import ( |
15 |
| - RandomGeneratorSharedVariable, |
16 |
| - RandomStateSharedVariable, |
17 |
| -) |
18 | 13 |
|
19 | 14 |
|
20 | 15 | try:
|
@@ -605,7 +600,7 @@ def __call__(
|
605 | 600 | @classmethod
|
606 | 601 | def rng_fn_scipy(
|
607 | 602 | cls,
|
608 |
| - rng: np.random.Generator | np.random.RandomState, |
| 603 | + rng: np.random.Generator, |
609 | 604 | loc: np.ndarray | float,
|
610 | 605 | scale: np.ndarray | float,
|
611 | 606 | size: list[int] | int | None,
|
@@ -1548,7 +1543,7 @@ def __call__(self, n, p, size=None, **kwargs):
|
1548 | 1543 | binomial = BinomialRV()
|
1549 | 1544 |
|
1550 | 1545 |
|
1551 |
| -class NegBinomialRV(ScipyRandomVariable): |
| 1546 | +class NegBinomialRV(RandomVariable): |
1552 | 1547 | r"""A negative binomial discrete random variable.
|
1553 | 1548 |
|
1554 | 1549 | The probability mass function for `nbinom` for the number :math:`k` of draws
|
@@ -1588,13 +1583,8 @@ def __call__(self, n, p, size=None, **kwargs):
|
1588 | 1583 | """
|
1589 | 1584 | return super().__call__(n, p, size=size, **kwargs)
|
1590 | 1585 |
|
1591 |
| - @classmethod |
1592 |
| - def rng_fn_scipy(cls, rng, n, p, size): |
1593 |
| - return stats.nbinom.rvs(n, p, size=size, random_state=rng) |
1594 |
| - |
1595 | 1586 |
|
1596 |
| -nbinom = NegBinomialRV() |
1597 |
| -negative_binomial = NegBinomialRV() |
| 1587 | +negative_binomial = nbinom = NegBinomialRV() |
1598 | 1588 |
|
1599 | 1589 |
|
1600 | 1590 | class BetaBinomialRV(ScipyRandomVariable):
|
@@ -1842,58 +1832,6 @@ def rng_fn(cls, rng, p, size):
|
1842 | 1832 | categorical = CategoricalRV()
|
1843 | 1833 |
|
1844 | 1834 |
|
1845 |
| -class RandIntRV(RandomVariable): |
1846 |
| - r"""A discrete uniform random variable. |
1847 |
| -
|
1848 |
| - Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s. |
1849 |
| -
|
1850 |
| - """ |
1851 |
| - |
1852 |
| - name = "randint" |
1853 |
| - signature = "(),()->()" |
1854 |
| - dtype = "int64" |
1855 |
| - _print_name = ("randint", "\\operatorname{randint}") |
1856 |
| - |
1857 |
| - def __call__(self, low, high=None, size=None, **kwargs): |
1858 |
| - r"""Draw samples from a discrete uniform distribution. |
1859 |
| -
|
1860 |
| - Signature |
1861 |
| - --------- |
1862 |
| -
|
1863 |
| - `() -> ()` |
1864 |
| -
|
1865 |
| - Parameters |
1866 |
| - ---------- |
1867 |
| - low |
1868 |
| - Lower boundary of the output interval. All values generated will |
1869 |
| - be greater than or equal to `low`, unless `high=None`, in which case |
1870 |
| - all values generated are greater than or equal to `0` and |
1871 |
| - smaller than `low` (exclusive). |
1872 |
| - high |
1873 |
| - Upper boundary of the output interval. All values generated |
1874 |
| - will be smaller than `high` (exclusive). |
1875 |
| - size |
1876 |
| - Sample shape. If the given size is `(m, n, k)`, then `m * n * k` |
1877 |
| - independent, identically distributed samples are |
1878 |
| - returned. Default is `None`, in which case a single |
1879 |
| - sample is returned. |
1880 |
| -
|
1881 |
| - """ |
1882 |
| - if high is None: |
1883 |
| - low, high = 0, low |
1884 |
| - return super().__call__(low, high, size=size, **kwargs) |
1885 |
| - |
1886 |
| - def make_node(self, rng, *args, **kwargs): |
1887 |
| - if not isinstance( |
1888 |
| - getattr(rng, "type", None), RandomStateType | RandomStateSharedVariable |
1889 |
| - ): |
1890 |
| - raise TypeError("`randint` is only available for `RandomStateType`s") |
1891 |
| - return super().make_node(rng, *args, **kwargs) |
1892 |
| - |
1893 |
| - |
1894 |
| -randint = RandIntRV() |
1895 |
| - |
1896 |
| - |
1897 | 1835 | class IntegersRV(RandomVariable):
|
1898 | 1836 | r"""A discrete uniform random variable.
|
1899 | 1837 |
|
@@ -1933,14 +1871,6 @@ def __call__(self, low, high=None, size=None, **kwargs):
|
1933 | 1871 | low, high = 0, low
|
1934 | 1872 | return super().__call__(low, high, size=size, **kwargs)
|
1935 | 1873 |
|
1936 |
| - def make_node(self, rng, *args, **kwargs): |
1937 |
| - if not isinstance( |
1938 |
| - getattr(rng, "type", None), |
1939 |
| - RandomGeneratorType | RandomGeneratorSharedVariable, |
1940 |
| - ): |
1941 |
| - raise TypeError("`integers` is only available for `RandomGeneratorType`s") |
1942 |
| - return super().make_node(rng, *args, **kwargs) |
1943 |
| - |
1944 | 1874 |
|
1945 | 1875 | integers = IntegersRV()
|
1946 | 1876 |
|
@@ -1974,7 +1904,28 @@ def rng_fn(self, *params):
|
1974 | 1904 | p = None
|
1975 | 1905 | else:
|
1976 | 1906 | rng, a, p, replace, size = params
|
1977 |
| - return rng.choice(a, size, replace, p) |
| 1907 | + |
| 1908 | + batch_ndim = a.ndim - self.ndims_params[0] |
| 1909 | + |
| 1910 | + if size is not None: |
| 1911 | + a = np.broadcast_to(a, size + a.shape[-self.ndims_params[0] :]) |
| 1912 | + if p is not None: |
| 1913 | + p = np.broadcast_to(p, size + p.shape[-1:]) |
| 1914 | + elif p is not None: |
| 1915 | + a, p = broadcast_params([a, p], self.ndims_params) |
| 1916 | + |
| 1917 | + if batch_ndim: |
| 1918 | + # rng.choice does not have a concept of batch dimensionn |
| 1919 | + batch_shape = a.shape[:batch_ndim] |
| 1920 | + core_shape = a.shape[batch_ndim:-1] |
| 1921 | + out = np.empty(batch_shape + core_shape, dtype=a.dtype) |
| 1922 | + for idx in np.ndindex(batch_shape): |
| 1923 | + out[idx] = rng.choice( |
| 1924 | + a[idx], size=None, replace=replace, p=None if p is None else p[idx] |
| 1925 | + ) |
| 1926 | + return out |
| 1927 | + else: |
| 1928 | + return rng.choice(a, size=size, replace=replace, p=p) |
1978 | 1929 |
|
1979 | 1930 |
|
1980 | 1931 | def choice(a, size=None, replace=True, p=None, rng=None):
|
@@ -2079,7 +2030,6 @@ def permutation(x, **kwargs):
|
2079 | 2030 | "permutation",
|
2080 | 2031 | "choice",
|
2081 | 2032 | "integers",
|
2082 |
| - "randint", |
2083 | 2033 | "categorical",
|
2084 | 2034 | "multinomial",
|
2085 | 2035 | "betabinom",
|
|
0 commit comments