Skip to content

Commit 4d96c47

Browse files
committed
Remove strict=False in hot loops
This is actually slower than just not specifying it
1 parent e98cbbc commit 4d96c47

File tree

15 files changed

+44
-50
lines changed

15 files changed

+44
-50
lines changed

pyproject.toml

+1-5
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,8 @@ exclude = ["doc/", "pytensor/_version.py"]
130130
docstring-code-format = true
131131

132132
[tool.ruff.lint]
133-
select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
133+
select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
134134
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
135-
unfixable = [
136-
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
137-
"B905",
138-
]
139135

140136

141137
[tool.ruff.lint.isort]

pytensor/compile/builders.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,6 @@ def clone(self):
873873

874874
def perform(self, node, inputs, outputs):
875875
variables = self.fn(*inputs)
876-
assert len(variables) == len(outputs)
877-
# strict=False because asserted above
878-
for output, variable in zip(outputs, variables, strict=False):
876+
# strict=None because we are in a hot loop
877+
for output, variable in zip(outputs, variables):
879878
output[0] = variable

pytensor/link/basic.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,12 @@ def make_all(
373373

374374
# The function that actually runs your program is one of the f's in streamline.
375375
f = streamline(
376-
fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling
376+
fgraph,
377+
thunks,
378+
order,
379+
post_thunk_old_storage=post_thunk_old_storage,
380+
no_recycling=no_recycling,
381+
output_storage=output_storage,
377382
)
378383

379384
f.allow_gc = (
@@ -539,14 +544,14 @@ def make_thunk(self, **kwargs):
539544

540545
def f():
541546
for inputs in input_lists[1:]:
542-
# strict=False because we are in a hot loop
543-
for input1, input2 in zip(inputs0, inputs, strict=False):
547+
# strict=None because we are in a hot loop
548+
for input1, input2 in zip(inputs0, inputs):
544549
input2.storage[0] = copy(input1.storage[0])
545550
for x in to_reset:
546551
x[0] = None
547552
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
548-
# strict=False because we are in a hot loop
549-
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
553+
# strict=None because we are in a hot loop
554+
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
550555
try:
551556
wrapper(self.fgraph, i, node, *thunks)
552557
except Exception:
@@ -668,8 +673,8 @@ def thunk(
668673
# since the error may come from any of them?
669674
raise_with_op(self.fgraph, output_nodes[0], thunk)
670675

671-
# strict=False because we are in a hot loop
672-
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
676+
# strict=None because we are in a hot loop
677+
for o_storage, o_val in zip(thunk_outputs, outputs):
673678
o_storage[0] = o_val
674679

675680
thunk.inputs = thunk_inputs

pytensor/link/numba/dispatch/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,10 @@ def py_perform_return(inputs):
312312
else:
313313

314314
def py_perform_return(inputs):
315-
# strict=False because we are in a hot loop
315+
# strict=None because we are in a hot loop
316316
return tuple(
317317
out_type.filter(out[0])
318-
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
318+
for out_type, out in zip(output_types, py_perform(inputs))
319319
)
320320

321321
@numba_njit

pytensor/link/numba/dispatch/cython_support.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,7 @@ def __wrapper_address__(self):
166166
def __call__(self, *args, **kwargs):
167167
# no strict argument because of the JIT
168168
# TODO: check
169-
args = [
170-
dtype(arg)
171-
for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905
172-
]
169+
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
173170
if self.has_pyx_skip_dispatch():
174171
output = self._pyfunc(*args[:-1], **kwargs)
175172
else:

pytensor/link/numba/dispatch/extra_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def ravelmultiindex(*inp):
186186
new_arr = arr.T.astype(np.float64).copy()
187187
for i, b in enumerate(new_arr):
188188
# no strict argument to this zip because numba doesn't support it
189-
for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905
189+
for j, (d, v) in enumerate(zip(shape, b)):
190190
if v < 0 or v >= d:
191191
mode_fn(new_arr, i, j, v, d)
192192

pytensor/link/numba/dispatch/slinalg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def block_diag(*arrs):
183183

184184
r, c = 0, 0
185185
# no strict argument because it is incompatible with numba
186-
for arr, shape in zip(arrs, shapes): # noqa: B905
186+
for arr, shape in zip(arrs, shapes):
187187
rr, cc = shape
188188
out[r : r + rr, c : c + cc] = arr
189189
r += rr

pytensor/link/numba/dispatch/subtensor.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def advanced_subtensor_multiple_vector(x, *idxs):
219219
shape_aft = x_shape[after_last_axis:]
220220
out_shape = (*shape_bef, *idx_shape, *shape_aft)
221221
out_buffer = np.empty(out_shape, dtype=x.dtype)
222-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
222+
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
223223
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
224224
return out_buffer
225225

@@ -253,7 +253,7 @@ def advanced_set_subtensor_multiple_vector(x, y, *idxs):
253253
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
254254

255255
for outer in np.ndindex(x_shape[:first_axis]):
256-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
256+
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
257257
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
258258
return out
259259

@@ -275,7 +275,7 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
275275
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
276276

277277
for outer in np.ndindex(x_shape[:first_axis]):
278-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
278+
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
279279
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
280280
return out
281281

@@ -314,7 +314,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
314314
if not len(idxs) == len(vals):
315315
raise ValueError("The number of indices and values must match.")
316316
# no strict argument because incompatible with numba
317-
for idx, val in zip(idxs, vals): # noqa: B905
317+
for idx, val in zip(idxs, vals):
318318
x[idx] = val
319319
return x
320320
else:
@@ -342,7 +342,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
342342
raise ValueError("The number of indices and values must match.")
343343
# no strict argument because unsupported by numba
344344
# TODO: this doesn't come up in tests
345-
for idx, val in zip(idxs, vals): # noqa: B905
345+
for idx, val in zip(idxs, vals):
346346
x[idx] += val
347347
return x
348348

pytensor/link/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ def streamline_default_f():
190190
for x in no_recycling:
191191
x[0] = None
192192
try:
193-
# strict=False because we are in a hot loop
193+
# strict=None because we are in a hot loop
194194
for thunk, node, old_storage in zip(
195-
thunks, order, post_thunk_old_storage, strict=False
195+
thunks, order, post_thunk_old_storage
196196
):
197197
thunk()
198198
for old_s in old_storage:
@@ -207,8 +207,8 @@ def streamline_nice_errors_f():
207207
for x in no_recycling:
208208
x[0] = None
209209
try:
210-
# strict=False because we are in a hot loop
211-
for thunk, node in zip(thunks, order, strict=False):
210+
# strict=None because we are in a hot loop
211+
for thunk, node in zip(thunks, order):
212212
thunk()
213213
except Exception:
214214
raise_with_op(fgraph, node, thunk)

pytensor/scalar/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4416,8 +4416,8 @@ def make_node(self, *inputs):
44164416

44174417
def perform(self, node, inputs, output_storage):
44184418
outputs = self.py_perform_fn(*inputs)
4419-
# strict=False because we are in a hot loop
4420-
for storage, out_val in zip(output_storage, outputs, strict=False):
4419+
# strict=None because we are in a hot loop
4420+
for storage, out_val in zip(output_storage, outputs):
44214421
storage[0] = out_val
44224422

44234423
def grad(self, inputs, output_grads):

pytensor/scalar/loop.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def perform(self, node, inputs, output_storage):
196196
for i in range(n_steps):
197197
carry = inner_fn(*carry, *constant)
198198

199-
# strict=False because we are in a hot loop
200-
for storage, out_val in zip(output_storage, carry, strict=False):
199+
# strict=None because we are in a hot loop
200+
for storage, out_val in zip(output_storage, carry):
201201
storage[0] = out_val
202202

203203
@property

pytensor/tensor/random/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1865,8 +1865,8 @@ def rng_fn(cls, rng, p, size):
18651865
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
18661866
if len(size) < (p.ndim - 1):
18671867
raise ValueError("`size` is incompatible with the shape of `p`")
1868-
# strict=False because we are in a hot loop
1869-
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False):
1868+
# strict=None because we are in a hot loop
1869+
for s, ps in zip(reversed(size), reversed(p.shape[:-1])):
18701870
if s == 1 and ps != 1:
18711871
raise ValueError("`size` is incompatible with the shape of `p`")
18721872

pytensor/tensor/random/utils.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def params_broadcast_shapes(
4444
max_fn = maximum if use_pytensor else max
4545

4646
rev_extra_dims: list[int] = []
47-
# strict=False because we are in a hot loop
48-
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False):
47+
# strict=None because we are in a hot loop
48+
for ndim_param, param_shape in zip(ndims_params, param_shapes):
4949
# We need this in order to use `len`
5050
param_shape = tuple(param_shape)
5151
extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
@@ -69,7 +69,7 @@ def max_bcast(x, y):
6969
(extra_dims + tuple(param_shape)[-ndim_param:])
7070
if ndim_param > 0
7171
else extra_dims
72-
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False)
72+
for ndim_param, param_shape in zip(ndims_params, param_shapes)
7373
]
7474

