2
2
from operator import mul
3
3
from math import sqrt
4
4
import itertools
5
- from typing import Tuple
5
+ from typing import Tuple , Optional , List
6
6
7
7
from hypothesis import assume
8
8
from hypothesis .strategies import (lists , integers , sampled_from ,
9
9
shared , floats , just , composite , one_of ,
10
- none , booleans )
11
- from hypothesis .strategies ._internal .strategies import SearchStrategy
10
+ none , booleans , SearchStrategy )
12
11
13
12
from .pytest_helpers import nargs
14
13
from .array_helpers import ndindex
14
+ from .typing import DataType , Shape
15
15
from . import dtype_helpers as dh
16
16
from ._array_module import (full , float32 , float64 , bool as bool_dtype ,
17
17
_UndefinedStub , eye , broadcast_to )
50
50
_dtype_categories = [(xp .bool ,), dh .uint_dtypes , dh .int_dtypes , dh .float_dtypes ]
51
51
_sorted_dtypes = [d for category in _dtype_categories for d in category ]
52
52
53
- def _dtypes_sorter (dtype_pair ):
53
+ def _dtypes_sorter (dtype_pair : Tuple [ DataType , DataType ] ):
54
54
dtype1 , dtype2 = dtype_pair
55
55
if dtype1 == dtype2 :
56
56
return _sorted_dtypes .index (dtype1 )
@@ -67,7 +67,7 @@ def _dtypes_sorter(dtype_pair):
67
67
key += 1
68
68
return key
69
69
70
- promotable_dtypes = sorted (dh .promotion_table .keys (), key = _dtypes_sorter )
70
+ promotable_dtypes : List [ Tuple [ DataType , DataType ]] = sorted (dh .promotion_table .keys (), key = _dtypes_sorter )
71
71
72
72
if FILTER_UNDEFINED_DTYPES :
73
73
promotable_dtypes = [
@@ -77,10 +77,34 @@ def _dtypes_sorter(dtype_pair):
77
77
]
78
78
79
79
80
- def mutually_promotable_dtypes (dtype_objs = dh .all_dtypes ):
81
- return sampled_from (
82
- [(i , j ) for i , j in promotable_dtypes if i in dtype_objs and j in dtype_objs ]
83
- )
80
+ def mutually_promotable_dtypes (
81
+ max_size : Optional [int ] = 2 ,
82
+ * ,
83
+ dtypes : Tuple [DataType , ...] = dh .all_dtypes ,
84
+ ) -> SearchStrategy [Tuple [DataType , ...]]:
85
+ if max_size == 2 :
86
+ return sampled_from (
87
+ [(i , j ) for i , j in promotable_dtypes if i in dtypes and j in dtypes ]
88
+ )
89
+ if isinstance (max_size , int ) and max_size < 2 :
90
+ raise ValueError (f'{ max_size = } should be >=2' )
91
+ strats = []
92
+ category_samples = {
93
+ category : [d for d in dtypes if d in category ] for category in _dtype_categories
94
+ }
95
+ for samples in category_samples .values ():
96
+ if len (samples ) > 0 :
97
+ strat = lists (sampled_from (samples ), min_size = 2 , max_size = max_size )
98
+ strats .append (strat )
99
+ if len (category_samples [dh .uint_dtypes ]) > 0 and len (category_samples [dh .int_dtypes ]) > 0 :
100
+ mixed_samples = category_samples [dh .uint_dtypes ] + category_samples [dh .int_dtypes ]
101
+ strat = lists (sampled_from (mixed_samples ), min_size = 2 , max_size = max_size )
102
+ if xp .uint64 in mixed_samples :
103
+ strat = strat .filter (
104
+ lambda l : not (xp .uint64 in l and any (d in dh .int_dtypes for d in l ))
105
+ )
106
+ return one_of (strats ).map (tuple )
107
+
84
108
85
109
# shared() allows us to draw either the function or the function name and they
86
110
# will both correspond to the same function.
@@ -113,15 +137,19 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False)
113
137
114
138
# Use this to avoid memory errors with NumPy.
115
139
# See https://github.com/numpy/numpy/issues/15753
116
- shapes = xps .array_shapes (min_dims = 0 , min_side = 0 ).filter (
117
- lambda shape : prod (i for i in shape if i ) < MAX_ARRAY_SIZE
118
- )
140
+ def shapes (** kw ):
141
+ kw .setdefault ('min_dims' , 0 )
142
+ kw .setdefault ('min_side' , 0 )
143
+ return xps .array_shapes (** kw ).filter (
144
+ lambda shape : prod (i for i in shape if i ) < MAX_ARRAY_SIZE
145
+ )
146
+
119
147
120
148
one_d_shapes = xps .array_shapes (min_dims = 1 , max_dims = 1 , min_side = 0 , max_side = SQRT_MAX_ARRAY_SIZE )
121
149
122
150
# Matrix shapes assume stacks of matrices
123
151
@composite
124
- def matrix_shapes (draw , stack_shapes = shapes ):
152
+ def matrix_shapes (draw , stack_shapes = shapes () ):
125
153
stack_shape = draw (stack_shapes )
126
154
mat_shape = draw (xps .array_shapes (max_dims = 2 , min_dims = 2 ))
127
155
shape = stack_shape + mat_shape
@@ -135,9 +163,11 @@ def matrix_shapes(draw, stack_shapes=shapes):
135
163
elements = dict (allow_nan = False ,
136
164
allow_infinity = False ))
137
165
138
- def mutually_broadcastable_shapes (num_shapes : int ) -> SearchStrategy [Tuple [Tuple ]]:
166
+ def mutually_broadcastable_shapes (
167
+ num_shapes : int , ** kw
168
+ ) -> SearchStrategy [Tuple [Shape , ...]]:
139
169
return (
140
- xps .mutually_broadcastable_shapes (num_shapes )
170
+ xps .mutually_broadcastable_shapes (num_shapes , ** kw )
141
171
.map (lambda BS : BS .input_shapes )
142
172
.filter (lambda shapes : all (
143
173
prod (i for i in s if i > 0 ) < MAX_ARRAY_SIZE for s in shapes
@@ -164,13 +194,13 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
164
194
# using something like
165
195
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
166
196
n = draw (integers (0 ))
167
- shape = draw (shapes ) + (n , n )
197
+ shape = draw (shapes () ) + (n , n )
168
198
assume (prod (i for i in shape if i ) < MAX_ARRAY_SIZE )
169
199
dtype = draw (dtypes )
170
200
return broadcast_to (eye (n , dtype = dtype ), shape )
171
201
172
202
@composite
173
- def invertible_matrices (draw , dtypes = xps .floating_dtypes (), stack_shapes = shapes ):
203
+ def invertible_matrices (draw , dtypes = xps .floating_dtypes (), stack_shapes = shapes () ):
174
204
# For now, just generate stacks of diagonal matrices.
175
205
n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ),)
176
206
stack_shape = draw (stack_shapes )
@@ -318,9 +348,10 @@ def multiaxis_indices(draw, shapes):
318
348
319
349
320
350
def two_mutual_arrays (
321
- dtype_objs = dh .all_dtypes , two_shapes = two_mutually_broadcastable_shapes
322
- ):
323
- mutual_dtypes = shared (mutually_promotable_dtypes (dtype_objs ))
351
+ dtypes : Tuple [DataType , ...] = dh .all_dtypes ,
352
+ two_shapes : SearchStrategy [Tuple [Shape , Shape ]] = two_mutually_broadcastable_shapes ,
353
+ ) -> SearchStrategy :
354
+ mutual_dtypes = shared (mutually_promotable_dtypes (dtypes = dtypes ))
324
355
mutual_shapes = shared (two_shapes )
325
356
arrays1 = xps .arrays (
326
357
dtype = mutual_dtypes .map (lambda pair : pair [0 ]),
0 commit comments