Skip to content

Commit 6fc1362

Browse files
wip
1 parent 234ec0d commit 6fc1362

File tree

6 files changed

+227
-22
lines changed

6 files changed

+227
-22
lines changed

packages/python/plotly/plotly/express/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
box,
3535
strip,
3636
histogram,
37+
ecdf,
38+
kde,
3739
scatter_matrix,
3840
parallel_coordinates,
3941
parallel_categories,
@@ -88,6 +90,8 @@
8890
"box",
8991
"strip",
9092
"histogram",
93+
"ecdf",
94+
"kde",
9195
"choropleth",
9296
"choropleth_mapbox",
9397
"pie",

packages/python/plotly/plotly/express/_chart_types.py

+115
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,121 @@ def histogram(
471471
)
472472

473473

474+
def ecdf(
475+
data_frame=None,
476+
x=None,
477+
y=None,
478+
color=None,
479+
line_dash=None,
480+
facet_row=None,
481+
facet_col=None,
482+
facet_col_wrap=0,
483+
facet_row_spacing=None,
484+
facet_col_spacing=None,
485+
hover_name=None,
486+
hover_data=None,
487+
animation_frame=None,
488+
animation_group=None,
489+
category_orders=None,
490+
labels=None,
491+
color_discrete_sequence=None,
492+
color_discrete_map=None,
493+
line_dash_sequence=None,
494+
line_dash_map=None,
495+
marginal=None,
496+
opacity=None,
497+
orientation=None,
498+
line_shape=None,
499+
norm=None, # TODO use this
500+
complementary=None, # TODO use this
501+
log_x=False,
502+
log_y=False,
503+
range_x=None,
504+
range_y=None,
505+
title=None,
506+
template=None,
507+
width=None,
508+
height=None,
509+
):
510+
"""
511+
In a Empirical Cumulative Distribution Function (ECDF) plot, rows of `data_frame`
512+
are sorted by the value `x` (or `y` if `orientation` is `'h'`) and their cumulative
513+
count (or the cumulative sum of `y` if supplied and `orientation` is `h`) is drawn
514+
as a line.
515+
"""
516+
return make_figure(args=locals(), constructor=go.Scatter)
517+
518+
519+
ecdf.__doc__ = make_docstring(
520+
ecdf,
521+
append_dict=dict(
522+
x=[
523+
"If `orientation` is `'h'`, the cumulative sum of this argument is plotted rather than the cumulative count."
524+
]
525+
+ _wide_mode_xy_append,
526+
y=[
527+
"If `orientation` is `'v'`, the cumulative sum of this argument is plotted rather than the cumulative count."
528+
]
529+
+ _wide_mode_xy_append,
530+
),
531+
)
532+
533+
534+
def kde(
535+
data_frame=None,
536+
x=None,
537+
y=None,
538+
color=None,
539+
line_dash=None,
540+
facet_row=None,
541+
facet_col=None,
542+
facet_col_wrap=0,
543+
facet_row_spacing=None,
544+
facet_col_spacing=None,
545+
hover_name=None,
546+
hover_data=None,
547+
animation_frame=None,
548+
animation_group=None,
549+
category_orders=None,
550+
labels=None,
551+
color_discrete_sequence=None,
552+
color_discrete_map=None,
553+
line_dash_sequence=None,
554+
line_dash_map=None,
555+
marginal=None,
556+
opacity=None,
557+
orientation=None,
558+
norm=None, # TODO use this
559+
kernel=None, # TODO use this
560+
bw_method=None, # TODO use this
561+
bw_adjust=None, # TODO use this
562+
log_x=False,
563+
log_y=False,
564+
range_x=None,
565+
range_y=None,
566+
title=None,
567+
template=None,
568+
width=None,
569+
height=None,
570+
):
571+
"""
572+
In a Kernel Density Estimation (KDE) plot, rows of `data_frame`
573+
are used as inputs to a KDE smoothing function which is rendered as a line.
574+
"""
575+
return make_figure(args=locals(), constructor=go.Scatter)
576+
577+
578+
kde.__doc__ = make_docstring(
579+
kde,
580+
append_dict=dict(
581+
x=["If `orientation` is `'h'`, this argument is used as KDE weights."]
582+
+ _wide_mode_xy_append,
583+
y=["If `orientation` is `'v'`, this argument is used as KDE weights."]
584+
+ _wide_mode_xy_append,
585+
),
586+
)
587+
588+
474589
def violin(
475590
data_frame=None,
476591
x=None,

packages/python/plotly/plotly/express/_core.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,9 @@ def build_dataframe(args, constructor):
12751275
wide_cross_name = None # will likely be "index" in wide_mode
12761276
value_name = None # will likely be "value" in wide_mode
12771277
hist2d_types = [go.Histogram2d, go.Histogram2dContour]
1278+
hist1d_orientation = (
1279+
constructor == go.Histogram or "complementary" in args or "kernel" in args
1280+
)
12781281
if constructor in cartesians:
12791282
if wide_x and wide_y:
12801283
raise ValueError(
@@ -1309,7 +1312,7 @@ def build_dataframe(args, constructor):
13091312
df_provided and var_name in df_input
13101313
):
13111314
var_name = "variable"
1312-
if constructor == go.Histogram:
1315+
if hist1d_orientation:
13131316
wide_orientation = "v" if wide_x else "h"
13141317
else:
13151318
wide_orientation = "v" if wide_y else "h"
@@ -1323,7 +1326,10 @@ def build_dataframe(args, constructor):
13231326
var_name = _escape_col_name(df_input, var_name, [])
13241327

13251328
missing_bar_dim = None
1326-
if constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types:
1329+
if (
1330+
constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types
1331+
and not hist1d_orientation
1332+
):
13271333
if not wide_mode and (no_x != no_y):
13281334
for ax in ["x", "y"]:
13291335
if args.get(ax, None) is None:
@@ -1420,14 +1426,22 @@ def build_dataframe(args, constructor):
14201426
df_output[var_name] = df_output[var_name].astype(str)
14211427
orient_v = wide_orientation == "v"
14221428

1423-
if constructor in [go.Scatter, go.Funnel] + hist2d_types:
1429+
if hist1d_orientation:
1430+
args["x" if orient_v else "y"] = value_name
1431+
if wide_cross_name is None and constructor == go.Scatter:
1432+
args["y" if orient_v else "x"] = count_name
1433+
df_output[count_name] = 1
1434+
else:
1435+
args["y" if orient_v else "x"] = wide_cross_name
1436+
args["color"] = args["color"] or var_name
1437+
elif constructor in [go.Scatter, go.Funnel] + hist2d_types:
14241438
args["x" if orient_v else "y"] = wide_cross_name
14251439
args["y" if orient_v else "x"] = value_name
14261440
if constructor != go.Histogram2d:
14271441
args["color"] = args["color"] or var_name
14281442
if "line_group" in args:
14291443
args["line_group"] = args["line_group"] or var_name
1430-
if constructor == go.Bar:
1444+
elif constructor == go.Bar:
14311445
if _is_continuous(df_output, value_name):
14321446
args["x" if orient_v else "y"] = wide_cross_name
14331447
args["y" if orient_v else "x"] = value_name
@@ -1437,13 +1451,9 @@ def build_dataframe(args, constructor):
14371451
args["y" if orient_v else "x"] = count_name
14381452
df_output[count_name] = 1
14391453
args["color"] = args["color"] or var_name
1440-
if constructor in [go.Violin, go.Box]:
1454+
elif constructor in [go.Violin, go.Box]:
14411455
args["x" if orient_v else "y"] = wide_cross_name or var_name
14421456
args["y" if orient_v else "x"] = value_name
1443-
if constructor == go.Histogram:
1444-
args["x" if orient_v else "y"] = value_name
1445-
args["y" if orient_v else "x"] = wide_cross_name
1446-
args["color"] = args["color"] or var_name
14471457
if no_color:
14481458
args["color"] = None
14491459
args["data_frame"] = df_output
@@ -1925,7 +1935,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19251935
if (
19261936
trace_spec != trace_specs[0]
19271937
and trace_spec.constructor in [go.Violin, go.Box, go.Histogram]
1928-
and m.variable == "symbol"
1938+
and m.variable in ["symbol", "dash"]
19291939
):
19301940
pass
19311941
elif (
@@ -1986,6 +1996,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19861996
):
19871997
trace.update(marker=dict(color=trace.line.color))
19881998

1999+
if "complementary" in args: # ECDF
2000+
base = args["x"] if args["orientation"] == "v" else args["y"]
2001+
var = args["x"] if args["orientation"] == "h" else args["y"]
2002+
group = group.sort_values(by=base)
2003+
group[var] = group[var].cumsum()
2004+
19892005
patch, fit_results = make_trace_kwargs(
19902006
args, trace_spec, group, mapping_labels.copy(), sizeref
19912007
)

packages/python/plotly/plotly/express/_doc.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -541,10 +541,17 @@
541541
"Sets the number of rendered sectors from any given `level`. Set `maxdepth` to -1 to render all the"
542542
"levels in the hierarchy.",
543543
],
544+
norm=["TODO"],
545+
complementary=["TODO"],
546+
kernel=["TODO"],
547+
bw_method=["TODO"],
548+
bw_adjust=["TODO"],
544549
)
545550

