Skip to content

BUG: Issues creating InferenceData if there is a coord of type Sequence[Sequence] #6496

Open
@covertg

Description

@covertg

Describe the issue:

Hi, writing with a niche question. I tried to specify a model where I've flattened multiple coords into a "flattened" single coord, where each value is a particular combination of the original coords. E.g. instead of {"coord1": ["a", "b"], "coord2": [1, 2, 3]}, I set {"c": [("a", 1), ("a", 2), ("a", 3), ("b", 1), ("b", 2), ("b", 3)].

Is this supported input? Unfortunately, sampling fails in this case.

Reproduceable code example:

import pymc as pm
with pm.Model(coords={"c": [("a", 1), ("a", 3), ("b", 2)]}) as m:
    x = pm.Normal("x", mu=0, sigma=1, dims="c")
    # m.coords is ({'c': (('a', 1), ('a', 3), ('b', 2))}
    # x.type is TensorType(float64, (3,)))
    idata = pm.sample_prior_predictive()
    # also fails with,
    # idata = pm.sample()

Error message:

ValueError                                Traceback (most recent call last)
Cell In[23], line 2
      1 with m:
----> 2     idata = pm.sample_prior_predictive()

File ...python/site-packages/pymc/sampling/forward.py:440, in sample_prior_predictive(samples, model, var_names, random_seed, return_inferencedata, idata_kwargs, compile_kwargs)
    438 if idata_kwargs:
    439     ikwargs.update(idata_kwargs)
--> 440 return pm.to_inference_data(prior=prior, **ikwargs)

File ...python/site-packages/pymc/backends/arviz.py:496, in to_inference_data(trace, prior, posterior_predictive, log_likelihood, coords, dims, sample_dims, model, save_warmup, include_transformed)
    482 if isinstance(trace, InferenceData):
    483     return trace
    485 return InferenceDataConverter(
    486     trace=trace,
    487     prior=prior,
    488     posterior_predictive=posterior_predictive,
    489     log_likelihood=log_likelihood,
    490     coords=coords,
    491     dims=dims,
    492     sample_dims=sample_dims,
    493     model=model,
    494     save_warmup=save_warmup,
    495     include_transformed=include_transformed,
--> 496 ).to_inference_data()

File ...python/site-packages/pymc/backends/arviz.py:408, in InferenceDataConverter.to_inference_data(self)
    396 def to_inference_data(self):
    397     """Convert all available data to an InferenceData object.
    398 
    399     Note that if groups can not be created (e.g., there is no `trace`, so
    400     the `posterior` and `sample_stats` can not be extracted), then the InferenceData
    401     will not have those groups.
    402     """
    403     id_dict = {
    404         "posterior": self.posterior_to_xarray(),
    405         "sample_stats": self.sample_stats_to_xarray(),
    406         "posterior_predictive": self.posterior_predictive_to_xarray(),
    407         "predictions": self.predictions_to_xarray(),
--> 408         **self.priors_to_xarray(),
    409         "observed_data": self.observed_data_to_xarray(),
    410     }
    411     if self.predictions:
    412         id_dict["predictions_constant_data"] = self.constant_data_to_xarray()

File ...python/site-packages/pymc/backends/arviz.py:358, in InferenceDataConverter.priors_to_xarray(self)
    351 priors_dict = {}
    352 for group, var_names in zip(
    353     ("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
    354 ):
    355     priors_dict[group] = (
    356         None
    357         if var_names is None
--> 358         else dict_to_dataset(
    359             {k: np.expand_dims(self.prior[k], 0) for k in var_names},
    360             library=pymc,
    361             coords=self.coords,
    362             dims=self.dims,
    363         )
    364     )
    365 return priors_dict

File ...python/site-packages/arviz/data/base.py:306, in dict_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
    303 if dims is None:
    304     dims = {}
--> 306 data_vars = {
    307     key: numpy_to_data_array(
    308         values,
    309         var_name=key,
    310         coords=coords,
    311         dims=dims.get(key),
    312         default_dims=default_dims,
    313         index_origin=index_origin,
    314         skip_event_dims=skip_event_dims,
    315     )
    316     for key, values in data.items()
    317 }
    318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ...python/site-packages/arviz/data/base.py:307, in <dictcomp>(.0)
    303 if dims is None:
    304     dims = {}
    306 data_vars = {
--> 307     key: numpy_to_data_array(
    308         values,
    309         var_name=key,
    310         coords=coords,
    311         dims=dims.get(key),
    312         default_dims=default_dims,
    313         index_origin=index_origin,
    314         skip_event_dims=skip_event_dims,
    315     )
    316     for key, values in data.items()
    317 }
    318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

File ...python/site-packages/arviz/data/base.py:254, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
    251     coords["draw"] = np.arange(index_origin, n_samples + index_origin)
    253 # filter coords based on the dims
--> 254 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
    255 return xr.DataArray(ary, coords=coords, dims=dims)

File ...python/site-packages/arviz/data/base.py:254, in <dictcomp>(.0)
    251     coords["draw"] = np.arange(index_origin, n_samples + index_origin)
    253 # filter coords based on the dims
--> 254 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
    255 return xr.DataArray(ary, coords=coords, dims=dims)

File ...python/site-packages/xarray/core/variable.py:2846, in IndexVariable.__init__(self, dims, data, attrs, encoding, fastpath)
   2845 def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
-> 2846     super().__init__(dims, data, attrs, encoding, fastpath)
   2847     if self.ndim != 1:
   2848         raise ValueError(f"{type(self).__name__} objects must be 1-dimensional")

File ...python/site-packages/xarray/core/variable.py:366, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
    346 """
    347 Parameters
    348 ----------
   (...)
    363     unrecognized encoding items.
    364 """
    365 self._data = as_compatible_data(data, fastpath=fastpath)
--> 366 self._dims = self._parse_dimensions(dims)
    367 self._attrs = None
    368 self._encoding = None

File ...python/site-packages/xarray/core/variable.py:663, in Variable._parse_dimensions(self, dims)
    661 dims = tuple(dims)
    662 if len(dims) != self.ndim:
--> 663     raise ValueError(
    664         f"dimensions {dims} must have the same length as the "
    665         f"number of data dimensions, ndim={self.ndim}"
    666     )
    667 return dims

ValueError: dimensions ('c',) must have the same length as the number of data dimensions, ndim=2

PyMC version information:

pymc : 5.0.2
arviz : 0.14.0
xarray: 2023.1.0

Context for the issue:

It looks like the error arises because along the way, the coord is converted from a tuple to a numpy array by InferenceDataConverter.

pymc/pymc/backends/arviz.py

Lines 210 to 214 in 902b1ec

self.coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
for cname, cvals in untyped_coords.items()
if cvals is not None
}

And numpy is eager to convert the coord vector into a multidimensional array, because it can. But np.array([("a", 1), ("a", 3), ("b", 2)]).ndim is 2, which I think causes the mismatch that makes arviz/xarray fail.

Obviously one workaround would be to stringify the coordinate values (that I currently have tupled). But it would be nice to preserve the type/objects.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions