Skip to content

Commit 0b94be0

Browse files
committed
Reduce overhead of Scalar python implementation
1 parent 0b07727 commit 0b94be0

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

pytensor/scalar/basic.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from pytensor.utils import (
3737
apply_across_args,
3838
difference,
39-
from_return_values,
4039
to_return_values,
4140
)
4241

@@ -1081,6 +1080,16 @@ def real_out(type):
10811080
return (type,)
10821081

10831082

1083+
def _cast_to_promised_scalar_dtype(x, dtype):
1084+
try:
1085+
return x.astype(dtype)
1086+
except AttributeError:
1087+
if dtype == "bool":
1088+
return np.bool_(x)
1089+
else:
1090+
return getattr(np, dtype)(x)
1091+
1092+
10841093
class ScalarOp(COp):
10851094
nin = -1
10861095
nout = 1
@@ -1134,28 +1143,18 @@ def output_types(self, types):
11341143
else:
11351144
raise NotImplementedError(f"Cannot calculate the output types for {self}")
11361145

1137-
@staticmethod
1138-
def _cast_scalar(x, dtype):
1139-
if hasattr(x, "astype"):
1140-
return x.astype(dtype)
1141-
elif dtype == "bool":
1142-
return np.bool_(x)
1143-
else:
1144-
return getattr(np, dtype)(x)
1145-
11461146
def perform(self, node, inputs, output_storage):
11471147
if self.nout == 1:
1148-
dtype = node.outputs[0].dtype
1149-
output_storage[0][0] = self._cast_scalar(self.impl(*inputs), dtype)
1148+
output_storage[0][0] = _cast_to_promised_scalar_dtype(
1149+
self.impl(*inputs),
1150+
node.outputs[0].dtype,
1151+
)
11501152
else:
1151-
variables = from_return_values(self.impl(*inputs))
1152-
assert len(variables) == len(output_storage)
11531153
# strict=False because we are in a hot loop
11541154
for out, storage, variable in zip(
1155-
node.outputs, output_storage, variables, strict=False
1155+
node.outputs, output_storage, self.impl(*inputs), strict=False
11561156
):
1157-
dtype = out.dtype
1158-
storage[0] = self._cast_scalar(variable, dtype)
1157+
storage[0] = _cast_to_promised_scalar_dtype(variable, out.dtype)
11591158

11601159
def impl(self, *inputs):
11611160
raise MethodNotDefined("impl", type(self), self.__class__.__name__)

0 commit comments

Comments
 (0)