Skip to content

Commit 034a03f

Browse files
committed
Remove Mean Op
This Op does not really fit the CAReduce API, as it requires an extra bit of information (number of elements in the axis) during the loop. A better solution will be a fused Elemwise+CAReduce
1 parent 0824dba commit 034a03f

File tree

6 files changed

+3
-161
lines changed

6 files changed

+3
-161
lines changed

pytensor/link/numba/dispatch/elemwise.py

-6
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
Add,
3535
Composite,
3636
IntDiv,
37-
Mean,
3837
Mul,
3938
ScalarMaximum,
4039
ScalarMinimum,
@@ -77,11 +76,6 @@ def scalar_in_place_fn_Sub(op, idx, res, arr):
7776
return f"{res}[{idx}] -= {arr}"
7877

7978

80-
@scalar_in_place_fn.register(Mean)
81-
def scalar_in_place_fn_Mean(op, idx, res, arr):
82-
return f"{res}[{idx}] += ({arr} - {res}[{idx}]) / (i + 1)"
83-
84-
8579
@scalar_in_place_fn.register(Mul)
8680
def scalar_in_place_fn_Mul(op, idx, res, arr):
8781
return f"{res}[{idx}] *= {arr}"

pytensor/scalar/basic.py

-26
Original file line numberDiff line numberDiff line change
@@ -1871,32 +1871,6 @@ def L_op(self, inputs, outputs, gout):
18711871
add = Add(upcast_out, name="add")
18721872

18731873

1874-
class Mean(ScalarOp):
1875-
identity = 0
1876-
commutative = True
1877-
associative = False
1878-
nfunc_spec = ("mean", 2, 1)
1879-
nfunc_variadic = "mean"
1880-
1881-
def impl(self, *inputs):
1882-
return sum(inputs) / len(inputs)
1883-
1884-
def c_code(self, node, name, inputs, outputs, sub):
1885-
(z,) = outputs
1886-
if not inputs:
1887-
return f"{z} = 0;"
1888-
else:
1889-
return f"{z} = ({' + '.join(inputs)}) / ((double) {len(inputs)});"
1890-
1891-
def L_op(self, inputs, outputs, gout):
1892-
(gz,) = gout
1893-
retval = [gz / len(inputs)] * len(inputs)
1894-
return retval
1895-
1896-
1897-
mean = Mean(float_out, name="mean")
1898-
1899-
19001874
class Mul(ScalarOp):
19011875
identity = 1
19021876
commutative = True

pytensor/tensor/math.py

+1-76
Original file line numberDiff line numberDiff line change
@@ -1316,63 +1316,7 @@ def complex_from_polar(abs, angle):
13161316
"""Return complex-valued tensor from polar coordinate specification."""
13171317

13181318

