Skip to content

Merge redundant code across logprob, pytensorf and distributions/transform #6976

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 38 additions & 71 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@
import pymc as pm

from pymc.logprob.transforms import (
ChainedTransform,
CircularTransform,
IntervalTransform,
LogOddsTransform,
LogTransform,
RVTransform,
SimplexTransform,
Transform,
)

__all__ = [
"RVTransform",
"Transform",
"simplex",
"logodds",
"Interval",
Expand All @@ -60,6 +61,10 @@ def __getattr__(name):
warnings.warn(f"{name} has been deprecated, use sum_to_1 instead.", FutureWarning)
return sum_to_1

if name == "RVTransform":
warnings.warn("RVTransform has been renamed to Transform", FutureWarning)
return Transform

raise AttributeError(f"module {__name__} has no attribute {name}")


Expand All @@ -69,7 +74,7 @@ def _default_transform(op: Op, rv: TensorVariable):
return None


class LogExpM1(RVTransform):
class LogExpM1(Transform):
name = "log_exp_m1"

def backward(self, value, *inputs):
Expand All @@ -87,7 +92,7 @@ def log_jac_det(self, value, *inputs):
return -pt.softplus(-value)


class Ordered(RVTransform):
class Ordered(Transform):
name = "ordered"

def __init__(self, ndim_supp=None):
Expand All @@ -110,7 +115,7 @@ def log_jac_det(self, value, *inputs):
return pt.sum(value[..., 1:], axis=-1)


class SumTo1(RVTransform):
class SumTo1(Transform):
"""
Transforms K - 1 dimensional simplex space (k values in [0,1] and that sum to 1) to a K - 1 vector of values in [0,1]
This Transformation operates on the last dimension of the input tensor.
Expand All @@ -134,7 +139,7 @@ def log_jac_det(self, value, *inputs):
return pt.sum(y, axis=-1)


class CholeskyCovPacked(RVTransform):
class CholeskyCovPacked(Transform):
"""
Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the
log scale
Expand Down Expand Up @@ -162,45 +167,7 @@ def log_jac_det(self, value, *inputs):
return pt.sum(value[..., self.diag_idxs], axis=-1)


class Chain(RVTransform):
__slots__ = ("param_extract_fn", "transform_list", "name")

def __init__(self, transform_list):
self.transform_list = transform_list
self.name = "+".join([transf.name for transf in self.transform_list])

def forward(self, value, *inputs):
y = value
for transf in self.transform_list:
# TODO:Needs proper discussion as to what should be
# passed as inputs here
y = transf.forward(y, *inputs)
return y

def backward(self, value, *inputs):
x = value
for transf in reversed(self.transform_list):
x = transf.backward(x, *inputs)
return x

def log_jac_det(self, value, *inputs):
y = pt.as_tensor_variable(value)
det_list = []
ndim0 = y.ndim
for transf in reversed(self.transform_list):
det_ = transf.log_jac_det(y, *inputs)
det_list.append(det_)
y = transf.backward(y, *inputs)
ndim0 = min(ndim0, det_.ndim)
# match the shape of the smallest log_jac_det
det = 0.0
for det_ in det_list:
if det_.ndim > ndim0:
det += det_.sum(axis=-1)
else:
det += det_
return det

Chain = ChainedTransform

simplex = SimplexTransform()
simplex.__doc__ = """
Expand Down Expand Up @@ -297,7 +264,7 @@ def bounds_fn(*rv_inputs):
super().__init__(args_fn=bounds_fn)


class ZeroSumTransform(RVTransform):
class ZeroSumTransform(Transform):
"""
Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``.

Expand All @@ -314,43 +281,43 @@ class ZeroSumTransform(RVTransform):
def __init__(self, zerosum_axes):
self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes)

@staticmethod
def extend_axis(array, axis):
n = pm.floatX(array.shape[axis] + 1)
sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (pt.sqrt(n) + n)
fill_val = norm - sum_vals / pt.sqrt(n)

out = pt.concatenate([array, fill_val], axis=axis)
return out - norm

@staticmethod
def extend_axis_rev(array, axis):
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]

n = pm.floatX(array.shape[normalized_axis])
last = pt.take(array, [-1], axis=normalized_axis)

