@@ -1884,8 +1884,14 @@ def c_code(self, node, name, inputs, outputs, sub):
1884
1884
(z ,) = outputs
1885
1885
if any (i .type in complex_types for i in node .inputs ):
1886
1886
raise NotImplementedError ()
1887
- # Test for both y>x and x>=y to detect NaN
1888
- return f'{ z } = (({ y } )>({ x } )? ({ y } ): (({ x } )>=({ y } )? ({ x } ): nan("")));'
1887
+ if all (i .type in discrete_dtypes for i in node .inputs ):
1888
+ return f"{ z } = (({ y } )>({ x } )? ({ y } ): (({ x } );"
1889
+ else :
1890
+ # Test for both y>x and x>=y to detect NaN
1891
+ return f'{ z } = (({ y } )>({ x } )? ({ y } ): (({ x } )>=({ y } )? ({ x } ): nan("")));'
1892
+
1893
+ def c_code_cache_version (self ):
1894
+ return (1 ,)
1889
1895
1890
1896
def L_op (self , inputs , outputs , gout ):
1891
1897
(x , y ) = inputs
@@ -1927,7 +1933,14 @@ def c_code(self, node, name, inputs, outputs, sub):
1927
1933
(z ,) = outputs
1928
1934
if any (i .type in complex_types for i in node .inputs ):
1929
1935
raise NotImplementedError ()
1930
- return f'{ z } = (({ y } )<({ x } )? ({ y } ): (({ x } )<=({ y } )? ({ x } ): nan("")));'
1936
+ if all (i .type in discrete_dtypes for i in node .inputs ):
1937
+ return f"{ z } = (({ y } )<({ x } )? ({ y } ): (({ x } );"
1938
+ else :
1939
+ # Second check catches `NAN`s
1940
+ return f'{ z } = (({ y } )<({ x } )? ({ y } ): (({ x } )<=({ y } )? ({ x } ): nan("")));'
1941
+
1942
+ def c_code_cache_version (self ):
1943
+ return (1 ,)
1931
1944
1932
1945
def L_op (self , inputs , outputs , gout ):
1933
1946
(x , y ) = inputs
0 commit comments