@@ -1049,7 +1049,7 @@ def test_car_moment(self, mu, size, expected):
1049
1049
)
1050
1050
def test_mvstudentt_moment (self , nu , mu , cov , size , expected ):
1051
1051
with pm .Model () as model :
1052
- x = pm .MvStudentT ("x" , nu = nu , mu = mu , cov = cov , size = size )
1052
+ x = pm .MvStudentT ("x" , nu = nu , mu = mu , scale = cov , size = size )
1053
1053
1054
1054
# MvStudentT logp is only impemented for up to 2D variables
1055
1055
assert_moment_is_expected (model , expected , check_finite_logp = x .ndim < 3 )
@@ -1369,28 +1369,28 @@ def test_issue_3706(self):
1369
1369
1370
1370
1371
1371
class TestMvStudentTCov (BaseTestDistributionRandom ):
1372
- def mvstudentt_rng_fn (self , size , nu , mu , cov , rng ):
1373
- mv_samples = rng .multivariate_normal (np .zeros_like (mu ), cov , size = size )
1372
+ def mvstudentt_rng_fn (self , size , nu , mu , scale , rng ):
1373
+ mv_samples = rng .multivariate_normal (np .zeros_like (mu ), scale , size = size )
1374
1374
chi2_samples = rng .chisquare (nu , size = size )
1375
1375
return (mv_samples / np .sqrt (chi2_samples [:, None ] / nu )) + mu
1376
1376
1377
1377
pymc_dist = pm .MvStudentT
1378
1378
pymc_dist_params = {
1379
1379
"nu" : 5 ,
1380
1380
"mu" : np .array ([1.0 , 2.0 ]),
1381
- "cov " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1381
+ "scale " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1382
1382
}
1383
1383
expected_rv_op_params = {
1384
1384
"nu" : 5 ,
1385
1385
"mu" : np .array ([1.0 , 2.0 ]),
1386
- "cov " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1386
+ "scale " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1387
1387
}
1388
1388
sizes_to_check = [None , (1 ), (2 , 3 )]
1389
1389
sizes_expected = [(2 ,), (1 , 2 ), (2 , 3 , 2 )]
1390
1390
reference_dist_params = {
1391
1391
"nu" : 5 ,
1392
1392
"mu" : np .array ([1.0 , 2.0 ]),
1393
- "cov " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1393
+ "scale " : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
1394
1394
}
1395
1395
reference_dist = lambda self : ft .partial (self .mvstudentt_rng_fn , rng = self .get_random_state ())
1396
1396
checks_to_run = [
@@ -1409,29 +1409,29 @@ def check_errors(self):
1409
1409
"mvstudentt" ,
1410
1410
nu = np .array ([1 , 2 ]),
1411
1411
mu = np .ones (2 ),
1412
- cov = np .full ((2 , 2 ), np .ones (2 )),
1412
+ scale = np .full ((2 , 2 ), np .ones (2 )),
1413
1413
)
1414
1414
1415
1415
def check_mu_broadcast_helper (self ):
1416
1416
"""Test that mu is broadcasted to the shape of cov"""
1417
- x = pm .MvStudentT .dist (nu = 4 , mu = 1 , cov = np .eye (3 ))
1417
+ x = pm .MvStudentT .dist (nu = 4 , mu = 1 , scale = np .eye (3 ))
1418
1418
mu = x .owner .inputs [4 ]
1419
1419
assert mu .eval ().shape == (3 ,)
1420
1420
1421
- x = pm .MvStudentT .dist (nu = 4 , mu = np .ones (1 ), cov = np .eye (3 ))
1421
+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones (1 ), scale = np .eye (3 ))
1422
1422
mu = x .owner .inputs [4 ]
1423
1423
assert mu .eval ().shape == (3 ,)
1424
1424
1425
- x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((1 , 1 )), cov = np .eye (3 ))
1425
+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((1 , 1 )), scale = np .eye (3 ))
1426
1426
mu = x .owner .inputs [4 ]
1427
1427
assert mu .eval ().shape == (1 , 3 )
1428
1428
1429
- x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((10 , 1 )), cov = np .eye (3 ))
1429
+ x = pm .MvStudentT .dist (nu = 4 , mu = np .ones ((10 , 1 )), scale = np .eye (3 ))
1430
1430
mu = x .owner .inputs [4 ]
1431
1431
assert mu .eval ().shape == (10 , 3 )
1432
1432
1433
1433
# Cov is artificually limited to being 2D
1434
- # x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1)), cov =np.full((2, 3, 3), np.eye(3)))
1434
+ # x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1)), scale =np.full((2, 3, 3), np.eye(3)))
1435
1435
# mu = x.owner.inputs[4]
1436
1436
# assert mu.eval().shape == (10, 2, 3)
1437
1437
@@ -1446,7 +1446,7 @@ class TestMvStudentTChol(BaseTestDistributionRandom):
1446
1446
expected_rv_op_params = {
1447
1447
"nu" : 5 ,
1448
1448
"mu" : np .array ([1.0 , 2.0 ]),
1449
- "cov " : quaddist_matrix (chol = pymc_dist_params ["chol" ]).eval (),
1449
+ "scale " : quaddist_matrix (chol = pymc_dist_params ["chol" ]).eval (),
1450
1450
}
1451
1451
checks_to_run = ["check_pymc_params_match_rv_op" ]
1452
1452
0 commit comments