@@ -2541,7 +2541,7 @@ def perform(self, node, inputs, output_storage):
2541
2541
)
2542
2542
2543
2543
def c_code_cache_version (self ):
2544
- return (6 ,)
2544
+ return (7 ,)
2545
2545
2546
2546
def c_code (self , node , name , inputs , outputs , sub ):
2547
2547
axis , * arrays = inputs
@@ -2580,16 +2580,86 @@ def c_code(self, node, name, inputs, outputs, sub):
2580
2580
code = f"""
2581
2581
int axis = { axis_def }
2582
2582
PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2583
- PyObject* arrays_tuple = PyTuple_New( { n } ) ;
2583
+ int out_is_valid = { out } != NULL ;
2584
2584
2585
2585
{ axis_check }
2586
2586
2587
- Py_XDECREF({ out } );
2588
- { copy_arrays_to_tuple }
2589
- { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2590
- Py_DECREF(arrays_tuple);
2591
- if(!{ out } ){{
2592
- { fail }
2587
+ if (out_is_valid) {{
2588
+ // Check if we can reuse output
2589
+ npy_intp join_size = 0;
2590
+ npy_intp out_shape[{ ndim } ];
2591
+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2592
+
2593
+ for (int i = 0; i < { n } ; i++) {{
2594
+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2595
+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2596
+ { fail }
2597
+ }}
2598
+
2599
+ join_size += PyArray_SHAPE(arrays[i])[axis];
2600
+
2601
+ if (i > 0){{
2602
+ for (int j = 0; j < { ndim } ; j++) {{
2603
+ if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2604
+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2605
+ { fail }
2606
+ }}
2607
+ }}
2608
+ }}
2609
+ }}
2610
+
2611
+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2612
+ out_shape[axis] = join_size;
2613
+
2614
+ for (int i = 0; i < { ndim } ; i++) {{
2615
+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2616
+ }}
2617
+ }}
2618
+
2619
+ if (!out_is_valid) {{
2620
+ // Use PyArray_Concatenate
2621
+ Py_XDECREF({ out } );
2622
+ PyObject* arrays_tuple = PyTuple_New({ n } );
2623
+ { copy_arrays_to_tuple }
2624
+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2625
+ Py_DECREF(arrays_tuple);
2626
+ if(!{ out } ){{
2627
+ { fail }
2628
+ }}
2629
+ }}
2630
+ else {{
2631
+ // Copy the data to the pre-allocated output buffer
2632
+
2633
+ // Create view into output buffer
2634
+ PyArrayObject_fields *view;
2635
+
2636
+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2637
+ Py_INCREF(PyArray_DESCR({ out } ));
2638
+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2639
+ PyArray_DESCR({ out } ),
2640
+ { ndim } ,
2641
+ PyArray_SHAPE(arrays[0]),
2642
+ PyArray_STRIDES({ out } ),
2643
+ PyArray_DATA({ out } ),
2644
+ NPY_ARRAY_WRITEABLE,
2645
+ NULL);
2646
+ if (view == NULL) {{
2647
+ { fail }
2648
+ }}
2649
+
2650
+ // Copy data into output buffer
2651
+ for (int i = 0; i < { n } ; i++) {{
2652
+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2653
+
2654
+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2655
+ Py_DECREF(view);
2656
+ { fail }
2657
+ }}
2658
+
2659
+ view->data += (view->dimensions[axis] * view->strides[axis]);
2660
+ }}
2661
+
2662
+ Py_DECREF(view);
2593
2663
}}
2594
2664
"""
2595
2665
return code
0 commit comments