Skip to content

Drop Aesara support and fix all currently detected type issues #84

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ repos:
args: [--rcfile=.pylintrc]
exclude: (test_*|mcbackend/meta.py|mcbackend/npproto/)
files: ^mcbackend/
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.991
hooks:
- id: mypy
10 changes: 2 additions & 8 deletions mcbackend/adapters/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,10 @@

import hagelkorn
import numpy

try:
from pytensor.graph.basic import Constant
from pytensor.tensor.sharedvar import SharedVariable
except ModuleNotFoundError:
from aesara.graph.basic import Constant
from aesara.tensor.sharedvar import SharedVariable

from pymc.backends.base import BaseTrace
from pymc.model import Model
from pytensor.graph.basic import Constant
from pytensor.tensor.sharedvar import SharedVariable

from mcbackend.meta import Coordinate, DataVariable, Variable

Expand Down
37 changes: 26 additions & 11 deletions mcbackend/backends/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
import logging
import time
from datetime import datetime, timezone
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)

import clickhouse_driver
import numpy
Expand Down Expand Up @@ -156,7 +167,7 @@ def __init__(
self._client = client
# The following attributes belong to the batched insert mechanism.
# Inserting in batches is much faster than inserting single rows.
self._str_cols = set()
self._str_cols: Set[str] = set()
self._insert_query: str = ""
self._insert_queue: List[Dict[str, Any]] = []
self._last_insert = time.time()
Expand All @@ -176,13 +187,16 @@ def append(
self._insert_query = f"INSERT INTO {self.cid} (`_draw_idx`,`{names}`) VALUES"
self._str_cols = {k for k, v in params.items() if "str" in numpy.asarray(v).dtype.name}

# Convert str ndarrays to lists
params_ins: Dict[str, Union[numpy.ndarray, int, float, List[str]]] = {
"_draw_idx": self._draw_idx,
**params,
}
# Convert str-dtyped ndarrays to lists
for col in self._str_cols:
params[col] = params[col].tolist()
params_ins[col] = params[col].tolist()

# Queue up for insertion
params["_draw_idx"] = self._draw_idx
self._insert_queue.append(params)
self._insert_queue.append(params_ins)
self._draw_idx += 1

if (
Expand Down Expand Up @@ -242,13 +256,14 @@ def _get_rows(

# Without draws return empty arrays of the correct shape/dtype
if not draws:
if is_rigid(nshape):
return numpy.empty(shape=[0] + nshape, dtype=dtype)
if is_rigid(nshape) and nshape is not None:
return numpy.empty(shape=[0, *nshape], dtype=dtype)
return numpy.array([], dtype=object)

# The unpacking must also account for non-rigid shapes
# and str-dtyped empty arrays default to fixed length 1 strings.
# The [None] list is slower, but more flexible in this regard.
buffer: Union[numpy.ndarray, Sequence]
if is_rigid(nshape) and dtype != "str":
assert nshape is not None
buffer = numpy.empty((draws, *nshape), dtype)
Expand Down Expand Up @@ -292,7 +307,7 @@ def __init__(
self,
meta: RunMeta,
*,
created_at: datetime = None,
created_at: Optional[datetime] = None,
client_fn: Callable[[], clickhouse_driver.Client],
) -> None:
self._client_fn = client_fn
Expand Down Expand Up @@ -331,8 +346,8 @@ class ClickHouseBackend(Backend):

def __init__(
self,
client: clickhouse_driver.Client = None,
client_fn: Callable[[], clickhouse_driver.Client] = None,
client: Optional[clickhouse_driver.Client] = None,
client_fn: Optional[Callable[[], clickhouse_driver.Client]] = None,
):
"""Create a ClickHouse backend around a database client.

Expand Down
39 changes: 15 additions & 24 deletions mcbackend/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,20 @@
"""
import collections
import logging
from typing import (
TYPE_CHECKING,
Dict,
List,
Mapping,
Optional,
Sequence,
Sized,
TypeVar,
)
from typing import Dict, List, Mapping, Optional, Sequence, Sized, TypeVar, Union, cast

import numpy

from .meta import ChainMeta, RunMeta, Variable
from .npproto.utils import ndarray_to_numpy
from .utils import as_array_from_ragged

InferenceData = TypeVar("InferenceData")
try:
from arviz import from_dict
from arviz import InferenceData, from_dict

if not TYPE_CHECKING:
from arviz import InferenceData
_HAS_ARVIZ = True
except ModuleNotFoundError:
InferenceData = TypeVar("InferenceData") # type: ignore
_HAS_ARVIZ = False

Shape = Sequence[int]
Expand Down Expand Up @@ -262,20 +251,22 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
warmup_sample_stats[svar.name].append(stats[tune])
sample_stats[svar.name].append(stats[~tune])

w_pst = cast(Dict[str, Union[Sequence, numpy.ndarray]], warmup_posterior)
w_ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], warmup_sample_stats)
pst = cast(Dict[str, Union[Sequence, numpy.ndarray]], posterior)
ss = cast(Dict[str, Union[Sequence, numpy.ndarray]], sample_stats)
if not equalize_chain_lengths:
# Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically
warmup_posterior = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()}
warmup_sample_stats = {
k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items()
}
posterior = {k: as_array_from_ragged(v) for k, v in posterior.items()}
sample_stats = {k: as_array_from_ragged(v) for k, v in sample_stats.items()}
w_pst = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()}
w_ss = {k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items()}
pst = {k: as_array_from_ragged(v) for k, v in posterior.items()}
ss = {k: as_array_from_ragged(v) for k, v in sample_stats.items()}

idata = from_dict(
warmup_posterior=warmup_posterior,
warmup_sample_stats=warmup_sample_stats,
posterior=posterior,
sample_stats=sample_stats,
warmup_posterior=w_pst,
warmup_sample_stats=w_ss,
posterior=pst,
sample_stats=ss,
coords=self.coords,
dims=self.dims,
attrs=self.meta.attributes,
Expand Down
4 changes: 2 additions & 2 deletions mcbackend/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def test_chain_properties(self):

def test_chain_length(self):
class _TestChain(core.Chain):
def get_draws(self, var_name: str):
def get_draws(self, var_name: str, slc: slice = slice(None)):
return numpy.arange(12)

def get_stats(self, stat_name: str):
def get_stats(self, stat_name: str, slc: slice = slice(None)):
return numpy.arange(42)

rmeta = RunMeta("test", variables=[Variable("v1")])
Expand Down
12 changes: 5 additions & 7 deletions mcbackend/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import random
import time
from dataclasses import dataclass
from typing import Sequence
from typing import Optional, Sequence

import arviz
import hagelkorn
Expand Down Expand Up @@ -78,9 +78,9 @@ def make_draw(variables: Sequence[Variable]):
class BaseBackendTest:
"""Can be used to test different backends in the same way."""

cls_backend = None
cls_run = None
cls_chain = None
cls_backend: Optional[type] = None
cls_run: Optional[type] = None
cls_chain: Optional[type] = None

def setup_method(self, method):
"""Override this when the backend has no parameterless constructor."""
Expand Down Expand Up @@ -373,10 +373,8 @@ def run_all_benchmarks(self) -> pandas.DataFrame:
for attr in dir(BackendBenchmark):
meth = getattr(self, attr, None)
if callable(meth) and meth.__name__.startswith("measure_"):
try:
if hasattr(self, "setup_method"):
self.setup_method(meth)
except TypeError:
pass
print(f"Running {meth.__name__}")
speed = meth()
df.loc[meth.__name__[8:], ["bytes_per_draw", "append_speed", "description"]] = (
Expand Down