Skip to content

Commit ce5ff15

Browse files
committed
Do not bother with nan in discrete Maximum/Minimum
1 parent 8147011 commit ce5ff15

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

pytensor/scalar/basic.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,8 +1884,14 @@ def c_code(self, node, name, inputs, outputs, sub):
18841884
(z,) = outputs
18851885
if any(i.type in complex_types for i in node.inputs):
18861886
raise NotImplementedError()
1887-
# Test for both y>x and x>=y to detect NaN
1888-
return f'{z} = (({y})>({x})? ({y}): (({x})>=({y})? ({x}): nan("")));'
1887+
if all(i.type in discrete_dtypes for i in node.inputs):
1888+
return f"{z} = (({y})>({x})? ({y}): (({x});"
1889+
else:
1890+
# Test for both y>x and x>=y to detect NaN
1891+
return f'{z} = (({y})>({x})? ({y}): (({x})>=({y})? ({x}): nan("")));'
1892+
1893+
def c_code_cache_version(self):
1894+
return (1,)
18891895

18901896
def L_op(self, inputs, outputs, gout):
18911897
(x, y) = inputs
@@ -1927,7 +1933,14 @@ def c_code(self, node, name, inputs, outputs, sub):
19271933
(z,) = outputs
19281934
if any(i.type in complex_types for i in node.inputs):
19291935
raise NotImplementedError()
1930-
return f'{z} = (({y})<({x})? ({y}): (({x})<=({y})? ({x}): nan("")));'
1936+
if all(i.type in discrete_dtypes for i in node.inputs):
1937+
return f"{z} = (({y})<({x})? ({y}): (({x});"
1938+
else:
1939+
# Second check catches `NAN`s
1940+
return f'{z} = (({y})<({x})? ({y}): (({x})<=({y})? ({x}): nan("")));'
1941+
1942+
def c_code_cache_version(self):
1943+
return (1,)
19311944

19321945
def L_op(self, inputs, outputs, gout):
19331946
(x, y) = inputs

0 commit comments

Comments
 (0)