Skip to content

Commit 3eccc70

Browse files
authored
ENH: make all of the properties on LineWrapper mappable (#15)
* ENH: make all of the properties on LineWrapper mappable * DOC: update logo unbreaks doc build for me locally * ENH: make the nus functions more filtering and powerful Make it possible to rename as part of the transforms * MNT: remove a print * ENH: show off using multiple inputs + renaming * MNT: add required keys to steps and image * ENH: pass any "extra" data through with assumed unity transform * ENH: being permissive in passing data was a bad idea Introduce "expected keys" * BLD: always run all of the examples when build the docs * MNT: make clear all FormattedText objects share the same functions * STY: fix linting
1 parent c04b2f4 commit 3eccc70

File tree

5 files changed

+674
-13
lines changed

5 files changed

+674
-13
lines changed

data_prototype/wrappers.py

+45-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Dict, Any, Protocol, Tuple, get_type_hints
2+
import inspect
23

34
import numpy as np
45

@@ -46,6 +47,15 @@ class _Aritst(Protocol):
4647
axes: _Axes
4748

4849

50+
def _make_identity(k):
51+
def identity(**kwargs):
52+
(_,) = kwargs.values()
53+
return _
54+
55+
identity.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)])
56+
return identity
57+
58+
4959
def _forwarder(forwards, cls=None):
5060
if cls is None:
5161
return partial(_forwarder, forwards)
@@ -88,6 +98,8 @@ class ProxyWrapperBase:
8898
data: DataContainer
8999
axes: _Axes
90100
stale: bool
101+
required_keys: set = set()
102+
expected_keys: set = set()
91103

92104
@_stale_wrapper
93105
def draw(self, renderer):
@@ -137,18 +149,34 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
137149
# doing the nu work here is nice because we can write it once, but we
138150
# really want to push this computation down a layer
139151
# TODO sort out how this interoperates with the transform stack
140-
data = {k: self.nus.get(k, lambda x: x)(v) for k, v in data.items()}
141-
self._cache[cache_key] = data
142-
return data
152+
transformed_data = {}
153+
for k, (nu, sig) in self._sigs.items():
154+
to_pass = set(sig.parameters)
155+
transformed_data[k] = nu(**{k: data[k] for k in to_pass})
156+
157+
self._cache[cache_key] = transformed_data
158+
return transformed_data
143159

144160
def __init__(self, data, nus, **kwargs):
145161
super().__init__(**kwargs)
146162
self.data = data
147163
self._cache = LFUCache(64)
148164
# TODO make sure mutating this will invalidate the cache!
149-
self.nus = nus or {}
165+
self._nus = nus or {}
166+
for k in self.required_keys:
167+
self._nus.setdefault(k, _make_identity(k))
168+
desc = data.describe()
169+
for k in self.expected_keys:
170+
if k in desc:
171+
self._nus.setdefault(k, _make_identity(k))
172+
self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()}
150173
self.stale = True
151174

175+
# TODO add a setter
176+
@property
177+
def nus(self):
178+
return dict(self._nus)
179+
152180

153181
class ProxyWrapper(ProxyWrapperBase):
154182
_privtized_methods: Tuple[str, ...] = ()
@@ -163,7 +191,7 @@ def __getattr__(self, key):
163191
return getattr(self._wrapped_instance, key)
164192

165193
def __setattr__(self, key, value):
166-
if key in ("_wrapped_instance", "data", "_cache", "nus", "stale"):
194+
if key in ("_wrapped_instance", "data", "_cache", "_nus", "stale", "_sigs"):
167195
super().__setattr__(key, value)
168196
elif hasattr(self, "_wrapped_instance") and hasattr(self._wrapped_instance, key):
169197
setattr(self._wrapped_instance, key, value)
@@ -174,6 +202,7 @@ def __setattr__(self, key, value):
174202
class LineWrapper(ProxyWrapper):
175203
_wrapped_class = _Line2D
176204
_privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data")
205+
required_keys = {"x", "y"}
177206

178207
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
179208
super().__init__(data, nus)
@@ -187,14 +216,16 @@ def draw(self, renderer):
187216
return self._wrapped_instance.draw(renderer)
188217

189218
def _update_wrapped(self, data):
190-
self._wrapped_instance.set_data(data["x"], data["y"])
219+
for k, v in data.items():
220+
k = {"x": "xdata", "y": "ydata"}.get(k, k)
221+
getattr(self._wrapped_instance, f"set_{k}")(v)
191222

192223

193224
class ImageWrapper(ProxyWrapper):
194225
_wrapped_class = _AxesImage
226+
required_keys = {"xextent", "yextent", "image"}
195227

196228
def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwargs):
197-
print(kwargs, nus)
198229
nus = dict(nus or {})
199230
if cmap is not None or norm is not None:
200231
if nus is not None and "image" in nus:
@@ -223,6 +254,7 @@ def _update_wrapped(self, data):
223254
class StepWrapper(ProxyWrapper):
224255
_wrapped_class = _StepPatch
225256
_privtized_methods = () # ("set_data", "get_data")
257+
required_keys = {"edges", "density"}
226258

227259
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
228260
super().__init__(data, nus)
@@ -243,10 +275,9 @@ class FormatedText(ProxyWrapper):
243275
_wrapped_class = _Text
244276
_privtized_methods = ("set_text",)
245277

246-
def __init__(self, data: DataContainer, format_func, nus=None, /, **kwargs):
278+
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
247279
super().__init__(data, nus)
248280
self._wrapped_instance = self._wrapped_class(text="", **kwargs)
249-
self._format_func = format_func
250281

251282
@_stale_wrapper
252283
def draw(self, renderer):
@@ -256,7 +287,8 @@ def draw(self, renderer):
256287
return self._wrapped_instance.draw(renderer)
257288

258289
def _update_wrapped(self, data):
259-
self._wrapped_instance.set_text(self._format_func(**data))
290+
for k, v in data.items():
291+
getattr(self._wrapped_instance, f"set_{k}")(v)
260292

261293

262294
@_forwarder(
@@ -296,6 +328,9 @@ def get_children(self):
296328

297329

298330
class ErrorbarWrapper(MultiProxyWrapper):
331+
required_keys = {"x", "y"}
332+
expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]}
333+
299334
def __init__(self, data: DataContainer, nus=None, /, **kwargs):
300335
super().__init__(data, nus)
301336
# TODO all of the kwarg teasing apart that is needed

0 commit comments

Comments
 (0)