Skip to content

Commit cb2c40b

Browse files
committed
Replace double negative when checking if a variable has an Apply node
1 parent 37d4c40 commit cb2c40b

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

pytensor/tensor/rewriting/math.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def local_func_inv(fgraph, node):
306306

307307
if not isinstance(node.op, Elemwise):
308308
return
309-
if not x.owner or not isinstance(x.owner.op, Elemwise):
309+
if not (x.owner and isinstance(x.owner.op, Elemwise)):
310310
return
311311

312312
prev_op = x.owner.op.scalar_op
@@ -332,9 +332,7 @@ def local_func_inv(fgraph, node):
332332
def local_exp_log(fgraph, node):
333333
x = node.inputs[0]
334334

335-
if not isinstance(node.op, Elemwise):
336-
return
337-
if not x.owner or not isinstance(x.owner.op, Elemwise):
335+
if not (x.owner and isinstance(x.owner.op, Elemwise)):
338336
return
339337

340338
prev_op = x.owner.op.scalar_op
@@ -375,9 +373,7 @@ def local_exp_log_nan_switch(fgraph, node):
375373
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
376374
x = node.inputs[0]
377375

378-
if not isinstance(node.op, Elemwise):
379-
return
380-
if not x.owner or not isinstance(x.owner.op, Elemwise):
376+
if not (x.owner and isinstance(x.owner.op, Elemwise)):
381377
return
382378

383379
prev_op = x.owner.op.scalar_op
@@ -501,9 +497,11 @@ def local_mul_exp_to_exp_add(fgraph, node):
501497
rest = [
502498
n
503499
for n in node.inputs
504-
if not n.owner
505-
or not hasattr(n.owner.op, "scalar_op")
506-
or not isinstance(n.owner.op.scalar_op, ps.Exp)
500+
if not (
501+
n.owner
502+
and isinstance(n.owner.op, Elemwise)
503+
and isinstance(n.owner.op.scalar_op, ps.Exp)
504+
)
507505
]
508506
if len(rest) > 0:
509507
new_out = orig_op(new_out, *rest)

pytensor/tensor/rewriting/subtensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def local_subtensor_inc_subtensor(fgraph, node):
695695
"""
696696
if isinstance(node.op, Subtensor):
697697
x = node.inputs[0]
698-
if not x.owner or not isinstance(x.owner.op, IncSubtensor):
698+
if not (x.owner and isinstance(x.owner.op, IncSubtensor)):
699699
return
700700
if not x.owner.op.set_instead_of_inc:
701701
return
@@ -755,7 +755,7 @@ def local_subtensor_make_vector(fgraph, node):
755755

756756
x = node.inputs[0]
757757

758-
if not x.owner or not isinstance(x.owner.op, MakeVector):
758+
if not (x.owner and isinstance(x.owner.op, MakeVector)):
759759
return False
760760

761761
make_vector_op = x.owner.op

0 commit comments

Comments
 (0)