|
36 | 36 | from pytensor.utils import (
|
37 | 37 | apply_across_args,
|
38 | 38 | difference,
|
39 |
| - from_return_values, |
40 | 39 | to_return_values,
|
41 | 40 | )
|
42 | 41 |
|
@@ -1081,6 +1080,16 @@ def real_out(type):
|
1081 | 1080 | return (type,)
|
1082 | 1081 |
|
1083 | 1082 |
|
| 1083 | +def _cast_to_promised_scalar_dtype(x, dtype): |
| 1084 | + try: |
| 1085 | + return x.astype(dtype) |
| 1086 | + except AttributeError: |
| 1087 | + if dtype == "bool": |
| 1088 | + return np.bool_(x) |
| 1089 | + else: |
| 1090 | + return getattr(np, dtype)(x) |
| 1091 | + |
| 1092 | + |
1084 | 1093 | class ScalarOp(COp):
|
1085 | 1094 | nin = -1
|
1086 | 1095 | nout = 1
|
@@ -1134,28 +1143,18 @@ def output_types(self, types):
|
1134 | 1143 | else:
|
1135 | 1144 | raise NotImplementedError(f"Cannot calculate the output types for {self}")
|
1136 | 1145 |
|
1137 |
| - @staticmethod |
1138 |
| - def _cast_scalar(x, dtype): |
1139 |
| - if hasattr(x, "astype"): |
1140 |
| - return x.astype(dtype) |
1141 |
| - elif dtype == "bool": |
1142 |
| - return np.bool_(x) |
1143 |
| - else: |
1144 |
| - return getattr(np, dtype)(x) |
1145 |
| - |
1146 | 1146 | def perform(self, node, inputs, output_storage):
|
1147 | 1147 | if self.nout == 1:
|
1148 |
| - dtype = node.outputs[0].dtype |
1149 |
| - output_storage[0][0] = self._cast_scalar(self.impl(*inputs), dtype) |
| 1148 | + output_storage[0][0] = _cast_to_promised_scalar_dtype( |
| 1149 | + self.impl(*inputs), |
| 1150 | + node.outputs[0].dtype, |
| 1151 | + ) |
1150 | 1152 | else:
|
1151 |
| - variables = from_return_values(self.impl(*inputs)) |
1152 |
| - assert len(variables) == len(output_storage) |
1153 | 1153 | # strict=False because we are in a hot loop
|
1154 | 1154 | for out, storage, variable in zip(
|
1155 |
| - node.outputs, output_storage, variables, strict=False |
| 1155 | + node.outputs, output_storage, self.impl(*inputs), strict=False |
1156 | 1156 | ):
|
1157 |
| - dtype = out.dtype |
1158 |
| - storage[0] = self._cast_scalar(variable, dtype) |
| 1157 | + storage[0] = _cast_to_promised_scalar_dtype(variable, out.dtype) |
1159 | 1158 |
|
1160 | 1159 | def impl(self, *inputs):
|
1161 | 1160 | raise MethodNotDefined("impl", type(self), self.__class__.__name__)
|
|
0 commit comments