sum_vals = -last * pt.sqrt(n)
norm = sum_vals / (pt.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis

return array[slice_before + (slice(None, -1),)] + norm

def forward(self, value, *rv_inputs):
for axis in self.zerosum_axes:
value = extend_axis_rev(value, axis=axis)
value = self.extend_axis_rev(value, axis=axis)
return value

def backward(self, value, *rv_inputs):
for axis in self.zerosum_axes:
value = extend_axis(value, axis=axis)
value = self.extend_axis(value, axis=axis)
return value

def log_jac_det(self, value, *rv_inputs):
return pt.constant(0.0)


def extend_axis(array, axis):
n = pm.floatX(array.shape[axis] + 1)
sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (pt.sqrt(n) + n)
fill_val = norm - sum_vals / pt.sqrt(n)

out = pt.concatenate([array, fill_val], axis=axis)
return out - norm


def extend_axis_rev(array, axis):
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]

n = pm.floatX(array.shape[normalized_axis])
last = pt.take(array, [-1], axis=normalized_axis)

sum_vals = -last * pt.sqrt(n)
norm = sum_vals / (pt.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis

return array[slice_before + (slice(None, -1),)] + norm


log_exp_m1 = LogExpM1()
log_exp_m1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
Expand Down
5 changes: 2 additions & 3 deletions pymc/gp/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
"Kron",
]

from pymc.pytensorf import constant_fold

TensorLike = Union[np.ndarray, TensorVariable]
IntSequence = Union[np.ndarray, Sequence[int]]

Expand Down Expand Up @@ -183,9 +185,6 @@ def n_dims(self) -> int:
def _slice(self, X, Xs=None):
xdims = X.shape[-1]
if isinstance(xdims, Variable):
# Circular dependency
from pymc.pytensorf import constant_fold