1319-
class Mean(FixedOpCAReduce):
1320-
__props__ = ("axis",)
1321-
nfunc_spec = ("mean", 1, 1)
1322-
1323-
def __init__(self, axis=None):
1324-
super().__init__(ps.mean, axis)
1325-
assert self.axis is None or len(self.axis) == 1
1326-
1327-
def __str__(self):
1328-
if self.axis is not None:
1329-
args = ", ".join(str(x) for x in self.axis)
1330-
return f"Mean{{{args}}}"
1331-
else:
1332-
return "Mean"
1333-
1334-
def _output_dtype(self, idtype):
1335-
# we want to protect against overflow
1336-
return "float64"
1337-
1338-
def perform(self, node, inp, out):
1339-
(input,) = inp
1340-
(output,) = out
1341-
if self.axis is None:
1342-
axis = None
1343-
else:
1344-
axis = self.axis[0]
1345-
# numpy.asarray is needed as otherwise we can end up with a
1346-
# numpy scalar.
1347-
output[0] = np.asarray(np.mean(input, dtype="float64", axis=axis))
1348-
1349-
def c_code(self, node, name, inames, onames, sub):
1350-
ret = super().c_code(node, name, inames, onames, sub)
1351-
1352-
if self.axis is not None:
1353-
return ret
1354-
1355-
# TODO: c_code perform support only axis is None
1356-
return (
1357-
ret
1358-
+ f"""
1359-
*((double *)PyArray_DATA({onames[0]})) /= PyArray_SIZE({inames[0]});
1360-
"""
1361-
)
1362-
1363-
def clone(self, **kwargs):
1364-
axis = kwargs.get("axis", self.axis)
1365-
return type(self)(axis=axis)
1366-
1367-
1368-
# TODO: implement the grad. When done and tested, you can make this the default
1369-
# version.
1370-
# def grad(self, (x,), (gout,)):
1371-
# import pdb;pdb.set_trace()
1372-
# return grad(mean(x, self.axis, op=False),[x])
1373-
1374-
1375-
def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None):
1319+
def mean(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
13761320
"""
13771321
Computes the mean value along the given axis(es) of a tensor `input`.
13781322
@@ -1397,25 +1341,6 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, acc_dtype=None)
13971341
be in a float type). If None, then we use the same rules as `sum()`.
13981342
"""
13991343
input = as_tensor_variable(input)
1400-
if op:
1401-
if dtype not in (None, "float64"):
1402-
raise NotImplementedError(
1403-
"The Mean op does not support the dtype argument, "
1404-
"and will always use float64. If you want to specify "
1405-
"the dtype, call tensor.mean(..., op=False).",
1406-
dtype,
1407-
)
1408-
if acc_dtype not in (None, "float64"):
1409-
raise NotImplementedError(
1410-
"The Mean op does not support the acc_dtype argument, "
1411-
"and will always use float64. If you want to specify "
1412-
"acc_dtype, call tensor.mean(..., op=False).",
1413-
dtype,
1414-
)
1415-
out = Mean(axis)(input)
1416-
if keepdims:
1417-
out = makeKeepDims(input, out, axis)
1418-
return out
14191344

14201345
if dtype is not None:
14211346
# The summation will be done with the specified dtype.

tests/link/numba/test_elemwise.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pytensor.graph.basic import Constant
1717
from pytensor.graph.fg import FunctionGraph
1818
from pytensor.tensor.elemwise import DimShuffle
19-
from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
19+
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
2020
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
2121
from tests.link.numba.test_basic import (
2222
compare_numba_and_py,
@@ -256,18 +256,6 @@ def test_Dimshuffle_non_contiguous():
256256
0,
257257
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
258258
),
259-
(
260-
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
261-
0,
262-
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
263-
),
264-
(
265-
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
266-
0,
267-
set_test_value(
268-
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
269-
),
270-
),
271259
(
272260
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
273261
axis=axis, dtype=dtype, acc_dtype=acc_dtype

tests/scalar/test_basic.py

+1-30
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
log1p,
4444
log2,
4545
log10,
46-
mean,
4746
mul,
4847
neg,
4948
neq,
@@ -58,7 +57,7 @@
5857
true_div,
5958
uint8,
6059
)
61-
from pytensor.tensor.type import fscalar, imatrix, iscalar, matrix
60+
from pytensor.tensor.type import fscalar, imatrix, matrix
6261
from tests.link.test_link import make_function
6362

6463

@@ -521,34 +520,6 @@ def test_constant():
521520
assert c.dtype == "float32"
522521

523522

524-
@pytest.mark.parametrize("mode", [Mode("py"), Mode("cvm")])
525-
def test_mean(mode):
526-
a = iscalar("a")
527-
b = iscalar("b")
528-
z = mean(a, b)
529-
z_fn = pytensor.function([a, b], z, mode=mode)
530-
res = z_fn(1, 1)
531-
assert np.allclose(res, 1.0)
532-
533-
a = fscalar("a")
534-
b = fscalar("b")
535-
c = fscalar("c")
536-
537-
z = mean(a, b, c)
538-
539-
z_fn = pytensor.function([a, b, c], pytensor.grad(z, [a]), mode=mode)
540-
res = z_fn(3, 4, 5)
541-
assert np.allclose(res, 1 / 3)
542-
543-
z_fn = pytensor.function([a, b, c], pytensor.grad(z, [b]), mode=mode)
544-
res = z_fn(3, 4, 5)
545-
assert np.allclose(res, 1 / 3)
546-
547-
z = mean()
548-
z_fn = pytensor.function([], z, mode=mode)
549-
assert z_fn() == 0
550-
551-
552523
def test_shape():
553524
a = float32("a")
554525
assert isinstance(a.type, ScalarType)

tests/tensor/test_math.py

-10
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
Argmax,
4141
Dot,
4242
Max,
43-
Mean,
4443
Prod,
4544
ProdWithoutZeros,
4645
Sum,
@@ -2587,15 +2586,6 @@ def test_mod_compile():
25872586

25882587

25892588
class TestInferShape(utt.InferShapeTester):
2590-
def test_Mean(self):
2591-
adtens3 = dtensor3()
2592-
adtens3_val = random(3, 4, 5)
2593-
aiscal_val = 2
2594-
self._compile_and_check([adtens3], [Mean(None)(adtens3)], [adtens3_val], Mean)
2595-
self._compile_and_check(
2596-
[adtens3], [Mean(aiscal_val)(adtens3)], [adtens3_val], Mean
2597-
)
2598-
25992589
def test_Max(self):
26002590
adtens3 = dtensor3()
26012591
adtens3_val = random(4, 5, 3)

0 commit comments

Comments
 (0)