Description
The requirement to upcast sum(x)
to the default floating-point dtype with the default dtype=None
currently says (from the sum spec):
If x
has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
The rationale given is "keyword argument is intended to help prevent data type overflows.". This came up again in the review of NEP 56 (numpy/numpy#25542), and is basically the only part of the standard that was flagged as problematic and explicitly rejected.
I agree that the standard's choice here is problematic, at least from a practical perspective: no array library does this, and none are planning to implement this. And the rationale is pretty weak, it just does not apply to floating-point dtypes to a similar extent as it does to integer dtypes (and for integers, array libraries do implement the upcasting). Examples:
>>> # NumPy:
>>> np.sum(np.ones(3, dtype=np.float32)).dtype
dtype('float32')
>>> np.sum(np.ones(3, dtype=np.int32)).dtype
dtype('int64')
>>> # PyTorch:
>>> torch.sum(torch.ones(2, dtype=torch.bfloat16)).dtype
torch.bfloat16
>>> torch.sum(torch.ones(2, dtype=torch.int16)).dtype
torch.int64
>>> # JAX:
>>> jnp.sum(jnp.ones(4, dtype=jnp.float16)).dtype
dtype('float16')
>>> jnp.sum(jnp.ones(4, dtype=jnp.int16)).dtype
dtype('int32')
>>> # CuPy:
>>> cp.sum(cp.ones(5, dtype=cp.float16)).dtype
dtype('float16')
>>> cp.sum(cp.ones(5, dtype=cp.int32)).dtype
dtype('int64')
>>> # Dask:
>>> da.sum(da.ones(6, dtype=np.float32)).dtype
dtype('float32')
>>> da.sum(da.ones(6, dtype=np.int32)).dtype
dtype('int64')
>>>
The most relevant conversation is #238 (comment). There was some further minor tweaks (without much discussion) in gh-666.
Proposed resolution: align the standard with what all known array libraries implement today.