[xdims] = constant_fold([xdims])
if self.input_dim != xdims:
warnings.warn(
Expand Down
5 changes: 3 additions & 2 deletions pymc/gp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import pytensor.tensor as pt

from pytensor.compile import SharedVariable
from pytensor.graph import ancestors
from pytensor.tensor.variable import TensorConstant
from scipy.cluster.vq import kmeans

# Avoid circular dependency when importing modelcontext
from pymc.distributions.distribution import Distribution
from pymc.model import modelcontext
from pymc.pytensorf import compile_pymc, walk_model
from pymc.pytensorf import compile_pymc

_ = Distribution # keep both pylint and black happy

Expand All @@ -48,7 +49,7 @@ def replace_with_values(vars_needed, replacements=None, model=None):
model = modelcontext(model)

inputs, input_names = [], []
for rv in walk_model(vars_needed):
for rv in ancestors(vars_needed):
if rv in model.named_vars.values() and not isinstance(rv, SharedVariable):
inputs.append(rv)
input_names.append(rv.name)
Expand Down
4 changes: 2 additions & 2 deletions pymc/initial_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.transforms import RVTransform
from pymc.logprob.transforms import Transform
from pymc.pytensorf import compile_pymc, find_rng_nodes, replace_rng_nodes, reseed_rngs
from pymc.util import get_transformed_name, get_untransformed_name, is_transformed_name

Expand Down Expand Up @@ -177,7 +177,7 @@ def inner(seed, *args, **kwargs):
def make_initial_point_expression(
*,
free_rvs: Sequence[TensorVariable],
rvs_to_transforms: Dict[TensorVariable, RVTransform],
rvs_to_transforms: Dict[TensorVariable, Transform],
initval_strategies: Dict[TensorVariable, Optional[Union[np.ndarray, Variable, str]]],
jitter_rvs: Set[TensorVariable] = None,
default_strategy: str = "moment",
Expand Down
20 changes: 9 additions & 11 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@
)
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
from pymc.logprob.transform_value import TransformValuesRewrite
from pymc.logprob.transforms import RVTransform
from pymc.logprob.utils import find_rvs_in_graph, rvs_to_value_vars
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import rvs_in_graph
from pymc.pytensorf import replace_vars_in_graphs

TensorLike: TypeAlias = Union[Variable, float, np.ndarray]

Expand All @@ -76,7 +77,7 @@ def _find_unallowed_rvs_in_graph(graph):

return {
rv
for rv in find_rvs_in_graph(graph)
for rv in rvs_in_graph(graph)
if not isinstance(rv.owner.op, (SimulatorRV, MinibatchIndexRV))
}

Expand Down Expand Up @@ -530,11 +531,9 @@ def conditional_logp(
continue

# Replace `RandomVariable`s in the inputs with value variables.
# Also, store the results in the `replacements` map for the nodes
# that follow.
remapped_vars, _ = rvs_to_value_vars(
q_values + list(node.inputs),
initial_replacements=replacements,
remapped_vars = replace_vars_in_graphs(
graphs=q_values + list(node.inputs),
replacements=replacements,
)
q_values = remapped_vars[: len(q_values)]
q_rv_inputs = remapped_vars[len(q_values) :]
Expand Down Expand Up @@ -562,8 +561,7 @@ def conditional_logp(

logprob_vars[q_value_var] = q_logprob_var

# Recompute test values for the changes introduced by the
# replacements above.
# Recompute test values for the changes introduced by the replacements above.
if config.compute_test_value != "off":
for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars):
compute_test_value(node)
Expand All @@ -589,7 +587,7 @@ def transformed_conditional_logp(
rvs: Sequence[TensorVariable],
*,
rvs_to_values: Dict[TensorVariable, TensorVariable],
rvs_to_transforms: Dict[TensorVariable, RVTransform],
rvs_to_transforms: Dict[TensorVariable, Transform],
jacobian: bool = True,
**kwargs,
) -> List[TensorVariable]:
Expand Down
3 changes: 1 addition & 2 deletions pymc/logprob/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import replace_rvs_by_values


class MeasurableSpecifyShape(SpecifyShape):
Expand Down Expand Up @@ -107,8 +108,6 @@ class MeasurableCheckAndRaise(CheckAndRaise):

@_logprob.register(MeasurableCheckAndRaise)
def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):
from pymc.pytensorf import replace_rvs_by_values

(value,) = values
# transfer assertion from rv to value
assertions = replace_rvs_by_values(assertions, rvs_to_values={inner_rv: value})
Expand Down
10 changes: 2 additions & 8 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@
measurable_ir_rewrites_db,
subtensor_ops,
)
from pymc.logprob.utils import check_potential_measurability
from pymc.logprob.utils import check_potential_measurability, replace_rvs_by_values
from pymc.pytensorf import constant_fold


def is_newaxis(x):
Expand Down Expand Up @@ -255,9 +256,6 @@ def get_stack_mixture_vars(
mixture_rvs = joined_rvs.owner.inputs

elif isinstance(joined_rvs.owner.op, Join):
# TODO: Find better solution to avoid this circular dependency
from pymc.pytensorf import constant_fold

join_axis = joined_rvs.owner.inputs[0]
# TODO: Support symbolic join axes. This will raise ValueError if it's not a constant
(join_axis,) = constant_fold((join_axis,), raise_not_constant=False)
Expand Down Expand Up @@ -351,9 +349,6 @@ def logprob_MixtureRV(
comp_rvs = [comp[None] for comp in comp_rvs]
original_shape = (len(comp_rvs),)
else:
# TODO: Find better solution to avoid this circular dependency
from pymc.pytensorf import constant_fold

join_axis_val = constant_fold((join_axis,))[0].item()
original_shape = shape_tuple(comp_rvs[0])

Expand Down Expand Up @@ -544,7 +539,6 @@ def find_measurable_ifelse_mixture(fgraph, node):
@_logprob.register(MeasurableIfElse)
def logprob_ifelse(op, values, if_var, *base_rvs, **kwargs):
"""Compute the log-likelihood graph for an `IfElse`."""
from pymc.pytensorf import replace_rvs_by_values

assert len(values) * 2 == len(base_rvs)

Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
logprob_rewrites_db,
measurable_ir_rewrites_db,
)
from pymc.pytensorf import replace_rvs_by_values
from pymc.logprob.utils import replace_rvs_by_values


class MeasurableScan(Scan):
Expand Down
Loading