We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f4ae645 commit 6d12203Copy full SHA for 6d12203
pymc_experimental/model/marginal_model.py
@@ -292,8 +292,10 @@ def _to_transformed(self):
292
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
293
return fn, transformed_names
294
295
- def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable]):
+ def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable | str]):
296
for rv in rvs_to_unmarginalize:
297
+ if isinstance(rv, str):
298
+ rv = self[rv]
299
self.marginalized_rvs.remove(rv)
300
if rv.name in self._marginalized_named_vars_to_dims:
301
dims = self._marginalized_named_vars_to_dims.pop(rv.name)
0 commit comments