Skip to content

Commit a7ed5d9

Browse files
authored
Merge pull request #4926 from FBruzzesi/patch/deepcopy-figure-fix
patch: deepcopy figure fix
2 parents 3e47cf8 + 444eca4 commit a7ed5d9

File tree

2 files changed

+67
-29
lines changed

2 files changed

+67
-29
lines changed

packages/python/plotly/_plotly_utils/basevalidators.py

+29-29
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,17 @@ def type_str(v):
223223
return "'{module}.{name}'".format(module=v.__module__, name=v.__name__)
224224

225225

226+
def is_typed_array_spec(v):
227+
"""
228+
Return whether a value is considered to be a typed array spec for plotly.js
229+
"""
230+
return isinstance(v, dict) and "bdata" in v and "dtype" in v
231+
232+
233+
def is_none_or_typed_array_spec(v):
234+
return v is None or is_typed_array_spec(v)
235+
236+
226237
# Validators
227238
# ----------
228239
class BaseValidator(object):
@@ -393,8 +404,7 @@ def description(self):
393404

394405
def validate_coerce(self, v):
395406

396-
if v is None:
397-
# Pass None through
407+
if is_none_or_typed_array_spec(v):
398408
pass
399409
elif is_homogeneous_array(v):
400410
v = copy_to_readonly_numpy_array(v)
@@ -591,8 +601,7 @@ def in_values(self, e):
591601
return False
592602

593603
def validate_coerce(self, v):
594-
if v is None:
595-
# Pass None through
604+
if is_none_or_typed_array_spec(v):
596605
pass
597606
elif self.array_ok and is_array(v):
598607
v_replaced = [self.perform_replacemenet(v_el) for v_el in v]
@@ -636,8 +645,7 @@ def description(self):
636645
)
637646

638647
def validate_coerce(self, v):
639-
if v is None:
640-
# Pass None through
648+
if is_none_or_typed_array_spec(v):
641649
pass
642650
elif not isinstance(v, bool):
643651
self.raise_invalid_val(v)
@@ -661,8 +669,7 @@ def description(self):
661669
)
662670

663671
def validate_coerce(self, v):
664-
if v is None:
665-
# Pass None through
672+
if is_none_or_typed_array_spec(v):
666673
pass
667674
elif isinstance(v, str):
668675
pass
@@ -752,8 +759,7 @@ def description(self):
752759
return desc
753760

754761
def validate_coerce(self, v):
755-
if v is None:
756-
# Pass None through
762+
if is_none_or_typed_array_spec(v):
757763
pass
758764
elif self.array_ok and is_homogeneous_array(v):
759765
np = get_module("numpy")
@@ -899,8 +905,7 @@ def description(self):
899905
return desc
900906

901907
def validate_coerce(self, v):
902-
if v is None:
903-
# Pass None through
908+
if is_none_or_typed_array_spec(v):
904909
pass
905910
elif v in self.extras:
906911
return v
@@ -1063,8 +1068,7 @@ def description(self):
10631068
return desc
10641069

10651070
def validate_coerce(self, v):
1066-
if v is None:
1067-
# Pass None through
1071+
if is_none_or_typed_array_spec(v):
10681072
pass
10691073
elif self.array_ok and is_array(v):
10701074

@@ -1365,8 +1369,7 @@ def description(self):
13651369
return valid_color_description
13661370

13671371
def validate_coerce(self, v, should_raise=True):
1368-
if v is None:
1369-
# Pass None through
1372+
if is_none_or_typed_array_spec(v):
13701373
pass
13711374
elif self.array_ok and is_homogeneous_array(v):
13721375
v = copy_to_readonly_numpy_array(v)
@@ -1510,8 +1513,7 @@ def description(self):
15101513

15111514
def validate_coerce(self, v):
15121515

1513-
if v is None:
1514-
# Pass None through
1516+
if is_none_or_typed_array_spec(v):
15151517
pass
15161518
elif is_array(v):
15171519
validated_v = [
@@ -1708,16 +1710,17 @@ def description(self):
17081710
(e.g. 270 is converted to -90).
17091711
""".format(
17101712
plotly_name=self.plotly_name,
1711-
array_ok=", or a list, numpy array or other iterable thereof"
1712-
if self.array_ok
1713-
else "",
1713+
array_ok=(
1714+
", or a list, numpy array or other iterable thereof"
1715+
if self.array_ok
1716+
else ""
1717+
),
17141718
)
17151719

17161720
return desc
17171721

17181722
def validate_coerce(self, v):
1719-
if v is None:
1720-
# Pass None through
1723+
if is_none_or_typed_array_spec(v):
17211724
pass
17221725
elif self.array_ok and is_homogeneous_array(v):
17231726
try:
@@ -1902,8 +1905,7 @@ def vc_scalar(self, v):
19021905
return None
19031906

19041907
def validate_coerce(self, v):
1905-
if v is None:
1906-
# Pass None through
1908+
if is_none_or_typed_array_spec(v):
19071909
pass
19081910
elif self.array_ok and is_array(v):
19091911

@@ -1961,8 +1963,7 @@ def description(self):
19611963
return desc
19621964

19631965
def validate_coerce(self, v):
1964-
if v is None:
1965-
# Pass None through
1966+
if is_none_or_typed_array_spec(v):
19661967
pass
19671968
elif self.array_ok and is_homogeneous_array(v):
19681969
v = copy_to_readonly_numpy_array(v, kind="O")
@@ -2170,8 +2171,7 @@ def validate_element_with_indexed_name(self, val, validator, inds):
21702171
return val
21712172

21722173
def validate_coerce(self, v):
2173-
if v is None:
2174-
# Pass None through
2174+
if is_none_or_typed_array_spec(v):
21752175
return None
21762176
elif not is_array(v):
21772177
self.raise_invalid_val(v)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import copy
2+
import pytest
3+
import plotly.express as px
4+
5+
"""
6+
This test is in the validators folder since copy.deepcopy ends up calling
7+
BaseFigure(*args) which hits `validate_coerce`.
8+
9+
When inputs are dataframes and arrays, then the copied figure is called with
10+
base64 encoded arrays.
11+
"""
12+
13+
14+
@pytest.mark.parametrize("return_type", ["pandas", "polars", "pyarrow"])
15+
@pytest.mark.filterwarnings(
16+
r"ignore:\*scattermapbox\* is deprecated! Use \*scattermap\* instead"
17+
)
18+
def test_deepcopy_dataframe(return_type):
19+
gapminder = px.data.gapminder(return_type=return_type)
20+
fig = px.line(gapminder, x="year", y="gdpPercap", color="country")
21+
fig_copied = copy.deepcopy(fig)
22+
23+
assert fig_copied.to_dict() == fig.to_dict()
24+
25+
26+
@pytest.mark.filterwarnings(
27+
r"ignore:\*scattermapbox\* is deprecated! Use \*scattermap\* instead"
28+
)
29+
def test_deepcopy_array():
30+
gapminder = px.data.gapminder()
31+
x = gapminder["year"].to_numpy()
32+
y = gapminder["gdpPercap"].to_numpy()
33+
color = gapminder["country"].to_numpy()
34+
35+
fig = px.line(x=x, y=y, color=color)
36+
fig_copied = copy.deepcopy(fig)
37+
38+
assert fig_copied.to_dict() == fig.to_dict()

0 commit comments

Comments
 (0)