Skip to content

ENH: make all of the properties on LineWrapper mappable #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 10, 2022
55 changes: 45 additions & 10 deletions data_prototype/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Dict, Any, Protocol, Tuple, get_type_hints
import inspect

import numpy as np

Expand Down Expand Up @@ -46,6 +47,15 @@ class _Aritst(Protocol):
axes: _Axes


def _make_identity(k):
def identity(**kwargs):
(_,) = kwargs.values()
return _

identity.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)])
return identity


def _forwarder(forwards, cls=None):
if cls is None:
return partial(_forwarder, forwards)
Expand Down Expand Up @@ -88,6 +98,8 @@ class ProxyWrapperBase:
data: DataContainer
axes: _Axes
stale: bool
required_keys: set = set()
expected_keys: set = set()

@_stale_wrapper
def draw(self, renderer):
Expand Down Expand Up @@ -137,18 +149,34 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
# doing the nu work here is nice because we can write it once, but we
# really want to push this computation down a layer
# TODO sort out how this interoperates with the transform stack
data = {k: self.nus.get(k, lambda x: x)(v) for k, v in data.items()}
self._cache[cache_key] = data
return data
transformed_data = {}
for k, (nu, sig) in self._sigs.items():
to_pass = set(sig.parameters)
transformed_data[k] = nu(**{k: data[k] for k in to_pass})

self._cache[cache_key] = transformed_data
return transformed_data

def __init__(self, data, nus, **kwargs):
super().__init__(**kwargs)
self.data = data
self._cache = LFUCache(64)
# TODO make sure mutating this will invalidate the cache!
self.nus = nus or {}
self._nus = nus or {}
for k in self.required_keys:
self._nus.setdefault(k, _make_identity(k))
desc = data.describe()
for k in self.expected_keys:
if k in desc:
self._nus.setdefault(k, _make_identity(k))
self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()}
self.stale = True

# TODO add a setter
@property
def nus(self):
return dict(self._nus)


class ProxyWrapper(ProxyWrapperBase):
_privtized_methods: Tuple[str, ...] = ()
Expand All @@ -163,7 +191,7 @@ def __getattr__(self, key):
return getattr(self._wrapped_instance, key)

def __setattr__(self, key, value):
if key in ("_wrapped_instance", "data", "_cache", "nus", "stale"):
if key in ("_wrapped_instance", "data", "_cache", "_nus", "stale", "_sigs"):
super().__setattr__(key, value)
elif hasattr(self, "_wrapped_instance") and hasattr(self._wrapped_instance, key):
setattr(self._wrapped_instance, key, value)
Expand All @@ -174,6 +202,7 @@ def __setattr__(self, key, value):
class LineWrapper(ProxyWrapper):
_wrapped_class = _Line2D
_privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data")
required_keys = {"x", "y"}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
Expand All @@ -187,14 +216,16 @@ def draw(self, renderer):
return self._wrapped_instance.draw(renderer)

def _update_wrapped(self, data):
self._wrapped_instance.set_data(data["x"], data["y"])
for k, v in data.items():
k = {"x": "xdata", "y": "ydata"}.get(k, k)
getattr(self._wrapped_instance, f"set_{k}")(v)


class ImageWrapper(ProxyWrapper):
_wrapped_class = _AxesImage
required_keys = {"xextent", "yextent", "image"}

def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwargs):
print(kwargs, nus)
nus = dict(nus or {})
if cmap is not None or norm is not None:
if nus is not None and "image" in nus:
Expand Down Expand Up @@ -223,6 +254,7 @@ def _update_wrapped(self, data):
class StepWrapper(ProxyWrapper):
_wrapped_class = _StepPatch
_privtized_methods = () # ("set_data", "get_data")
required_keys = {"edges", "density"}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
Expand All @@ -243,10 +275,9 @@ class FormatedText(ProxyWrapper):
_wrapped_class = _Text
_privtized_methods = ("set_text",)

def __init__(self, data: DataContainer, format_func, nus=None, /, **kwargs):
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
self._wrapped_instance = self._wrapped_class(text="", **kwargs)
self._format_func = format_func

@_stale_wrapper
def draw(self, renderer):
Expand All @@ -256,7 +287,8 @@ def draw(self, renderer):
return self._wrapped_instance.draw(renderer)

def _update_wrapped(self, data):
self._wrapped_instance.set_text(self._format_func(**data))
for k, v in data.items():
getattr(self._wrapped_instance, f"set_{k}")(v)


@_forwarder(
Expand Down Expand Up @@ -296,6 +328,9 @@ def get_children(self):


class ErrorbarWrapper(MultiProxyWrapper):
required_keys = {"x", "y"}
expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
# TODO all of the kwarg teasing apart that is needed
Expand Down
Loading