Skip to content

Fix float utils #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{% set version = "1.3.1" %}
{% set version = "1.3.3" %}
{% set buildnumber = 0 %}

package:
Expand Down
48 changes: 26 additions & 22 deletions mkl_fft/_float_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from numpy import (half, float32, asarray, ndarray,
longdouble, float64, longcomplex, complex_, float128, complex256)
import numpy as np

__all__ = ['__upcast_float16_array', '__downcast_float128_array', '__supported_array_or_not_implemented']

Expand All @@ -35,18 +34,18 @@ def __upcast_float16_array(x):
instead of float64, as mkl_fft would do"""
if hasattr(x, "dtype"):
xdt = x.dtype
if xdt == half:
if xdt == np.half:
# no half-precision routines, so convert to single precision
return asarray(x, dtype=float32)
if xdt == longdouble and not xdt == float64:
return np.asarray(x, dtype=np.float32)
if xdt == np.longdouble and not xdt == np.float64:
raise ValueError("type %s is not supported" % xdt)
if not isinstance(x, ndarray):
__x = asarray(x)
if not isinstance(x, np.ndarray):
__x = np.asarray(x)
xdt = __x.dtype
if xdt == half:
if xdt == np.half:
# no half-precision routines, so convert to single precision
return asarray(__x, dtype=float32)
if xdt == longdouble and not xdt == float64:
return np.asarray(__x, dtype=np.float32)
if xdt == np.longdouble and not xdt == np.float64:
raise ValueError("type %s is not supported" % xdt)
return __x
return x
Expand All @@ -58,17 +57,17 @@ def __downcast_float128_array(x):
complex128, instead of raising an error"""
if hasattr(x, "dtype"):
xdt = x.dtype
if xdt == longdouble and not xdt == float64:
return asarray(x, dtype=float64)
elif xdt == longcomplex and not xdt == complex_:
return asarray(x, dtype=complex_)
if not isinstance(x, ndarray):
__x = asarray(x)
if xdt == np.longdouble and not xdt == np.float64:
return np.asarray(x, dtype=np.float64)
elif xdt == np.longcomplex and not xdt == np.complex_:
return np.asarray(x, dtype=np.complex_)
if not isinstance(x, np.ndarray):
__x = np.asarray(x)
xdt = __x.dtype
if xdt == longdouble and not xdt == float64:
return asarray(x, dtype=float64)
elif xdt == longcomplex and not xdt == complex_:
return asarray(x, dtype=complex_)
if xdt == np.longdouble and not xdt == np.float64:
return np.asarray(x, dtype=np.float64)
elif xdt == np.longcomplex and not xdt == np.complex_:
return np.asarray(x, dtype=np.complex_)
return __x
return x

Expand All @@ -78,7 +77,12 @@ def __supported_array_or_not_implemented(x):
Used in _scipy_fft_backend to convert array to float32,
float64, complex64, or complex128 type or return NotImplemented
"""
__x = asarray(x)
if __x.dtype in [half, float128, complex256]:
__x = np.asarray(x)
black_list = [np.half]
if hasattr(np, 'float128'):
black_list.append(np.float128)
if hasattr(np, 'complex256'):
black_list.append(np.complex256)
if __x.dtype in black_list:
return NotImplemented
return __x
2 changes: 1 addition & 1 deletion mkl_fft/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.3.1'
__version__ = '1.3.3'
10 changes: 9 additions & 1 deletion mkl_fft/tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,15 @@ def test_numpy_rftn(norm, dtype):
assert np.allclose(x, xx, atol=tol, rtol=tol)


@pytest.mark.parametrize('dtype', [np.float16, np.float128, np.complex256])
def _get_blacklisted_dtypes():
bl_list = []
for dt in ['float16', 'float128', 'complex256']:
if hasattr(np, dt):
bl_list.append(getattr(np, dt))
return bl_list


@pytest.mark.parametrize('dtype', _get_blacklisted_dtypes())
def test_scipy_no_support_for(dtype):
x = np.ones(16, dtype=dtype)
w = mfi.scipy_fft.fft(x)
Expand Down