7575
return bcast_shapes
@@ -127,10 +127,9 @@ def broadcast_params(
127127
)
128128
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
129129

130-
# strict=False because we are in a hot loop
130+
# strict=None because we are in a hot loop
131131
bcast_params = [
132-
broadcast_to_fn(param, shape)
133-
for shape, param in zip(shapes, params, strict=False)
132+
broadcast_to_fn(param, shape) for shape, param in zip(shapes, params)
134133
]
135134

136135
return bcast_params

pytensor/tensor/shape.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -447,10 +447,8 @@ def perform(self, node, inp, out_):
447447
raise AssertionError(
448448
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
449449
)
450-
# strict=False because we are in a hot loop
451-
if not all(
452-
xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None
453-
):
450+
# strict=None because we are in a hot loop
451+
if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
454452
raise AssertionError(
455453
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
456454
)

pytensor/tensor/type.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:
261261
" PyTensor C code does not support that.",
262262
)
263263

264-
# strict=False because we are in a hot loop
264+
# strict=None because we are in a hot loop
265265
if not all(
266266
ds == ts if ts is not None else True
267-
for ds, ts in zip(data.shape, self.shape, strict=False)
267+
for ds, ts in zip(data.shape, self.shape)
268268
):
269269
raise TypeError(
270270
f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"

0 commit comments

Comments
 (0)