Skip to content

Commit 483a4f5

Browse files
authored
REF: restore _concat_managers_axis0 (#50401)
* REF: restore _concat_managers_axis0 * CLN
1 parent 3bc2203 commit 483a4f5

File tree

1 file changed

+68
-40
lines changed

1 file changed

+68
-40
lines changed

pandas/core/internals/concat.py

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,21 @@ def concatenate_managers(
193193
if isinstance(mgrs_indexers[0][0], ArrayManager):
194194
return _concatenate_array_managers(mgrs_indexers, axes, concat_axis, copy)
195195

196+
# Assertions disabled for performance
197+
# for tup in mgrs_indexers:
198+
# # caller is responsible for ensuring this
199+
# indexers = tup[1]
200+
# assert concat_axis not in indexers
201+
202+
if concat_axis == 0:
203+
return _concat_managers_axis0(mgrs_indexers, axes, copy)
204+
196205
mgrs_indexers = _maybe_reindex_columns_na_proxy(axes, mgrs_indexers)
197206

198207
concat_plans = [
199208
_get_mgr_concatenation_plan(mgr, indexers) for mgr, indexers in mgrs_indexers
200209
]
201-
concat_plan = _combine_concat_plans(concat_plans, concat_axis)
210+
concat_plan = _combine_concat_plans(concat_plans)
202211
blocks = []
203212

204213
for placement, join_units in concat_plan:
@@ -229,7 +238,7 @@ def concatenate_managers(
229238

230239
fastpath = blk.values.dtype == values.dtype
231240
else:
232-
values = _concatenate_join_units(join_units, concat_axis, copy=copy)
241+
values = _concatenate_join_units(join_units, copy=copy)
233242
fastpath = False
234243

235244
if fastpath:
@@ -242,6 +251,42 @@ def concatenate_managers(
242251
return BlockManager(tuple(blocks), axes)
243252

244253

254+
def _concat_managers_axis0(
255+
mgrs_indexers, axes: list[Index], copy: bool
256+
) -> BlockManager:
257+
"""
258+
concat_managers specialized to concat_axis=0, with reindexing already
259+
having been done in _maybe_reindex_columns_na_proxy.
260+
"""
261+
had_reindexers = {
262+
i: len(mgrs_indexers[i][1]) > 0 for i in range(len(mgrs_indexers))
263+
}
264+
mgrs_indexers = _maybe_reindex_columns_na_proxy(axes, mgrs_indexers)
265+
266+
mgrs = [x[0] for x in mgrs_indexers]
267+
268+
offset = 0
269+
blocks = []
270+
for i, mgr in enumerate(mgrs):
271+
# If we already reindexed, then we definitely don't need another copy
272+
made_copy = had_reindexers[i]
273+
274+
for blk in mgr.blocks:
275+
if made_copy:
276+
nb = blk.copy(deep=False)
277+
elif copy:
278+
nb = blk.copy()
279+
else:
280+
# by slicing instead of copy(deep=False), we get a new array
281+
# object, see test_concat_copy
282+
nb = blk.getitem_block(slice(None))
283+
nb._mgr_locs = nb._mgr_locs.add(offset)
284+
blocks.append(nb)
285+
286+
offset += len(mgr.items)
287+
return BlockManager(tuple(blocks), axes)
288+
289+
245290
def _maybe_reindex_columns_na_proxy(
246291
axes: list[Index], mgrs_indexers: list[tuple[BlockManager, dict[int, np.ndarray]]]
247292
) -> list[tuple[BlockManager, dict[int, np.ndarray]]]:
@@ -252,25 +297,22 @@ def _maybe_reindex_columns_na_proxy(
252297
Columns added in this reindexing have dtype=np.void, indicating they
253298
should be ignored when choosing a column's final dtype.
254299
"""
255-
new_mgrs_indexers = []
300+
new_mgrs_indexers: list[tuple[BlockManager, dict[int, np.ndarray]]] = []
301+
256302
for mgr, indexers in mgrs_indexers:
257-
# We only reindex for axis=0 (i.e. columns), as this can be done cheaply
258-
if 0 in indexers:
259-
new_mgr = mgr.reindex_indexer(
260-
axes[0],
261-
indexers[0],
262-
axis=0,
303+
# For axis=0 (i.e. columns) we use_na_proxy and only_slice, so this
304+
# is a cheap reindexing.
305+
for i, indexer in indexers.items():
306+
mgr = mgr.reindex_indexer(
307+
axes[i],
308+
indexers[i],
309+
axis=i,
263310
copy=False,
264-
only_slice=True,
311+
only_slice=True, # only relevant for i==0
265312
allow_dups=True,
266-
use_na_proxy=True,
313+
use_na_proxy=True, # only relevant for i==0
267314
)
268-
new_indexers = indexers.copy()
269-
del new_indexers[0]
270-
new_mgrs_indexers.append((new_mgr, new_indexers))
271-
else:
272-
new_mgrs_indexers.append((mgr, indexers))
273-
315+
new_mgrs_indexers.append((mgr, {}))
274316
return new_mgrs_indexers
275317

276318

@@ -288,7 +330,9 @@ def _get_mgr_concatenation_plan(mgr: BlockManager, indexers: dict[int, np.ndarra
288330
plan : list of (BlockPlacement, JoinUnit) tuples
289331
290332
"""
291-
# Calculate post-reindex shape , save for item axis which will be separate
333+
assert len(indexers) == 0
334+
335+
# Calculate post-reindex shape, save for item axis which will be separate
292336
# for each block anyway.
293337
mgr_shape_list = list(mgr.shape)
294338
for ax, indexer in indexers.items():
@@ -523,16 +567,10 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
523567
return values
524568

525569

526-
def _concatenate_join_units(
527-
join_units: list[JoinUnit], concat_axis: AxisInt, copy: bool
528-
) -> ArrayLike:
570+
def _concatenate_join_units(join_units: list[JoinUnit], copy: bool) -> ArrayLike:
529571
"""
530-
Concatenate values from several join units along selected axis.
572+
Concatenate values from several join units along axis=1.
531573
"""
532-
if concat_axis == 0 and len(join_units) > 1:
533-
# Concatenating join units along ax0 is handled in _merge_blocks.
534-
raise AssertionError("Concatenating join units along axis0")
535-
536574
empty_dtype = _get_empty_dtype(join_units)
537575

538576
has_none_blocks = any(unit.block.dtype.kind == "V" for unit in join_units)
@@ -573,7 +611,7 @@ def _concatenate_join_units(
573611
concat_values = ensure_block_shape(concat_values, 2)
574612

575613
else:
576-
concat_values = concat_compat(to_concat, axis=concat_axis)
614+
concat_values = concat_compat(to_concat, axis=1)
577615

578616
return concat_values
579617

@@ -701,28 +739,18 @@ def _trim_join_unit(join_unit: JoinUnit, length: int) -> JoinUnit:
701739
return JoinUnit(block=extra_block, indexers=extra_indexers, shape=extra_shape)
702740

703741

704-
def _combine_concat_plans(plans, concat_axis: AxisInt):
742+
def _combine_concat_plans(plans):
705743
"""
706744
Combine multiple concatenation plans into one.
707745
708746
existing_plan is updated in-place.
747+
748+
We only get here with concat_axis == 1.
709749
"""
710750
if len(plans) == 1:
711751
for p in plans[0]:
712752
yield p[0], [p[1]]
713753

714-
elif concat_axis == 0:
715-
offset = 0
716-
for plan in plans:
717-
last_plc = None
718-
719-
for plc, unit in plan:
720-
yield plc.add(offset), [unit]
721-
last_plc = plc
722-
723-
if last_plc is not None:
724-
offset += last_plc.as_slice.stop
725-
726754
else:
727755
# singleton list so we can modify it as a side-effect within _next_or_none
728756
num_ended = [0]

0 commit comments

Comments
 (0)