Skip to content

Commit 6d12203

Browse files
committed
Allow string in MarginalModel.unmarginalize
1 parent f4ae645 commit 6d12203

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pymc_experimental/model/marginal_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,10 @@ def _to_transformed(self):
292292
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
293293
return fn, transformed_names
294294

295-
def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable]):
295+
def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable | str]):
296296
for rv in rvs_to_unmarginalize:
297+
if isinstance(rv, str):
298+
rv = self[rv]
297299
self.marginalized_rvs.remove(rv)
298300
if rv.name in self._marginalized_named_vars_to_dims:
299301
dims = self._marginalized_named_vars_to_dims.pop(rv.name)

0 commit comments

Comments
 (0)