Skip to content

Commit 196564f

Browse files
committed
Make Maximum and Minimum variadic
1 parent ce5ff15 commit 196564f

File tree

6 files changed

+126
-57
lines changed

6 files changed

+126
-57
lines changed

pytensor/link/jax/dispatch/scalar.py

+18
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
Composite,
1515
Identity,
1616
IntDiv,
17+
Maximum,
18+
Minimum,
1719
Mod,
1820
Mul,
1921
ScalarOp,
@@ -172,6 +174,22 @@ def elemwise(x, y):
172174
return elemwise
173175

174176

177+
@jax_funcify.register(Maximum)
178+
def jax_funcify_scalar_Maximum(op, **kwargs):
179+
def elemwise(*inputs):
180+
return functools.reduce(jnp.maximum, inputs[1:], inputs[0])
181+
182+
return elemwise
183+
184+
185+
@jax_funcify.register(Minimum)
186+
def jax_funcify_scalar_Minimum(op, **kwargs):
187+
def elemwise(*inputs):
188+
return functools.reduce(jnp.minimum, inputs[1:], inputs[0])
189+
190+
return elemwise
191+
192+
175193
@jax_funcify.register(Cast)
176194
def jax_funcify_Cast(op, **kwargs):
177195
def cast(x):

pytensor/link/numba/dispatch/scalar.py

