Skip to content

Commit e35c402

Browse files
committed
Implement model transform that freezes RV dims and Data
1 parent bad219a commit e35c402

File tree

6 files changed

+160
-24
lines changed

6 files changed

+160
-24
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ jobs:
9696
tests/model/test_fgraph.py
9797
tests/model/transform/test_basic.py
9898
tests/model/transform/test_conditioning.py
99+
tests/model/transform/test_optimization.py
99100
tests/test_model_graph.py
100101
tests/ode/test_ode.py
101102
tests/ode/test_utils.py

docs/source/api/model/conditioning.rst renamed to docs/source/api/model/transform.rst

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@ Model Conditioning
55
.. autosummary::
66
:toctree: generated/
77

8-
change_value_transforms
98
do
109
observe
10+
change_value_transforms
1111
remove_value_transforms
12+
13+
14+
Model Optimization
15+
------------------
16+
.. currentmodule:: pymc.model.transform.optimization
17+
.. autosummary::
18+
:toctree: generated/
19+
20+
freeze_dims_and_data

pymc/model/core.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,8 +1121,7 @@ def set_data(
11211121

11221122
for d, dname in enumerate(dims):
11231123
length_tensor = self.dim_lengths[dname]
1124-
with pytensor.config.change_flags(cxx=""):
1125-
old_length = length_tensor.eval()
1124+
old_length = length_tensor.eval()
11261125
new_length = values.shape[d]
11271126
original_coords = self.coords.get(dname, None)
11281127
new_coords = coords.get(dname, None)
@@ -1404,24 +1403,22 @@ def create_value_var(
14041403
else:
14051404
transform = _default_transform(rv_var.owner.op, rv_var)
14061405

1407-
if value_var is not None:
1408-
if transform is not None:
1409-
raise ValueError("Cannot use transform when providing a pre-defined value_var")
1410-
elif transform is None:
1411-
# Create value variable with the same type as the RV
1412-
value_var = rv_var.type()
1413-
value_var.name = rv_var.name
1414-
if pytensor.config.compute_test_value != "off":
1415-
value_var.tag.test_value = rv_var.tag.test_value
1416-
else:
1417-
# Create value variable with the same type as the transformed RV
1418-
value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
1419-
value_var.name = f"{rv_var.name}_{transform.name}__"
1420-
value_var.tag.transform = transform
1421-
if pytensor.config.compute_test_value != "off":
1422-
value_var.tag.test_value = transform.forward(
1423-
rv_var, *rv_var.owner.inputs
1424-
).tag.test_value
1406+
if value_var is None:
1407+
if transform is None:
1408+
# Create value variable with the same type as the RV
1409+
value_var = rv_var.type()
1410+
value_var.name = rv_var.name
1411+
if pytensor.config.compute_test_value != "off":
1412+
value_var.tag.test_value = rv_var.tag.test_value
1413+
else:
1414+
# Create value variable with the same type as the transformed RV
1415+
value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
1416+
value_var.name = f"{rv_var.name}_{transform.name}__"
1417+
value_var.tag.transform = transform
1418+
if pytensor.config.compute_test_value != "off":
1419+
value_var.tag.test_value = transform.forward(
1420+
rv_var, *rv_var.owner.inputs
1421+
).tag.test_value
14251422

14261423
_add_future_warning_tag(value_var)
14271424
rv_var.tag.value_var = value_var

pymc/model/fgraph.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,9 +321,7 @@ def first_non_model_var(var):
321321
var, value, *dims = model_var.owner.inputs
322322
transform = model_var.owner.op.transform
323323
model.free_RVs.append(var)
324-
# PyMC does not allow setting transform when we pass a value_var. Why?
325-
model.create_value_var(var, transform=None, value_var=value)
326-
model.rvs_to_transforms[var] = transform
324+
model.create_value_var(var, transform=transform, value_var=value)
327325
model.set_initval(var, initval=None)
328326
elif isinstance(model_var.owner.op, ModelObservedRV):
329327
var, value, *dims = model_var.owner.inputs

pymc/model/transform/optimization.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytensor import clone_replace
15+
from pytensor.compile import SharedVariable
16+
from pytensor.graph import FunctionGraph
17+
from pytensor.tensor import constant
18+
19+
from pymc import Model
20+
from pymc.model.fgraph import ModelFreeRV, fgraph_from_model, model_from_fgraph
21+
22+
23+
def freeze_dims_and_data(model: Model) -> Model:
24+
"""Recreate a Model with fixed RV dimensions and Data values.
25+
26+
The dimensions of the pre-existing RVs will no longer follow changes to the coordinates.
27+
Likewise, it will not be possible to update pre-existing Data in the new model.
28+
29+
Note that any new RVs and Data created after calling this function will still be "unfrozen".
30+
31+
This transformation may allow more performant sampling, or compiling model functions to backends that
32+
are more restrictive about dynamic shapes such as JAX.
33+
"""
34+
fg, memo = fgraph_from_model(model)
35+
36+
# Replace mutable dim lengths and data by constants
37+
frozen_vars = {
38+
memo[dim_length]: constant(
39+
dim_length.get_value(), name=dim_length.name, dtype=dim_length.type.dtype
40+
)
41+
for dim_length in model.dim_lengths.values()
42+
if isinstance(dim_length, SharedVariable)
43+
}
44+
frozen_vars |= {
45+
memo[data_var].owner.inputs[0]: constant(
46+
data_var.get_value(), name=data_var.name, dtype=data_var.type.dtype
47+
)
48+
for data_var in model.named_vars.values()
49+
if isinstance(data_var, SharedVariable)
50+
}
51+
52+
old_outs, coords = fg.outputs, fg._coords # type: ignore
53+
# Rebuild strict will force the recreation of RV nodes with updated static types
54+
new_outs = clone_replace(old_outs, replace=frozen_vars, rebuild_strict=False) # type: ignore
55+
for old_out, new_out in zip(old_outs, new_outs):
56+
new_out.name = old_out.name
57+
fg = FunctionGraph(outputs=new_outs, clone=False)
58+
fg._coords = coords # type: ignore
59+
60+
# Recreate value variables from new RVs to propagate static types to logp graphs
61+
replacements = {}
62+
for node in fg.apply_nodes:
63+
if not isinstance(node.op, ModelFreeRV):
64+
continue
65+
rv, old_value, *dims = node.inputs
66+
if dims is None:
67+
continue
68+
transform = node.op.transform
69+
if transform is None:
70+
new_value = rv.type()
71+
else:
72+
new_value = transform.forward(rv, *rv.owner.inputs).type() # type: ignore
73+
new_value.name = old_value.name
74+
replacements[old_value] = new_value
75+
fg.replace_all(tuple(replacements.items()), import_missing=True)
76+
77+
return model_from_fgraph(fg)
78+
79+
80+
__all__ = ("freeze_dims_and_data",)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from pytensor.graph import Constant
15+
16+
from pymc.data import Data
17+
from pymc.distributions import HalfNormal, Normal
18+
from pymc.model import Model
19+
from pymc.model.transform.optimization import freeze_dims_and_data
20+
21+
22+
def test_freeze_existing_rv_dims_and_data():
23+
with Model(coords={"test_dim": range(5)}) as m:
24+
std = Data("std", [1])
25+
x = HalfNormal("x", std, dims=("test_dim",))
26+
y = Normal("y", shape=x.shape[0] + 1)
27+
28+
x_logp, y_logp = m.logp(sum=False)
29+
30+
assert not isinstance(std, Constant)
31+
assert x.type.shape == (None,)
32+
assert y.type.shape == (None,)
33+
assert x_logp.type.shape == (None,)
34+
assert y_logp.type.shape == (None,)
35+
36+
frozen_m = freeze_dims_and_data(m)
37+
std, x, y = frozen_m["std"], frozen_m["x"], frozen_m["y"]
38+
x_logp, y_logp = frozen_m.logp(sum=False)
39+
assert isinstance(std, Constant)
40+
assert x.type.shape == (5,)
41+
assert y.type.shape == (6,)
42+
assert x_logp.type.shape == (5,)
43+
assert y_logp.type.shape == (6,)
44+
45+
46+
def test_freeze_rv_dims_nothing_to_change():
47+
with Model(coords={"test_dim": range(5)}) as m:
48+
x = HalfNormal("x", shape=(5,))
49+
y = Normal("y", shape=x.shape[0] + 1)
50+
51+
assert m.point_logps() == freeze_dims_and_data(m).point_logps()

0 commit comments

Comments
 (0)