Skip to content

Commit f695840

Browse files
committed
Reuse output buffer in C-impl of Join
1 parent 3de303d commit f695840

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

pytensor/tensor/basic.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2541,7 +2541,7 @@ def perform(self, node, inputs, output_storage):
25412541
)
25422542

25432543
def c_code_cache_version(self):
2544-
return (6,)
2544+
return (7,)
25452545

25462546
def c_code(self, node, name, inputs, outputs, sub):
25472547
axis, *arrays = inputs
@@ -2580,16 +2580,86 @@ def c_code(self, node, name, inputs, outputs, sub):
25802580
code = f"""
25812581
int axis = {axis_def}
25822582
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
2583-
PyObject* arrays_tuple = PyTuple_New({n});
2583+
int out_is_valid = {out} != NULL;
25842584
25852585
{axis_check}
25862586
2587-
Py_XDECREF({out});
2588-
{copy_arrays_to_tuple}
2589-
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2590-
Py_DECREF(arrays_tuple);
2591-
if(!{out}){{
2592-
{fail}
2587+
if (out_is_valid) {{
2588+
// Check if we can reuse output
2589+
npy_intp join_size = 0;
2590+
npy_intp out_shape[{ndim}];
2591+
npy_intp *shape = PyArray_SHAPE(arrays[0]);
2592+
2593+
for (int i = 0; i < {n}; i++) {{
2594+
if (PyArray_NDIM(arrays[i]) != {ndim}) {{
2595+
PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2596+
{fail}
2597+
}}
2598+
2599+
join_size += PyArray_SHAPE(arrays[i])[axis];
2600+
2601+
if (i > 0){{
2602+
for (int j = 0; j < {ndim}; j++) {{
2603+
if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2604+
PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2605+
{fail}
2606+
}}
2607+
}}
2608+
}}
2609+
}}
2610+
2611+
memcpy(out_shape, shape, {ndim} * sizeof(npy_intp));
2612+
out_shape[axis] = join_size;
2613+
2614+
for (int i = 0; i < {ndim}; i++) {{
2615+
out_is_valid &= (PyArray_SHAPE({out})[i] == out_shape[i]);
2616+
}}
2617+
}}
2618+
2619+
if (!out_is_valid) {{
2620+
// Use PyArray_Concatenate
2621+
Py_XDECREF({out});
2622+
PyObject* arrays_tuple = PyTuple_New({n});
2623+
{copy_arrays_to_tuple}
2624+
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2625+
Py_DECREF(arrays_tuple);
2626+
if(!{out}){{
2627+
{fail}
2628+
}}
2629+
}}
2630+
else {{
2631+
// Copy the data to the pre-allocated output buffer
2632+
2633+
// Create view into output buffer
2634+
PyArrayObject_fields *view;
2635+
2636+
// PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2637+
Py_INCREF(PyArray_DESCR({out}));
2638+
view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2639+
PyArray_DESCR({out}),
2640+
{ndim},
2641+
PyArray_SHAPE(arrays[0]),
2642+
PyArray_STRIDES({out}),
2643+
PyArray_DATA({out}),
2644+
NPY_ARRAY_WRITEABLE,
2645+
NULL);
2646+
if (view == NULL) {{
2647+
{fail}
2648+
}}
2649+
2650+
// Copy data into output buffer
2651+
for (int i = 0; i < {n}; i++) {{
2652+
view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2653+
2654+
if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2655+
Py_DECREF(view);
2656+
{fail}
2657+
}}
2658+
2659+
view->data += (view->dimensions[axis] * view->strides[axis]);
2660+
}}
2661+
2662+
Py_DECREF(view);
25932663
}}
25942664
"""
25952665
return code

tests/tensor/test_basic.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
ivector,
118118
lscalar,
119119
lvector,
120+
matrices,
120121
matrix,
121122
row,
122123
scalar,
@@ -1762,7 +1763,7 @@ def test_join_matrixV_negative_axis(self):
17621763
got = f(-2)
17631764
assert np.allclose(got, want)
17641765

1765-
with pytest.raises(IndexError):
1766+
with pytest.raises(ValueError):
17661767
f(-3)
17671768

17681769
@pytest.mark.parametrize("py_impl", (False, True))
@@ -1805,7 +1806,7 @@ def test_join_matrixC_negative_axis(self, py_impl):
18051806
got = f()
18061807
assert np.allclose(got, want)
18071808

1808-
with pytest.raises(IndexError):
1809+
with pytest.raises(ValueError):
18091810
join(-3, a, b)
18101811

18111812
with impl_ctxt:
@@ -2152,6 +2153,32 @@ def test_split_view(self, linker):
21522153
assert np.allclose(r, expected)
21532154
assert r.base is x_test
21542155

2156+
@pytest.mark.parametrize("gc", (True, False), ids=lambda x: f"gc={x}")
2157+
@pytest.mark.parametrize("memory_layout", ["C-contiguous", "F-contiguous", "Mixed"])
2158+
@pytest.mark.parametrize("axis", (0, 1), ids=lambda x: f"axis={x}")
2159+
@pytest.mark.parametrize("ndim", (1, 2), ids=["vector", "matrix"])
2160+
@config.change_flags(cmodule__warn_no_version=False)
2161+
def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark):
2162+
if ndim == 1 and not (memory_layout == "C-contiguous" and axis == 0):
2163+
pytest.skip("Redundant parametrization")
2164+
n = 64
2165+
inputs = vectors("abcdef") if ndim == 1 else matrices("abcdef")
2166+
out = join(axis, *inputs)
2167+
fn = pytensor.function(inputs, Out(out, borrow=True), trust_input=True)
2168+
fn.vm.allow_gc = gc
2169+
test_values = [np.zeros((n, n)[:ndim], dtype=inputs[0].dtype) for _ in inputs]
2170+
if memory_layout == "C-contiguous":
2171+
pass
2172+
elif memory_layout == "F-contiguous":
2173+
test_values = [t.T for t in test_values]
2174+
elif memory_layout == "Mixed":
2175+
test_values = [t if i % 2 else t.T for i, t in enumerate(test_values)]
2176+
else:
2177+
raise ValueError
2178+
2179+
assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6)
2180+
benchmark(fn, *test_values)
2181+
21552182

21562183
def test_TensorFromScalar():
21572184
s = ps.constant(56)

0 commit comments

Comments
 (0)