@@ -186,6 +186,23 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
186
186
return node_rewriter
187
187
188
188
189
+ def register_scalarize (
190
+ node_rewriter : Union [RewriteDatabase , NodeRewriter , str ], * tags : str , ** kwargs
191
+ ):
192
+ if isinstance (node_rewriter , str ):
193
+
194
+ def register (inner_rewriter : Union [RewriteDatabase , Rewriter ]):
195
+ return register_specialize (inner_rewriter , node_rewriter , * tags , ** kwargs )
196
+
197
+ return register
198
+ else :
199
+ name = kwargs .pop ("name" , None ) or node_rewriter .__name__
200
+ compile .optdb ["scalarize" ].register (
201
+ name , node_rewriter , "fast_run" , "fast_compile" , * tags , ** kwargs
202
+ )
203
+ return node_rewriter
204
+
205
+
189
206
def register_uncanonicalize (
190
207
node_rewriter : Union [RewriteDatabase , NodeRewriter , str ], * tags : str , ** kwargs
191
208
):
@@ -226,30 +243,36 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
226
243
227
244
@register_canonicalize
228
245
@register_specialize
246
+ @register_scalarize
229
247
@node_rewriter ([TensorFromScalar ])
230
248
def local_tensor_scalar_tensor (fgraph , node ):
231
249
"""tensor_from_scalar(scalar_from_tensor(x)) -> x"""
232
- if isinstance (node .op , TensorFromScalar ):
233
- s = node .inputs [0 ]
234
- if s .owner and isinstance (s .owner .op , ScalarFromTensor ):
235
- t = s .owner .inputs [0 ]
250
+ s = node .inputs [0 ]
251
+ if s .owner and isinstance (s .owner .op , ScalarFromTensor ):
252
+ t = s .owner .inputs [0 ]
236
253
237
- # We don't need to copy over any stack traces here
238
- return [t ]
254
+ # We don't need to copy over any stack traces here
255
+ return [t ]
239
256
240
257
241
258
@register_canonicalize
242
259
@register_specialize
260
+ @register_scalarize
243
261
@node_rewriter ([ScalarFromTensor ])
244
262
def local_scalar_tensor_scalar (fgraph , node ):
245
- """scalar_from_tensor(tensor_from_scalar(x)) -> x"""
246
- if isinstance (node .op , ScalarFromTensor ):
247
- t = node .inputs [0 ]
248
- if t .owner and isinstance (t .owner .op , TensorFromScalar ):
249
- s = t .owner .inputs [0 ]
250
-
251
- # We don't need to copy over any stack traces here
252
- return [s ]
263
+ """scalar_from_tensor(tensor_from_scalar(x)) -> x
264
+
265
+ and scalar_from_tensor(TensorConstant(x)) -> x
266
+ """
267
+ t = node .inputs [0 ]
268
+ if t .owner and isinstance (t .owner .op , TensorFromScalar ):
269
+ s = t .owner .inputs [0 ]
270
+
271
+ # We don't need to copy over any stack traces here
272
+ return [s ]
273
+ if isinstance (t , TensorConstant ):
274
+ assert t .ndim == 0
275
+ return [aes .constant (t .value .item (), t .name , t .dtype )]
253
276
254
277
255
278
@register_specialize ("local_alloc_elemwise" )
0 commit comments