Description
When sampling several chains with SMC, the different chains sometimes run a different number of stages. As a consequence, the beta, accept_rate and log_marginal_likelihood variables in the sample_stats of the inference data are non-square. PyMC currently deals with this by giving them an object data type (see this comment in the code).
While this works for converting to xarray I get an error when trying to save the InferenceData to netcdf.
The following example is not guaranteed to reproduce the error because I cannot force the two chains to run a different number of stages. However, I tried to pick an example where the number of stages is large, so that it is not very likely that both chains need the same number of stages.
import pymc as pm
with pm.Model() as model:
# Create a model that leads to many stages in SMC because the posterior is
# far from the prior.
bar = pm.Normal("bar", sigma=0.5)
pm.Normal("foo", mu=bar, observed=100, sigma=0.5)
trace = pm.sample_smc(cores=1, draws=100)
trace.to_netcdf("test_sample_stats.nc")
Complete error traceback
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
TypeError: float() argument must be a string or a number, not 'list'
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
/tmp/ipykernel_106373/4238811079.py in <module>
----> 1 trace.to_netcdf("test_sample_stats.nc")
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
390 if compress:
391 kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
--> 392 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
393 data.close()
394 mode = "a"
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
1900 from ..backends.api import to_netcdf
1901
-> 1902 return to_netcdf(
1903 self,
1904 path,
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
1070 # TODO: allow this work (setting up the file for writing array data)
1071 # to be parallelized with dask
-> 1072 dump_to_store(
1073 dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
1074 )
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
1117 variables, attrs = encoder(variables, attrs)
1118
-> 1119 store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
1120
1121
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
259 writer = ArrayWriter()
260
--> 261 variables, attributes = self.encode(variables, attributes)
262
263 self.set_attributes(attributes)
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
348 # All NetCDF files get CF encoded by default, without this attempting
349 # to write times, for example, would fail.
--> 350 variables, attributes = cf_encoder(variables, attributes)
351 variables = {k: self.encode_variable(v) for k, v in variables.items()}
352 attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
853 _update_bounds_encoding(variables)
854
--> 855 new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
856
857 # Remove attrs from bounds variables (issue #2921)
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in <dictcomp>(.0)
853 _update_bounds_encoding(variables)
854
--> 855 new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
856
857 # Remove attrs from bounds variables (issue #2921)
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
273 var = maybe_default_fill_value(var)
274 var = maybe_encode_bools(var)
--> 275 var = ensure_dtype_not_object(var, name=name)
276
277 for attr_name in CF_RELATED_DATA:
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in ensure_dtype_not_object(var, name)
231 data[missing] = fill_value
232 else:
--> 233 data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))
234
235 assert data.dtype.kind != "O" or data.dtype.metadata
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in _copy_with_dtype(data, dtype)
189 """
190 result = np.empty(data.shape, dtype)
--> 191 result[...] = data
192 return result
193
ValueError: setting an array element with a sequence.
In my opinion it would be nicer to save the sample stats as a square array even if that means filling up with NaNs in the case that the chains have a different number of stages. Then, it could be represented properly in xarray with a stage
dimension (which currently does not exist, the sample_stats falsely have "chain" and "draw" dimensions, even though they do not depend on the draw). I think that filling up with NaNs is not too bad because the number of stages do not differ hugely between the chains anyway.
Alternatively, sample stats could be saved as separate variables for each chain with a separate stage dimension for each chain.
Or would there be other solutions? Please let me know what you think about this and I would be happy to provide a pull request.
Besides, there is one more problem when saving to netcdf: Even if the chains do have the same number of stages, I get an error because the tune_steps
attribute of the sample stats has a Boolean data type which does not seem to be supported in netcdf. When I convert it to a string saving works.
import pymc as pm
with pm.Model() as model:
# Create a model that leads to many stages in SMC because the posterior is
# far from the prior.
bar = pm.Normal("bar", sigma=0.5)
pm.Normal("foo", mu=bar, observed=3, sigma=0.5)
trace = pm.sample_smc(cores=1, draws=100)
trace.to_netcdf("test_sample_stats.nc")
Complete error traceback
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_106373/1825203837.py in <module>
8 trace = pm.sample_smc(cores=1, draws=100)
9
---> 10 trace.to_netcdf("test_sample_stats.nc")
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
390 if compress:
391 kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
--> 392 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
393 data.close()
394 mode = "a"
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
1900 from ..backends.api import to_netcdf
1901
-> 1902 return to_netcdf(
1903 self,
1904 path,
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
1070 # TODO: allow this work (setting up the file for writing array data)
1071 # to be parallelized with dask
-> 1072 dump_to_store(
1073 dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
1074 )
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
1117 variables, attrs = encoder(variables, attrs)
1118
-> 1119 store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
1120
1121
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
261 variables, attributes = self.encode(variables, attributes)
262
--> 263 self.set_attributes(attributes)
264 self.set_dimensions(variables, unlimited_dims=unlimited_dims)
265 self.set_variables(
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/common.py in set_attributes(self, attributes)
278 """
279 for k, v in attributes.items():
--> 280 self.set_attribute(k, v)
281
282 def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None):
~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/netCDF4_.py in set_attribute(self, key, value)
446 self.ds.setncattr_string(key, value)
447 else:
--> 448 self.ds.setncattr(key, value)
449
450 def encode_variable(self, variable):
src/netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Dataset.setncattr()
src/netCDF4/_netCDF4.pyx in netCDF4._netCDF4._set_att()
TypeError: illegal data type for attribute b'tune_steps', must be one of dict_keys(['S1', 'i1', 'u1', 'i2', 'u2', 'i4', 'u4', 'i8', 'u8', 'f4', 'f8']), got b1
Versions and main components
- PyMC/PyMC3 Version: '4.0.0'
- Aesara/Theano Version: '2.3.1'
- Python Version: 3.9.7
- Operating system: Linux
- How did you install PyMC/PyMC3: conda