|
33 | 33 | from pytensor.tensor.nlinalg import det, eigh, matrix_inverse, trace
|
34 | 34 | from pytensor.tensor.random.basic import dirichlet, multinomial, multivariate_normal
|
35 | 35 | from pytensor.tensor.random.op import RandomVariable, default_supp_shape_from_params
|
36 |
| -from pytensor.tensor.random.utils import broadcast_params |
| 36 | +from pytensor.tensor.random.utils import ( |
| 37 | + broadcast_params, |
| 38 | + supp_shape_from_ref_param_shape, |
| 39 | +) |
37 | 40 | from pytensor.tensor.slinalg import Cholesky, SolveTriangular
|
38 | 41 | from pytensor.tensor.type import TensorType
|
39 | 42 | from scipy import linalg, stats
|
@@ -2229,6 +2232,14 @@ class ICARRV(RandomVariable):
|
2229 | 2232 | def __call__(self, W, node1, node2, N, sigma, zero_sum_stdev, size=None, **kwargs):
|
2230 | 2233 | return super().__call__(W, node1, node2, N, sigma, zero_sum_stdev, size=size, **kwargs)
|
2231 | 2234 |
|
| 2235 | + def _supp_shape_from_params(self, dist_params, param_shapes=None): |
| 2236 | + return supp_shape_from_ref_param_shape( |
| 2237 | + ndim_supp=self.ndim_supp, |
| 2238 | + dist_params=dist_params, |
| 2239 | + param_shapes=param_shapes, |
| 2240 | + ref_param_idx=0, |
| 2241 | + ) |
| 2242 | + |
2232 | 2243 | @classmethod
|
2233 | 2244 | def rng_fn(cls, rng, size, W, node1, node2, N, sigma, zero_sum_stdev):
|
2234 | 2245 | raise NotImplementedError("Cannot sample from ICAR prior")
|
|
0 commit comments