+35
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,23 @@
99
create_numba_signature,
1010
generate_fallback_impl,
1111
numba_funcify,
12+
numba_njit,
1213
)
1314
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
1415
from pytensor.link.utils import (
1516
compile_function_src,
1617
get_name_for_object,
1718
unique_name_generator,
1819
)
20+
from pytensor.scalar import discrete_dtypes
1921
from pytensor.scalar.basic import (
2022
Add,
2123
Cast,
2224
Clip,
2325
Composite,
2426
Identity,
27+
Maximum,
28+
Minimum,
2529
Mul,
2630
Reciprocal,
2731
ScalarOp,
@@ -186,6 +190,37 @@ def numba_funcify_Mul(op, node, **kwargs):
186190
return numba_basic.numba_njit(signature)(nary_add_fn)
187191

188192

193+
@numba_funcify.register(Maximum)
194+
@numba_funcify.register(Minimum)
195+
def numba_funcify_Extremum(op, node, **kwargs):
196+
input_names = [f"x{i}" for i in range(len(node.inputs))]
197+
input_signature = ", ".join(input_names)
198+
assert len(input_names) > 0
199+
200+
inner_code = f"res = {input_names[0]}\n"
201+
202+
if isinstance(op, Maximum):
203+
op = ">"
204+
func_name = "maximum"
205+
else:
206+
op = "<"
207+
func_name = "minimum"
208+
209+
if all(inp.dtype in discrete_dtypes for inp in node.inputs):
210+
for x in input_names[1:]:
211+
inner_code += f" res = {x} if {x} {op} res else res\n"
212+
else:
213+
for x in input_names[1:]:
214+
inner_code += f" res = {x} if {x} {op} res else (res if res {op}= {x} else np.nan)\n"
215+
inner_code += " return res"
216+
217+
src = f"""
218+
def {func_name}({input_signature}):
219+
{inner_code}
220+
"""
221+
return numba_njit(compile_function_src(src, func_name, globals() | {"np": np}))
222+
223+
189224
@numba_funcify.register(Cast)
190225
def numba_funcify_Cast(op, node, **kwargs):
191226
dtype = np.dtype(op.o_type.dtype)

pytensor/scalar/basic.py

+67-49
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import math
1515
from collections.abc import Callable
1616
from copy import copy
17+
from functools import reduce
1718
from itertools import chain
1819
from textwrap import dedent
1920
from typing import Any, TypeAlias
@@ -1868,99 +1869,116 @@ def c_code(self, node, name, inputs, outputs, sub):
18681869
##############
18691870
# Arithmetic
18701871
##############
1871-
class Maximum(BinaryScalarOp):
1872+
class AtLeastUnaryOp(ScalarOp):
1873+
def make_node(self, *inputs):
1874+
if len(inputs) == 0:
1875+
raise TypeError(f"{self} requires at least 1 input: got 0")
1876+
return super().make_node(*inputs)
1877+
1878+
1879+
class Maximum(AtLeastUnaryOp):
18721880
commutative = True
18731881
associative = True
1874-
nfunc_spec = ("maximum", 2, 1)
1875-
nfunc_variadic = "maximum"
1882+
nfunc_variadic = "max"
18761883
identity = -np.inf
18771884

18781885
def impl(self, *inputs):
18791886
# The built-in max function don't support complex type
1880-
return np.maximum(*inputs)
1887+
return reduce(np.maximum, inputs)
18811888

18821889
def c_code(self, node, name, inputs, outputs, sub):
1883-
(x, y) = inputs
1884-
(z,) = outputs
18851890
if any(i.type in complex_types for i in node.inputs):
18861891
raise NotImplementedError()
1887-
if all(i.type in discrete_dtypes for i in node.inputs):
1888-
return f"{z} = (({y})>({x})? ({y}): (({x});"
1892+
1893+
x, *ys = inputs
1894+
[z] = outputs
1895+
1896+
# We need an intermediate variable in case we are working inplace
1897+
tmp = f"{z}_tmp"
1898+
res = f"{node.outputs[0].type.dtype_specs()[1]} {tmp} = ({x});"
1899+
if all(i.dtype in discrete_dtypes for i in node.inputs):
1900+
for y in ys:
1901+
res += f"\n{tmp} = (({y}) > {tmp})? ({y}): {tmp};"
18891902
else:
1890-
# Test for both y>x and x>=y to detect NaN
1891-
return f'{z} = (({y})>({x})? ({y}): (({x})>=({y})? ({x}): nan("")));'
1903+
# Need to check for nans
1904+
for y in ys:
1905+
res += (
1906+
f"\n{tmp} = (({y}) > {tmp})? ({y}): (({tmp} >= ({y}))? {tmp}: NAN);"
1907+
)
1908+
res += f"\n{z} = {tmp};"
1909+
return res
18921910

18931911
def c_code_cache_version(self):
1894-
return (1,)
1912+
return (2,)
18951913

18961914
def L_op(self, inputs, outputs, gout):
1897-
(x, y) = inputs
1898-
(gz,) = gout
1915+
[gz] = gout
18991916
if gz.type in complex_types:
19001917
# max is currently defined for complex_types,
19011918
# but the gradient for complex is not.
19021919
raise NotImplementedError()
19031920

1904-
if outputs[0].type in discrete_types:
1905-
return [
1906-
x.zeros_like(dtype=config.floatX),
1907-
y.zeros_like(dtype=config.floatX),
1908-
]
1909-
# This form handle the case when both value are the same.
1910-
# In that case, gx will be gz, gy will be 0.
1911-
e = eq(outputs[0], x)
1912-
gx = e * gz
1913-
gy = (constant(1, dtype=gz.dtype) - e) * gz
1914-
return (gx, gy)
1921+
[out] = outputs
1922+
1923+
if out.type in discrete_types:
1924+
return [inp.zeros_like(dtype=config.floatX) for inp in inputs]
1925+
1926+
# We propagate the gradient to the maximum value(s) in the input
1927+
return [eq(inp, out) * gz for inp in inputs]
19151928

19161929

19171930
maximum = Maximum(upcast_out, name="maximum")
19181931

19191932

1920-
class Minimum(BinaryScalarOp):
1933+
class Minimum(AtLeastUnaryOp):
19211934
commutative = True
19221935
associative = True
1923-
nfunc_spec = ("minimum", 2, 1)
1924-
nfunc_variadic = "minimum"
1936+
nfunc_variadic = "min"
19251937
identity = np.inf
19261938

19271939
def impl(self, *inputs):
19281940
# The built-in min function don't support complex type
1929-
return np.minimum(*inputs)
1941+
return reduce(np.minimum, inputs)
19301942

19311943
def c_code(self, node, name, inputs, outputs, sub):
1932-
(x, y) = inputs
1933-
(z,) = outputs
19341944
if any(i.type in complex_types for i in node.inputs):
19351945
raise NotImplementedError()
1936-
if all(i.type in discrete_dtypes for i in node.inputs):
1937-
return f"{z} = (({y})<({x})? ({y}): (({x});"
1946+
1947+
x, *ys = inputs
1948+
[z] = outputs
1949+
1950+
# We need an intermediate variable in case we are working inplace
1951+
tmp = f"{z}_tmp"
1952+
res = f"{node.outputs[0].type.dtype_specs()[1]} {tmp} = ({x});"
1953+
if all(i.dtype in discrete_dtypes for i in node.inputs):
1954+
for y in ys:
1955+
res += f"\n{tmp} = (({y}) < {tmp})? ({y}): {tmp};"
19381956
else:
1939-
# Second check catches `NAN`s
1940-
return f'{z} = (({y})<({x})? ({y}): (({x})<=({y})? ({x}): nan("")));'
1957+
# Need to check for nans
1958+
for y in ys:
1959+
res += (
1960+
f"\n{tmp} = (({y}) < {tmp})? ({y}): (({tmp} <= ({y}))? {tmp}: NAN);"
1961+
)
1962+
res += f"\n{z} = {tmp};"
1963+
return res
19411964

19421965
def c_code_cache_version(self):
1943-
return (1,)
1966+
return (2,)
19441967

19451968
def L_op(self, inputs, outputs, gout):
1946-
(x, y) = inputs
1947-
(gz,) = gout
1969+
[gz] = gout
19481970
if gz.type in complex_types:
1949-
# min is currently defined for complex_types,
1971+
# max is currently defined for complex_types,
19501972
# but the gradient for complex is not.
19511973
raise NotImplementedError()
19521974

1953-
if outputs[0].type in discrete_types:
1954-
return [
1955-
x.zeros_like(dtype=config.floatX),
1956-
y.zeros_like(dtype=config.floatX),
1957-
]
1958-
# This form handle the case when both value are the same.
1959-
# In that case, gx will be gz, gy will be 0.
1960-
e = eq(outputs[0], x)
1961-
gx = e * gz
1962-
gy = (constant(1, dtype=gz.dtype) - e) * gz
1963-
return (gx, gy)
1975+
[out] = outputs
1976+
1977+
if out.type in discrete_types:
1978+
return [inp.zeros_like(dtype=config.floatX) for inp in inputs]
1979+
1980+
# We propagate the gradient to the minimum value(s) in the input
1981+
return [eq(inp, out) * gz for inp in inputs]
19641982

19651983

19661984
minimum = Minimum(upcast_out, name="minimum")

pytensor/tensor/inplace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,12 @@ def second_inplace(a):
357357
pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="]))
358358

359359

360-
@scalar_elemwise(symbolname="scalar_maximum_inplace")
360+
@scalar_elemwise
361361
def maximum_inplace(a, b):
362362
"""elementwise addition (inplace on `a`)"""
363363

364364

365-
@scalar_elemwise(symbolname="scalar_minimum_inplace")
365+
@scalar_elemwise
366366
def minimum_inplace(a, b):
367367
"""elementwise addition (inplace on `a`)"""
368368

pytensor/tensor/subtensor.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def analyze(x):
384384
)
385385
is_stop_length = (
386386
stop is None
387-
or stop in [length, sys.maxsize]
387+
or stop == length
388388
or (is_stop_constant and is_length_constant and stop >= length)
389389
)
390390
if is_start_0:
@@ -1036,7 +1036,7 @@ def infer_shape(self, fgraph, node, shapes):
10361036
b_pos,
10371037
# The [-a: b] is peculiar, the slice length actually decreases for larger arrays
10381038
# The branch -a is useless when b - a / 2 <= -a. Similar for the branch b
1039-
minimum(minimum(xl, b - a - xl), minimum(-a, b)), # [-a: b]
1039+
minimum(xl, b - a - xl, -a, b), # [-a: b]
10401040
minimum(b - a, xl + b), # [-a: -b]
10411041
),
10421042
)
@@ -1046,9 +1046,7 @@ def infer_shape(self, fgraph, node, shapes):
10461046
_eager_switch(
10471047
b_pos,
10481048
minimum(a - b, xl - b - one), # [a: b]
1049-
minimum(
1050-
minimum(xl, a - (xl + b)), minimum(a + one, -b - one)
1051-
), # [a: -b]
1049+
minimum(xl, a - (xl + b), a + one, -b - one), # [a: -b]
10521050
),
10531051
_eager_switch(
10541052
b_pos,

tests/tensor/test_subtensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,7 @@ def test_constant_params(self, a, b, step):
14661466
fg = FunctionGraph(outputs=[y], clone=False)
14671467
rewrite_graph(fg, include=("ShapeOpt", "canonicalize"), clone=False)
14681468
assert not any(isinstance(node.op, Subtensor) for node in fg.apply_nodes)
1469-
assert len(fg.apply_nodes) <= 9
1469+
assert len(fg.apply_nodes) <= 8
14701470

14711471
fn = pytensor.function(
14721472
[x],

0 commit comments

Comments
 (0)