Skip to content

SMC sample_stats cannot be saved to netcdf #5263

Open
@astoeriko

Description

@astoeriko

When sampling several chains with SMC, the different chains sometimes run a different number of stages. As a consequence, the beta, accept_rate and log_marginal_likelihood variables in the sample_stats of the inference data are non-square. PyMC currently deals with this by giving them an object data type (see this comment in the code).
While this works for converting to xarray I get an error when trying to save the InferenceData to netcdf.

The following example is not guaranteed to reproduce the error because I cannot force the two chains to run a different number of stages. However, I tried to pick an example where the number of stages is large, so that it is not very likely that both chains need the same number of stages.

import pymc as pm

with pm.Model() as model:
    # Create a model that leads to many stages in SMC because the posterior is
    # far from the prior.
    bar = pm.Normal("bar", sigma=0.5)
    pm.Normal("foo", mu=bar, observed=100, sigma=0.5)
    trace = pm.sample_smc(cores=1, draws=100)
trace.to_netcdf("test_sample_stats.nc")
Complete error traceback
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
TypeError: float() argument must be a string or a number, not 'list'

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_106373/4238811079.py in <module>
----> 1 trace.to_netcdf("test_sample_stats.nc")

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
    390                 if compress:
    391                     kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
--> 392                 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
    393                 data.close()
    394                 mode = "a"

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
   1900         from ..backends.api import to_netcdf
   1901 
-> 1902         return to_netcdf(
   1903             self,
   1904             path,

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
   1070         # TODO: allow this work (setting up the file for writing array data)
   1071         # to be parallelized with dask
-> 1072         dump_to_store(
   1073             dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
   1074         )

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
   1117         variables, attrs = encoder(variables, attrs)
   1118 
-> 1119     store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
   1120 
   1121 

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
    259             writer = ArrayWriter()
    260 
--> 261         variables, attributes = self.encode(variables, attributes)
    262 
    263         self.set_attributes(attributes)

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
    348         # All NetCDF files get CF encoded by default, without this attempting
    349         # to write times, for example, would fail.
--> 350         variables, attributes = cf_encoder(variables, attributes)
    351         variables = {k: self.encode_variable(v) for k, v in variables.items()}
    352         attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
    853     _update_bounds_encoding(variables)
    854 
--> 855     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    856 
    857     # Remove attrs from bounds variables (issue #2921)

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in <dictcomp>(.0)
    853     _update_bounds_encoding(variables)
    854 
--> 855     new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
    856 
    857     # Remove attrs from bounds variables (issue #2921)

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
    273     var = maybe_default_fill_value(var)
    274     var = maybe_encode_bools(var)
--> 275     var = ensure_dtype_not_object(var, name=name)
    276 
    277     for attr_name in CF_RELATED_DATA:

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in ensure_dtype_not_object(var, name)
    231             data[missing] = fill_value
    232         else:
--> 233             data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))
    234 
    235         assert data.dtype.kind != "O" or data.dtype.metadata

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/conventions.py in _copy_with_dtype(data, dtype)
    189     """
    190     result = np.empty(data.shape, dtype)
--> 191     result[...] = data
    192     return result
    193 

ValueError: setting an array element with a sequence.

In my opinion it would be nicer to save the sample stats as a square array even if that means filling up with NaNs in the case that the chains have a different number of stages. Then, it could be represented properly in xarray with a stage dimension (which currently does not exist, the sample_stats falsely have "chain" and "draw" dimensions, even though they do not depend on the draw). I think that filling up with NaNs is not too bad because the number of stages do not differ hugely between the chains anyway.
Alternatively, sample stats could be saved as separate variables for each chain with a separate stage dimension for each chain.
Or would there be other solutions? Please let me know what you think about this and I would be happy to provide a pull request.

Besides, there is one more problem when saving to netcdf: Even if the chains do have the same number of stages, I get an error because the tune_steps attribute of the sample stats has a Boolean data type which does not seem to be supported in netcdf. When I convert it to a string saving works.

import pymc as pm

with pm.Model() as model:
    # Create a model that leads to many stages in SMC because the posterior is
    # far from the prior.
    bar = pm.Normal("bar", sigma=0.5)
    pm.Normal("foo", mu=bar, observed=3, sigma=0.5)
    trace = pm.sample_smc(cores=1, draws=100)

trace.to_netcdf("test_sample_stats.nc")
Complete error traceback
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_106373/1825203837.py in <module>
      8     trace = pm.sample_smc(cores=1, draws=100)
      9 
---> 10 trace.to_netcdf("test_sample_stats.nc")

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
    390                 if compress:
    391                     kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
--> 392                 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
    393                 data.close()
    394                 mode = "a"

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
   1900         from ..backends.api import to_netcdf
   1901 
-> 1902         return to_netcdf(
   1903             self,
   1904             path,

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
   1070         # TODO: allow this work (setting up the file for writing array data)
   1071         # to be parallelized with dask
-> 1072         dump_to_store(
   1073             dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
   1074         )

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
   1117         variables, attrs = encoder(variables, attrs)
   1118 
-> 1119     store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
   1120 
   1121 

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
    261         variables, attributes = self.encode(variables, attributes)
    262 
--> 263         self.set_attributes(attributes)
    264         self.set_dimensions(variables, unlimited_dims=unlimited_dims)
    265         self.set_variables(

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/common.py in set_attributes(self, attributes)
    278         """
    279         for k, v in attributes.items():
--> 280             self.set_attribute(k, v)
    281 
    282     def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None):

~/mambaforge/envs/pymc4-dev/lib/python3.9/site-packages/xarray/backends/netCDF4_.py in set_attribute(self, key, value)
    446             self.ds.setncattr_string(key, value)
    447         else:
--> 448             self.ds.setncattr(key, value)
    449 
    450     def encode_variable(self, variable):

src/netCDF4/_netCDF4.pyx in netCDF4._netCDF4.Dataset.setncattr()

src/netCDF4/_netCDF4.pyx in netCDF4._netCDF4._set_att()

TypeError: illegal data type for attribute b'tune_steps', must be one of dict_keys(['S1', 'i1', 'u1', 'i2', 'u2', 'i4', 'u4', 'i8', 'u8', 'f4', 'f8']), got b1
**Please provide any additional information below.**

Versions and main components

  • PyMC/PyMC3 Version: '4.0.0'
  • Aesara/Theano Version: '2.3.1'
  • Python Version: 3.9.7
  • Operating system: Linux
  • How did you install PyMC/PyMC3: conda

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions