|
14 | 14 | import math
|
15 | 15 | from collections.abc import Callable
|
16 | 16 | from copy import copy
|
| 17 | +from functools import reduce |
17 | 18 | from itertools import chain
|
18 | 19 | from textwrap import dedent
|
19 | 20 | from typing import Any, TypeAlias
|
@@ -1868,99 +1869,116 @@ def c_code(self, node, name, inputs, outputs, sub):
|
1868 | 1869 | ##############
|
1869 | 1870 | # Arithmetic
|
1870 | 1871 | ##############
|
1871 |
| -class Maximum(BinaryScalarOp): |
| 1872 | +class AtLeastUnaryOp(ScalarOp): |
| 1873 | + def make_node(self, *inputs): |
| 1874 | + if len(inputs) == 0: |
| 1875 | + raise TypeError(f"{self} requires at least 1 input: got 0") |
| 1876 | + return super().make_node(*inputs) |
| 1877 | + |
| 1878 | + |
| 1879 | +class Maximum(AtLeastUnaryOp): |
1872 | 1880 | commutative = True
|
1873 | 1881 | associative = True
|
1874 |
| - nfunc_spec = ("maximum", 2, 1) |
1875 |
| - nfunc_variadic = "maximum" |
| 1882 | + nfunc_variadic = "max" |
1876 | 1883 | identity = -np.inf
|
1877 | 1884 |
|
1878 | 1885 | def impl(self, *inputs):
|
1879 | 1886 | # The built-in max function don't support complex type
|
1880 |
| - return np.maximum(*inputs) |
| 1887 | + return reduce(np.maximum, inputs) |
1881 | 1888 |
|
1882 | 1889 | def c_code(self, node, name, inputs, outputs, sub):
|
1883 |
| - (x, y) = inputs |
1884 |
| - (z,) = outputs |
1885 | 1890 | if any(i.type in complex_types for i in node.inputs):
|
1886 | 1891 | raise NotImplementedError()
|
1887 |
| - if all(i.type in discrete_dtypes for i in node.inputs): |
1888 |
| - return f"{z} = (({y})>({x})? ({y}): (({x});" |
| 1892 | + |
| 1893 | + x, *ys = inputs |
| 1894 | + [z] = outputs |
| 1895 | + |
| 1896 | + # We need an intermediate variable in case we are working inplace |
| 1897 | + tmp = f"{z}_tmp" |
| 1898 | + res = f"{node.outputs[0].type.dtype_specs()[1]} {tmp} = ({x});" |
| 1899 | + if all(i.dtype in discrete_dtypes for i in node.inputs): |
| 1900 | + for y in ys: |
| 1901 | + res += f"\n{tmp} = (({y}) > {tmp})? ({y}): {tmp};" |
1889 | 1902 | else:
|
1890 |
| - # Test for both y>x and x>=y to detect NaN |
1891 |
| - return f'{z} = (({y})>({x})? ({y}): (({x})>=({y})? ({x}): nan("")));' |
| 1903 | + # Need to check for nans |
| 1904 | + for y in ys: |
| 1905 | + res += ( |
| 1906 | + f"\n{tmp} = (({y}) > {tmp})? ({y}): (({tmp} >= ({y}))? {tmp}: NAN);" |
| 1907 | + ) |
| 1908 | + res += f"\n{z} = {tmp};" |
| 1909 | + return res |
1892 | 1910 |
|
1893 | 1911 | def c_code_cache_version(self):
|
1894 |
| - return (1,) |
| 1912 | + return (2,) |
1895 | 1913 |
|
1896 | 1914 | def L_op(self, inputs, outputs, gout):
|
1897 |
| - (x, y) = inputs |
1898 |
| - (gz,) = gout |
| 1915 | + [gz] = gout |
1899 | 1916 | if gz.type in complex_types:
|
1900 | 1917 | # max is currently defined for complex_types,
|
1901 | 1918 | # but the gradient for complex is not.
|
1902 | 1919 | raise NotImplementedError()
|
1903 | 1920 |
|
1904 |
| - if outputs[0].type in discrete_types: |
1905 |
| - return [ |
1906 |
| - x.zeros_like(dtype=config.floatX), |
1907 |
| - y.zeros_like(dtype=config.floatX), |
1908 |
| - ] |
1909 |
| - # This form handle the case when both value are the same. |
1910 |
| - # In that case, gx will be gz, gy will be 0. |
1911 |
| - e = eq(outputs[0], x) |
1912 |
| - gx = e * gz |
1913 |
| - gy = (constant(1, dtype=gz.dtype) - e) * gz |
1914 |
| - return (gx, gy) |
| 1921 | + [out] = outputs |
| 1922 | + |
| 1923 | + if out.type in discrete_types: |
| 1924 | + return [inp.zeros_like(dtype=config.floatX) for inp in inputs] |
| 1925 | + |
| 1926 | + # We propagate the gradient to the maximum value(s) in the input |
| 1927 | + return [eq(inp, out) * gz for inp in inputs] |
1915 | 1928 |
|
1916 | 1929 |
|
1917 | 1930 | maximum = Maximum(upcast_out, name="maximum")
|
1918 | 1931 |
|
1919 | 1932 |
|
1920 |
| -class Minimum(BinaryScalarOp): |
| 1933 | +class Minimum(AtLeastUnaryOp): |
1921 | 1934 | commutative = True
|
1922 | 1935 | associative = True
|
1923 |
| - nfunc_spec = ("minimum", 2, 1) |
1924 |
| - nfunc_variadic = "minimum" |
| 1936 | + nfunc_variadic = "min" |
1925 | 1937 | identity = np.inf
|
1926 | 1938 |
|
1927 | 1939 | def impl(self, *inputs):
|
1928 | 1940 | # The built-in min function don't support complex type
|
1929 |
| - return np.minimum(*inputs) |
| 1941 | + return reduce(np.minimum, inputs) |
1930 | 1942 |
|
1931 | 1943 | def c_code(self, node, name, inputs, outputs, sub):
|
1932 |
| - (x, y) = inputs |
1933 |
| - (z,) = outputs |
1934 | 1944 | if any(i.type in complex_types for i in node.inputs):
|
1935 | 1945 | raise NotImplementedError()
|
1936 |
| - if all(i.type in discrete_dtypes for i in node.inputs): |
1937 |
| - return f"{z} = (({y})<({x})? ({y}): (({x});" |
| 1946 | + |
| 1947 | + x, *ys = inputs |
| 1948 | + [z] = outputs |
| 1949 | + |
| 1950 | + # We need an intermediate variable in case we are working inplace |
| 1951 | + tmp = f"{z}_tmp" |
| 1952 | + res = f"{node.outputs[0].type.dtype_specs()[1]} {tmp} = ({x});" |
| 1953 | + if all(i.dtype in discrete_dtypes for i in node.inputs): |
| 1954 | + for y in ys: |
| 1955 | + res += f"\n{tmp} = (({y}) < {tmp})? ({y}): {tmp};" |
1938 | 1956 | else:
|
1939 |
| - # Second check catches `NAN`s |
1940 |
| - return f'{z} = (({y})<({x})? ({y}): (({x})<=({y})? ({x}): nan("")));' |
| 1957 | + # Need to check for nans |
| 1958 | + for y in ys: |
| 1959 | + res += ( |
| 1960 | + f"\n{tmp} = (({y}) < {tmp})? ({y}): (({tmp} <= ({y}))? {tmp}: NAN);" |
| 1961 | + ) |
| 1962 | + res += f"\n{z} = {tmp};" |
| 1963 | + return res |
1941 | 1964 |
|
1942 | 1965 | def c_code_cache_version(self):
|
1943 |
| - return (1,) |
| 1966 | + return (2,) |
1944 | 1967 |
|
1945 | 1968 | def L_op(self, inputs, outputs, gout):
|
1946 |
| - (x, y) = inputs |
1947 |
| - (gz,) = gout |
| 1969 | + [gz] = gout |
1948 | 1970 | if gz.type in complex_types:
|
1949 |
| - # min is currently defined for complex_types, |
| 1971 | + # max is currently defined for complex_types, |
1950 | 1972 | # but the gradient for complex is not.
|
1951 | 1973 | raise NotImplementedError()
|
1952 | 1974 |
|
1953 |
| - if outputs[0].type in discrete_types: |
1954 |
| - return [ |
1955 |
| - x.zeros_like(dtype=config.floatX), |
1956 |
| - y.zeros_like(dtype=config.floatX), |
1957 |
| - ] |
1958 |
| - # This form handle the case when both value are the same. |
1959 |
| - # In that case, gx will be gz, gy will be 0. |
1960 |
| - e = eq(outputs[0], x) |
1961 |
| - gx = e * gz |
1962 |
| - gy = (constant(1, dtype=gz.dtype) - e) * gz |
1963 |
| - return (gx, gy) |
| 1975 | + [out] = outputs |
| 1976 | + |
| 1977 | + if out.type in discrete_types: |
| 1978 | + return [inp.zeros_like(dtype=config.floatX) for inp in inputs] |
| 1979 | + |
| 1980 | + # We propagate the gradient to the minimum value(s) in the input |
| 1981 | + return [eq(inp, out) * gz for inp in inputs] |
1964 | 1982 |
|
1965 | 1983 |
|
1966 | 1984 | minimum = Minimum(upcast_out, name="minimum")
|
|
0 commit comments