Skip to content

Commit f03ded4

Browse files
committed
Stack eager optimization for single tensor
1 parent c0860f8 commit f03ded4

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

pytensor/tensor/basic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2943,6 +2943,8 @@ def stack(tensors: Sequence["TensorLike"], axis: int = 0):
29432943
):
29442944
# In case there is direct scalar
29452945
tensors = list(map(as_tensor_variable, tensors))
2946+
if len(tensors) == 1:
2947+
return atleast_1d(tensors[0])
29462948
dtype = ps.upcast(*[i.dtype for i in tensors])
29472949
return MakeVector(dtype)(*tensors)
29482950
return join(axis, *[shape_padaxis(t, axis) for t in tensors])

0 commit comments

Comments
 (0)