1
1
"""Test cases for GroupBy.plot"""
2
2
3
+ import matplotlib .pyplot as plt
3
4
import numpy as np
4
5
import pytest
5
6
@@ -156,10 +157,6 @@ def test_groupby_hist_series_with_legend_raises(self):
156
157
def test_groupby_scatter_colors_differ (self ):
157
158
# GH 59846 - Test that scatter plots use different colors for different groups
158
159
# similar to how line plots do
159
- from matplotlib .collections import PathCollection
160
- import matplotlib .pyplot as plt
161
-
162
- # Create test data with distinct groups
163
160
df = DataFrame (
164
161
{
165
162
"x" : [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ],
@@ -168,36 +165,20 @@ def test_groupby_scatter_colors_differ(self):
168
165
}
169
166
)
170
167
171
- # Set up a figure with both line and scatter plots
172
168
fig , (ax1 , ax2 ) = plt .subplots (1 , 2 )
173
-
174
- # Plot line chart (known to use different colors for different groups)
175
169
df .groupby ("group" ).plot (x = "x" , y = "y" , ax = ax1 , kind = "line" )
176
-
177
- # Plot scatter chart (should also use different colors for different groups)
178
170
df .groupby ("group" ).plot (x = "x" , y = "y" , ax = ax2 , kind = "scatter" )
179
171
180
- # Get the colors used in the line plot and scatter plot
181
172
line_colors = [line .get_color () for line in ax1 .get_lines ()]
173
+ scatter_colors = [
174
+ tuple (tuple (fc ) for fc in scatter .get_facecolor ())
175
+ for scatter in ax2 .collections
176
+ ]
182
177
183
- # Get scatter colors
184
- scatter_colors = []
185
- for collection in ax2 .collections :
186
- if isinstance (collection , PathCollection ): # This is a scatter plot
187
- # Get the face colors (might be array of RGBA values)
188
- face_colors = collection .get_facecolor ()
189
- # If multiple points with same color, we get the first one
190
- if face_colors .ndim > 1 :
191
- scatter_colors .append (tuple (face_colors [0 ]))
192
- else :
193
- scatter_colors .append (tuple (face_colors ))
194
-
195
- # Assert that we have the right number of colors (one per group)
196
178
assert len (line_colors ) == 3
197
179
assert len (scatter_colors ) == 3
198
180
199
- # Assert that the colors are all different
181
+ assert len ( set ( line_colors )) == 3
200
182
assert len (set (scatter_colors )) == 3
201
- assert len (line_colors ) == 3
202
183
203
184
plt .close (fig )
0 commit comments