@@ -431,38 +431,37 @@ def local_mul_exp_to_exp_add(fgraph, node):
431
431
This rewrite detects e^x * e^y and converts it to e^(x+y).
432
432
Similarly, e^x / e^y becomes e^(x-y).
433
433
"""
434
- if isinstance (node .op .scalar_op , (aes .Mul , aes .TrueDiv )):
435
- exps = [
436
- n .owner .inputs [0 ]
434
+ exps = [
435
+ n .owner .inputs [0 ]
436
+ for n in node .inputs
437
+ if n .owner
438
+ and hasattr (n .owner .op , "scalar_op" )
439
+ and isinstance (n .owner .op .scalar_op , aes .Exp )
440
+ ]
441
+ # Can only do any rewrite if there are at least two exp-s
442
+ if len (exps ) >= 2 :
443
+ # Mul -> add; TrueDiv -> sub
444
+ orig_op , new_op = mul , add
445
+ if isinstance (node .op .scalar_op , aes .TrueDiv ):
446
+ orig_op , new_op = true_div , sub
447
+ new_out = exp (new_op (* exps ))
448
+ if new_out .dtype != node .outputs [0 ].dtype :
449
+ new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
450
+ # The original Mul may have more than two factors, some of which may not be exp nodes.
451
+ # If so, we keep multiplying them with the new exp(sum) node.
452
+ # E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
453
+ rest = [
454
+ n
437
455
for n in node .inputs
438
- if n .owner
439
- and hasattr (n .owner .op , "scalar_op" )
440
- and isinstance (n .owner .op .scalar_op , aes .Exp )
456
+ if not n .owner
457
+ or not hasattr (n .owner .op , "scalar_op" )
458
+ or not isinstance (n .owner .op .scalar_op , aes .Exp )
441
459
]
442
- # Can only do any rewrite if there are at least two exp-s
443
- if len (exps ) >= 2 :
444
- # Mul -> add; TrueDiv -> sub
445
- orig_op , new_op = mul , add
446
- if isinstance (node .op .scalar_op , aes .TrueDiv ):
447
- orig_op , new_op = true_div , sub
448
- new_out = exp (new_op (* exps ))
460
+ if len (rest ) > 0 :
461
+ new_out = orig_op (new_out , * rest )
449
462
if new_out .dtype != node .outputs [0 ].dtype :
450
463
new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
451
- # The original Mul may have more than two factors, some of which may not be exp nodes.
452
- # If so, we keep multiplying them with the new exp(sum) node.
453
- # E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
454
- rest = [
455
- n
456
- for n in node .inputs
457
- if not n .owner
458
- or not hasattr (n .owner .op , "scalar_op" )
459
- or not isinstance (n .owner .op .scalar_op , aes .Exp )
460
- ]
461
- if len (rest ) > 0 :
462
- new_out = orig_op (new_out , * rest )
463
- if new_out .dtype != node .outputs [0 ].dtype :
464
- new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
465
- return [new_out ]
464
+ return [new_out ]
466
465
467
466
468
467
@register_specialize
@@ -472,52 +471,51 @@ def local_mul_pow_to_pow_add(fgraph, node):
472
471
This rewrite detects a^x * a^y and converts it to a^(x+y).
473
472
Similarly, a^x / a^y becomes a^(x-y).
474
473
"""
475
- if isinstance (node .op .scalar_op , (aes .Mul , aes .TrueDiv )):
476
- # search for pow-s and group them by their bases
477
- pow_nodes = defaultdict (list )
478
- rest = []
479
- for n in node .inputs :
480
- if (
481
- n .owner
482
- and hasattr (n .owner .op , "scalar_op" )
483
- and isinstance (n .owner .op .scalar_op , aes .Pow )
484
- ):
485
- base_node = n .owner .inputs [0 ]
486
- # exponent is at n.owner.inputs[1], but we need to store the full node
487
- # in case this particular power node remains alone and can't be rewritten
488
- pow_nodes [base_node ].append (n )
489
- else :
490
- rest .append (n )
491
-
492
- # Can only do any rewrite if there are at least two pow-s with the same base
493
- can_rewrite = [k for k , v in pow_nodes .items () if len (v ) >= 2 ]
494
- if len (can_rewrite ) >= 1 :
495
- # Mul -> add; TrueDiv -> sub
496
- orig_op , new_op = mul , add
497
- if isinstance (node .op .scalar_op , aes .TrueDiv ):
498
- orig_op , new_op = true_div , sub
499
- pow_factors = []
500
- # Rewrite pow-s having the same base for each different base
501
- # E.g.: a^x * a^y --> a^(x+y)
502
- for base in can_rewrite :
503
- exponents = [n .owner .inputs [1 ] for n in pow_nodes [base ]]
504
- new_node = base ** new_op (* exponents )
505
- if new_node .dtype != node .outputs [0 ].dtype :
506
- new_node = cast (new_node , dtype = node .outputs [0 ].dtype )
507
- pow_factors .append (new_node )
508
- # Don't forget about those sole pow-s that couldn't be rewriten
509
- sole_pows = [v [0 ] for k , v in pow_nodes .items () if k not in can_rewrite ]
510
- # Combine the rewritten pow-s and other, non-pow factors of the original Mul
511
- # E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
512
- if len (pow_factors ) > 1 or len (sole_pows ) > 0 or len (rest ) > 0 :
513
- new_out = orig_op (* pow_factors , * sole_pows , * rest )
514
- if new_out .dtype != node .outputs [0 ].dtype :
515
- new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
516
- else :
517
- # if all factors of the original mul were pows-s with the same base,
518
- # we can get rid of the mul completely.
519
- new_out = pow_factors [0 ]
520
- return [new_out ]
474
+ # search for pow-s and group them by their bases
475
+ pow_nodes = defaultdict (list )
476
+ rest = []
477
+ for n in node .inputs :
478
+ if (
479
+ n .owner
480
+ and hasattr (n .owner .op , "scalar_op" )
481
+ and isinstance (n .owner .op .scalar_op , aes .Pow )
482
+ ):
483
+ base_node = n .owner .inputs [0 ]
484
+ # exponent is at n.owner.inputs[1], but we need to store the full node
485
+ # in case this particular power node remains alone and can't be rewritten
486
+ pow_nodes [base_node ].append (n )
487
+ else :
488
+ rest .append (n )
489
+
490
+ # Can only do any rewrite if there are at least two pow-s with the same base
491
+ can_rewrite = [k for k , v in pow_nodes .items () if len (v ) >= 2 ]
492
+ if len (can_rewrite ) >= 1 :
493
+ # Mul -> add; TrueDiv -> sub
494
+ orig_op , new_op = mul , add
495
+ if isinstance (node .op .scalar_op , aes .TrueDiv ):
496
+ orig_op , new_op = true_div , sub
497
+ pow_factors = []
498
+ # Rewrite pow-s having the same base for each different base
499
+ # E.g.: a^x * a^y --> a^(x+y)
500
+ for base in can_rewrite :
501
+ exponents = [n .owner .inputs [1 ] for n in pow_nodes [base ]]
502
+ new_node = base ** new_op (* exponents )
503
+ if new_node .dtype != node .outputs [0 ].dtype :
504
+ new_node = cast (new_node , dtype = node .outputs [0 ].dtype )
505
+ pow_factors .append (new_node )
506
+ # Don't forget about those sole pow-s that couldn't be rewriten
507
+ sole_pows = [v [0 ] for k , v in pow_nodes .items () if k not in can_rewrite ]
508
+ # Combine the rewritten pow-s and other, non-pow factors of the original Mul
509
+ # E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
510
+ if len (pow_factors ) > 1 or len (sole_pows ) > 0 or len (rest ) > 0 :
511
+ new_out = orig_op (* pow_factors , * sole_pows , * rest )
512
+ if new_out .dtype != node .outputs [0 ].dtype :
513
+ new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
514
+ else :
515
+ # if all factors of the original mul were pows-s with the same base,
516
+ # we can get rid of the mul completely.
517
+ new_out = pow_factors [0 ]
518
+ return [new_out ]
521
519
522
520
523
521
@register_stabilize
0 commit comments