Skip to content

BUG: Dirichlet is not tolerant to floatX=float32 #6779

Closed
@ferrine

Description

@ferrine

Describe the issue:

Dirichlet distribution ignores floatX config and creates float64 variables in the graph

Reproduceable code example:

import pymc as pm
import pytensor.tensor as pt
import pytensor

def test_dirichlet():
    with pm.Model() as model:
        c = pm.floatX([1, 1, 1])
        print(c, c.dtype)
        d = pm.Dirichlet("a", c)
    print(model.point_logps())
    
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet()

Error message:

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
Cell In[10], line 2
      1 with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
----> 2     test_dirichlet()

Cell In[8], line 5, in test_dirichlet()
      3     c = pm.floatX([1, 1, 1])
      4     print(c, c.dtype)
----> 5     d = pm.Dirichlet("a", c)
      6 print(model.point_logps())

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/distributions/distribution.py:314, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    310         kwargs["shape"] = tuple(observed.shape)
    312 rv_out = cls.dist(*args, **kwargs)
--> 314 rv_out = model.register_rv(
    315     rv_out,
    316     name,
    317     observed,
    318     total_size,
    319     dims=dims,
    320     transform=transform,
    321     initval=initval,
    322 )
    324 # add in pretty-printing support
    325 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/model.py:1333, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
   1331     raise ValueError("total_size can only be passed to observed RVs")
   1332 self.free_RVs.append(rv_var)
-> 1333 self.create_value_var(rv_var, transform)
   1334 self.add_named_variable(rv_var, dims)
   1335 self.set_initval(rv_var, initval)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/model.py:1526, in Model.create_value_var(self, rv_var, transform, value_var)
   1523         value_var.tag.test_value = rv_var.tag.test_value
   1524 else:
   1525     # Create value variable with the same type as the transformed RV
-> 1526     value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
   1527     value_var.name = f"{rv_var.name}_{transform.name}__"
   1528     value_var.tag.transform = transform

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/logprob/transforms.py:985, in SimplexTransform.forward(self, value, *inputs)
    983 def forward(self, value, *inputs):
    984     log_value = pt.log(value)
--> 985     shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1]
    986     return log_value[..., :-1] - shift

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/var.py:173, in _tensor_py_operators.__truediv__(self, other)
    172 def __truediv__(self, other):
--> 173     return at.math.true_div(self, other)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/op.py:295, in Op.__call__(self, *inputs, **kwargs)
    253 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    254 
    255 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    292 
    293 """
    294 return_list = kwargs.pop("return_list", False)
--> 295 node = self.make_node(*inputs, **kwargs)
    297 if config.compute_test_value != "off":
    298     compute_test_value(node)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/elemwise.py:486, in Elemwise.make_node(self, *inputs)
    484 inputs = [as_tensor_variable(i) for i in inputs]
    485 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
--> 486 outputs = [
    487     TensorType(dtype=dtype, shape=shape)()
    488     for dtype, shape in zip(out_dtypes, out_shapes)
    489 ]
    490 return Apply(self, inputs, outputs)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/elemwise.py:487, in <listcomp>(.0)
    484 inputs = [as_tensor_variable(i) for i in inputs]
    485 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
    486 outputs = [
--> 487     TensorType(dtype=dtype, shape=shape)()
    488     for dtype, shape in zip(out_dtypes, out_shapes)
    489 ]
    490 return Apply(self, inputs, outputs)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/type.py:228, in Type.__call__(self, name)
    219 def __call__(self, name: Optional[str] = None) -> variable_type:
    220     """Return a new `Variable` instance of Type `self`.
    221 
    222     Parameters
   (...)
    226 
    227     """
--> 228     return utils.add_tag_trace(self.make_variable(name))

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/type.py:200, in Type.make_variable(self, name)
    191 def make_variable(self, name: Optional[str] = None) -> variable_type:
    192     """Return a new `Variable` instance of this `Type`.
    193 
    194     Parameters
   (...)
    198 
    199     """
--> 200     return self.variable_type(self, None, name=name)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/var.py:863, in TensorVariable.__init__(self, type, owner, index, name)
    861     warnings.warn(msg, stacklevel=1 + nb_rm)
    862 elif config.warn_float64 == "raise":
--> 863     raise Exception(msg)
    864 elif config.warn_float64 == "pdb":
    865     import pdb

Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
​

PyMC version information:

master

Context for the issue:

related to pymc-devs/pymc-extras#182

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions