Description
Describe the issue:
Hi, writing with a niche question. I tried to specify a model where I've flattened multiple coords into a "flattened" single coord, where each value is a particular combination of the original coords. E.g. instead of {"coord1": ["a", "b"], "coord2": [1, 2, 3]}
, I set {"c": [("a", 1), ("a", 2), ("a", 3), ("b", 1), ("b", 2), ("b", 3)]
.
Is this supported input? Unfortunately, sampling fails in this case.
Reproduceable code example:
import pymc as pm
with pm.Model(coords={"c": [("a", 1), ("a", 3), ("b", 2)]}) as m:
x = pm.Normal("x", mu=0, sigma=1, dims="c")
# m.coords is ({'c': (('a', 1), ('a', 3), ('b', 2))}
# x.type is TensorType(float64, (3,)))
idata = pm.sample_prior_predictive()
# also fails with,
# idata = pm.sample()
Error message:
ValueError Traceback (most recent call last)
Cell In[23], line 2
1 with m:
----> 2 idata = pm.sample_prior_predictive()
File ...python/site-packages/pymc/sampling/forward.py:440, in sample_prior_predictive(samples, model, var_names, random_seed, return_inferencedata, idata_kwargs, compile_kwargs)
438 if idata_kwargs:
439 ikwargs.update(idata_kwargs)
--> 440 return pm.to_inference_data(prior=prior, **ikwargs)
File ...python/site-packages/pymc/backends/arviz.py:496, in to_inference_data(trace, prior, posterior_predictive, log_likelihood, coords, dims, sample_dims, model, save_warmup, include_transformed)
482 if isinstance(trace, InferenceData):
483 return trace
485 return InferenceDataConverter(
486 trace=trace,
487 prior=prior,
488 posterior_predictive=posterior_predictive,
489 log_likelihood=log_likelihood,
490 coords=coords,
491 dims=dims,
492 sample_dims=sample_dims,
493 model=model,
494 save_warmup=save_warmup,
495 include_transformed=include_transformed,
--> 496 ).to_inference_data()
File ...python/site-packages/pymc/backends/arviz.py:408, in InferenceDataConverter.to_inference_data(self)
396 def to_inference_data(self):
397 """Convert all available data to an InferenceData object.
398
399 Note that if groups can not be created (e.g., there is no `trace`, so
400 the `posterior` and `sample_stats` can not be extracted), then the InferenceData
401 will not have those groups.
402 """
403 id_dict = {
404 "posterior": self.posterior_to_xarray(),
405 "sample_stats": self.sample_stats_to_xarray(),
406 "posterior_predictive": self.posterior_predictive_to_xarray(),
407 "predictions": self.predictions_to_xarray(),
--> 408 **self.priors_to_xarray(),
409 "observed_data": self.observed_data_to_xarray(),
410 }
411 if self.predictions:
412 id_dict["predictions_constant_data"] = self.constant_data_to_xarray()
File ...python/site-packages/pymc/backends/arviz.py:358, in InferenceDataConverter.priors_to_xarray(self)
351 priors_dict = {}
352 for group, var_names in zip(
353 ("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
354 ):
355 priors_dict[group] = (
356 None
357 if var_names is None
--> 358 else dict_to_dataset(
359 {k: np.expand_dims(self.prior[k], 0) for k in var_names},
360 library=pymc,
361 coords=self.coords,
362 dims=self.dims,
363 )
364 )
365 return priors_dict
File ...python/site-packages/arviz/data/base.py:306, in dict_to_dataset(data, attrs, library, coords, dims, default_dims, index_origin, skip_event_dims)
303 if dims is None:
304 dims = {}
--> 306 data_vars = {
307 key: numpy_to_data_array(
308 values,
309 var_name=key,
310 coords=coords,
311 dims=dims.get(key),
312 default_dims=default_dims,
313 index_origin=index_origin,
314 skip_event_dims=skip_event_dims,
315 )
316 for key, values in data.items()
317 }
318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
File ...python/site-packages/arviz/data/base.py:307, in <dictcomp>(.0)
303 if dims is None:
304 dims = {}
306 data_vars = {
--> 307 key: numpy_to_data_array(
308 values,
309 var_name=key,
310 coords=coords,
311 dims=dims.get(key),
312 default_dims=default_dims,
313 index_origin=index_origin,
314 skip_event_dims=skip_event_dims,
315 )
316 for key, values in data.items()
317 }
318 return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))
File ...python/site-packages/arviz/data/base.py:254, in numpy_to_data_array(ary, var_name, coords, dims, default_dims, index_origin, skip_event_dims)
251 coords["draw"] = np.arange(index_origin, n_samples + index_origin)
253 # filter coords based on the dims
--> 254 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
255 return xr.DataArray(ary, coords=coords, dims=dims)
File ...python/site-packages/arviz/data/base.py:254, in <dictcomp>(.0)
251 coords["draw"] = np.arange(index_origin, n_samples + index_origin)
253 # filter coords based on the dims
--> 254 coords = {key: xr.IndexVariable((key,), data=np.asarray(coords[key])) for key in dims}
255 return xr.DataArray(ary, coords=coords, dims=dims)
File ...python/site-packages/xarray/core/variable.py:2846, in IndexVariable.__init__(self, dims, data, attrs, encoding, fastpath)
2845 def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
-> 2846 super().__init__(dims, data, attrs, encoding, fastpath)
2847 if self.ndim != 1:
2848 raise ValueError(f"{type(self).__name__} objects must be 1-dimensional")
File ...python/site-packages/xarray/core/variable.py:366, in Variable.__init__(self, dims, data, attrs, encoding, fastpath)
346 """
347 Parameters
348 ----------
(...)
363 unrecognized encoding items.
364 """
365 self._data = as_compatible_data(data, fastpath=fastpath)
--> 366 self._dims = self._parse_dimensions(dims)
367 self._attrs = None
368 self._encoding = None
File ...python/site-packages/xarray/core/variable.py:663, in Variable._parse_dimensions(self, dims)
661 dims = tuple(dims)
662 if len(dims) != self.ndim:
--> 663 raise ValueError(
664 f"dimensions {dims} must have the same length as the "
665 f"number of data dimensions, ndim={self.ndim}"
666 )
667 return dims
ValueError: dimensions ('c',) must have the same length as the number of data dimensions, ndim=2
PyMC version information:
pymc : 5.0.2
arviz : 0.14.0
xarray: 2023.1.0
Context for the issue:
It looks like the error arises because along the way, the coord is converted from a tuple to a numpy array by InferenceDataConverter.
Lines 210 to 214 in 902b1ec
And numpy is eager to convert the coord vector into a multidimensional array, because it can. But np.array([("a", 1), ("a", 3), ("b", 2)]).ndim
is 2
, which I think causes the mismatch that makes arviz/xarray fail.
Obviously one workaround would be to stringify the coordinate values (that I currently have tupled). But it would be nice to preserve the type/objects.