546551

547-
def make_docstring(fn, override_dict={}, append_dict={}):
552+
def make_docstring(fn, override_dict=None, append_dict=None):
553+
override_dict = {} if override_dict is None else override_dict
554+
append_dict = {} if append_dict is None else append_dict
548555
tw = TextWrapper(width=75, initial_indent=" ", subsequent_indent=" ")
549556
result = (fn.__doc__ or "") + "\nParameters\n----------\n"
550557
for param in getfullargspec(fn)[0]:

packages/python/plotly/plotly/tests/test_core/test_px/test_facets.py

+48-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import plotly
21
import pandas as pd
32
import plotly.express as px
43
from pytest import approx
@@ -47,6 +46,48 @@ def test_facets():
4746
assert fig.layout.yaxis4.domain[0] - fig.layout.yaxis.domain[1] == approx(0.08)
4847

4948

49+
def test_facets_with_marginals():
50+
df = px.data.tips()
51+
52+
fig = px.histogram(df, x="total_bill", facet_col="sex", marginal="rug")
53+
assert len(fig.data) == 4
54+
# fig = px.histogram(df, x="total_bill", facet_row="sex", marginal="rug")
55+
# assert len(fig.data) == 2 buggy
56+
57+
fig = px.scatter(df, x="total_bill", y="tip", facet_col="sex", marginal_x="rug")
58+
assert len(fig.data) == 4
59+
fig = px.scatter(
60+
df, x="total_bill", y="tip", facet_col="day", facet_col_wrap=2, marginal_x="rug"
61+
)
62+
assert len(fig.data) == 8 # ignore the wrap when marginal is used
63+
fig = px.scatter(df, x="total_bill", y="tip", facet_col="sex", marginal_y="rug")
64+
assert len(fig.data) == 2 # ignore the marginal in the facet direction
65+
66+
fig = px.scatter(df, x="total_bill", y="tip", facet_row="sex", marginal_x="rug")
67+
assert len(fig.data) == 2 # ignore the marginal in the facet direction
68+
fig = px.scatter(df, x="total_bill", y="tip", facet_row="sex", marginal_y="rug")
69+
assert len(fig.data) == 4
70+
71+
fig = px.scatter(
72+
df, x="total_bill", y="tip", facet_row="sex", marginal_y="rug", marginal_x="rug"
73+
)
74+
assert len(fig.data) == 4 # ignore the marginal in the facet direction
75+
fig = px.scatter(
76+
df, x="total_bill", y="tip", facet_col="sex", marginal_y="rug", marginal_x="rug"
77+
)
78+
assert len(fig.data) == 4 # ignore the marginal in the facet direction
79+
fig = px.scatter(
80+
df,
81+
x="total_bill",
82+
y="tip",
83+
facet_row="sex",
84+
facet_col="sex",
85+
marginal_y="rug",
86+
marginal_x="rug",
87+
)
88+
assert len(fig.data) == 2 # ignore all marginals
89+
90+
5091
@pytest.fixture
5192
def bad_facet_spacing_df():
5293
NROWS = 101
@@ -65,25 +106,21 @@ def bad_facet_spacing_df():
65106
def test_bad_facet_spacing_eror(bad_facet_spacing_df):
66107
df = bad_facet_spacing_df
67108
with pytest.raises(
68-
ValueError, match="Use the facet_row_spacing argument to adjust this spacing\."
109+
ValueError, match="Use the facet_row_spacing argument to adjust this spacing."
69110
):
70-
fig = px.scatter(
71-
df, x="x", y="y", facet_row="category", facet_row_spacing=0.01001
72-
)
111+
px.scatter(df, x="x", y="y", facet_row="category", facet_row_spacing=0.01001)
73112
with pytest.raises(
74-
ValueError, match="Use the facet_col_spacing argument to adjust this spacing\."
113+
ValueError, match="Use the facet_col_spacing argument to adjust this spacing."
75114
):
76-
fig = px.scatter(
77-
df, x="x", y="y", facet_col="category", facet_col_spacing=0.01001
78-
)
115+
px.scatter(df, x="x", y="y", facet_col="category", facet_col_spacing=0.01001)
79116
# Check error is not raised when the spacing is OK
80117
try:
81-
fig = px.scatter(df, x="x", y="y", facet_row="category", facet_row_spacing=0.01)
118+
px.scatter(df, x="x", y="y", facet_row="category", facet_row_spacing=0.01)
82119
except ValueError:
83120
# Error shouldn't be raised, so fail if it is
84121
assert False
85122
try:
86-
fig = px.scatter(df, x="x", y="y", facet_col="category", facet_col_spacing=0.01)
123+
px.scatter(df, x="x", y="y", facet_col="category", facet_col_spacing=0.01)
87124
except ValueError:
88125
# Error shouldn't be raised, so fail if it is
89126
assert False
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import plotly.express as px
2+
import pytest
3+
4+
5+
@pytest.mark.parametrize("px_fn", [px.scatter, px.density_heatmap, px.density_contour])
6+
@pytest.mark.parametrize("marginal_x", [None, "histogram", "box", "violin"])
7+
@pytest.mark.parametrize("marginal_y", [None, "rug"])
8+
def test_xy_marginals(px_fn, marginal_x, marginal_y):
9+
df = px.data.tips()
10+
11+
fig = px_fn(
12+
df, x="total_bill", y="tip", marginal_x=marginal_x, marginal_y=marginal_y
13+
)
14+
assert len(fig.data) == 1 + (marginal_x is not None) + (marginal_y is not None)
15+
16+
17+
@pytest.mark.parametrize("px_fn", [px.histogram, px.ecdf, px.kde])
18+
@pytest.mark.parametrize("marginal", [None, "rug", "histogram", "box", "violin"])
19+
@pytest.mark.parametrize("orientation", ["h", "v"])
20+
def test_single_marginals(px_fn, marginal, orientation):
21+
df = px.data.tips()
22+
23+
fig = px_fn(
24+
df, x="total_bill", y="total_bill", marginal=marginal, orientation=orientation
25+
)
26+
assert len(fig.data) == 1 + (marginal is not None)

0 commit comments

Comments
 (0)