Skip to content

Commit a76172e

Browse files
ArmavicaricardoV94
authored andcommitted
Fix UP038 (isinstance(..., X | Y))
1 parent 090b844 commit a76172e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+367
-402
lines changed

pytensor/compile/builders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _filter_grad_var(grad, inp):
275275
#
276276
# For now, this converts NullType or DisconnectedType into zeros_like.
277277
# other types are unmodified: overrider_var -> None
278-
if isinstance(grad.type, (NullType, DisconnectedType)):
278+
if isinstance(grad.type, NullType | DisconnectedType):
279279
if hasattr(inp, "zeros_like"):
280280
return inp.zeros_like(), grad
281281
else:
@@ -535,7 +535,7 @@ def lop_op(inps, grads):
535535
all_grads_l = list(all_grads_l)
536536
all_grads_ov_l = list(all_grads_ov_l)
537537
elif isinstance(lop_op, Variable):
538-
if isinstance(lop_op.type, (DisconnectedType, NullType)):
538+
if isinstance(lop_op.type, DisconnectedType | NullType):
539539
all_grads_l = [inp.zeros_like() for inp in local_inputs]
540540
all_grads_ov_l = [lop_op.type() for _ in range(inp_len)]
541541
else:
@@ -562,7 +562,7 @@ def lop_op(inps, grads):
562562
all_grads_l.append(gnext)
563563
all_grads_ov_l.append(gnext_ov)
564564
elif isinstance(fn_gov, Variable):
565-
if isinstance(fn_gov.type, (DisconnectedType, NullType)):
565+
if isinstance(fn_gov.type, DisconnectedType | NullType):
566566
all_grads_l.append(inp.zeros_like())
567567
all_grads_ov_l.append(fn_gov.type())
568568
else:

pytensor/compile/compiledir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def cleanup():
5555
have_npy_abi_version = True
5656
elif obj.startswith("c_compiler_str="):
5757
have_c_compiler = True
58-
elif isinstance(obj, (Op, CType)) and hasattr(
58+
elif isinstance(obj, Op | CType) and hasattr(
5959
obj, "c_code_cache_version"
6060
):
6161
v = obj.c_code_cache_version()

pytensor/compile/debugmode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -679,7 +679,7 @@ def _lessbroken_deepcopy(a):
679679
# This logic is also in link.py
680680
from pytensor.link.c.type import _cdata_type
681681

682-
if isinstance(a, (np.ndarray, np.memmap)):
682+
if isinstance(a, np.ndarray | np.memmap):
683683
rval = a.copy(order="K")
684684
elif isinstance(a, _cdata_type):
685685
# This is not copyable (and should be used for constant data).
@@ -1889,7 +1889,7 @@ def thunk():
18891889
# HACK TO LOOK LIKE A REAL DESTRUCTIVE ACTION
18901890
# TOOK PLACE
18911891
if (
1892-
isinstance(dr_vals[r][0], (np.ndarray, np.memmap))
1892+
isinstance(dr_vals[r][0], np.ndarray | np.memmap)
18931893
and (dr_vals[r][0].dtype == storage_map[r][0].dtype)
18941894
and (dr_vals[r][0].shape == storage_map[r][0].shape)
18951895
):
@@ -2019,10 +2019,10 @@ def __init__(
20192019
if outputs is None:
20202020
return_none = True
20212021
outputs = []
2022-
if not isinstance(outputs, (list, tuple)):
2022+
if not isinstance(outputs, list | tuple):
20232023
unpack_single = True
20242024
outputs = [outputs]
2025-
if not isinstance(inputs, (list, tuple)):
2025+
if not isinstance(inputs, list | tuple):
20262026
inputs = [inputs]
20272027

20282028
# Wrap them in In or Out instances if needed.

pytensor/compile/function/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,15 +285,15 @@ def opt_log1p(node):
285285

286286
if givens is None:
287287
givens = []
288-
if not isinstance(inputs, (list, tuple)):
288+
if not isinstance(inputs, list | tuple):
289289
raise Exception(
290290
"Input variables of an PyTensor function should be "
291291
"contained in a list, even when there is a single "
292292
"input."
293293
)
294294

295295
# compute some features of the arguments:
296-
uses_tuple = any(isinstance(i, (list, tuple)) for i in inputs)
296+
uses_tuple = any(isinstance(i, list | tuple) for i in inputs)
297297
uses_updates = bool(updates)
298298
uses_givens = bool(givens)
299299

pytensor/compile/function/pfunc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def construct_pfunc_ins_and_outs(
512512
if givens is None:
513513
givens = []
514514

515-
if not isinstance(params, (list, tuple)):
515+
if not isinstance(params, list | tuple):
516516
raise TypeError("The `params` argument must be a list or a tuple")
517517

518518
if not isinstance(no_default_updates, bool) and not isinstance(
@@ -521,7 +521,7 @@ def construct_pfunc_ins_and_outs(
521521
raise TypeError("The `no_default_update` argument must be a boolean or list")
522522

523523
if len(updates) > 0 and not all(
524-
isinstance(pair, (tuple, list))
524+
isinstance(pair, tuple | list)
525525
and len(pair) == 2
526526
and isinstance(pair[0], Variable)
527527
for pair in iter_over_pairs(updates)
@@ -575,7 +575,7 @@ def construct_pfunc_ins_and_outs(
575575
if outputs is None:
576576
out_list = []
577577
else:
578-
if isinstance(outputs, (list, tuple)):
578+
if isinstance(outputs, list | tuple):
579579
out_list = list(outputs)
580580
else:
581581
out_list = [outputs]
@@ -598,7 +598,7 @@ def construct_pfunc_ins_and_outs(
598598
if outputs is None:
599599
new_outputs = []
600600
else:
601-
if isinstance(outputs, (list, tuple)):
601+
if isinstance(outputs, list | tuple):
602602
new_outputs = cloned_extended_outputs[: len(outputs)]
603603
else:
604604
new_outputs = cloned_extended_outputs[0]

pytensor/compile/function/types.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,7 @@ def wrap_in(input):
13101310
elif isinstance(input, Variable):
13111311
# r -> SymbolicInput(variable=r)
13121312
return SymbolicInput(input)
1313-
elif isinstance(input, (list, tuple)):
1313+
elif isinstance(input, list | tuple):
13141314
# (r, u) -> SymbolicInput(variable=r, update=u)
13151315
if len(input) == 2:
13161316
return SymbolicInput(input[0], update=input[1])
@@ -1495,10 +1495,10 @@ def __init__(
14951495
if outputs is None:
14961496
return_none = True
14971497
outputs = []
1498-
if not isinstance(outputs, (list, tuple)):
1498+
if not isinstance(outputs, list | tuple):
14991499
unpack_single = True
15001500
outputs = [outputs]
1501-
if not isinstance(inputs, (list, tuple)):
1501+
if not isinstance(inputs, list | tuple):
15021502
inputs = [inputs]
15031503

15041504
# Wrap them in In or Out instances if needed.
@@ -1734,14 +1734,14 @@ def orig_function(
17341734
inputs = list(map(convert_function_input, inputs))
17351735

17361736
if outputs is not None:
1737-
if isinstance(outputs, (list, tuple)):
1737+
if isinstance(outputs, list | tuple):
17381738
outputs = list(map(FunctionMaker.wrap_out, outputs))
17391739
else:
17401740
outputs = FunctionMaker.wrap_out(outputs)
17411741

17421742
defaults = [getattr(input, "value", None) for input in inputs]
17431743

1744-
if isinstance(mode, (list, tuple)):
1744+
if isinstance(mode, list | tuple):
17451745
raise ValueError("We do not support the passing of multiple modes")
17461746

17471747
fn = None
@@ -1797,7 +1797,7 @@ def convert_function_input(input):
17971797
raise TypeError(f"A Constant instance is not a legal function input: {input}")
17981798
elif isinstance(input, Variable):
17991799
return In(input)
1800-
elif isinstance(input, (list, tuple)):
1800+
elif isinstance(input, list | tuple):
18011801
orig = input
18021802
if not input:
18031803
raise TypeError(f"Nonsensical input specification: {input}")
@@ -1806,7 +1806,7 @@ def convert_function_input(input):
18061806
input = input[1:]
18071807
else:
18081808
name = None
1809-
if isinstance(input[0], (list, tuple)):
1809+
if isinstance(input[0], list | tuple):
18101810
if len(input[0]) != 2 or len(input) != 2:
18111811
raise TypeError(
18121812
f"Invalid input syntax: {orig} (check "
@@ -1843,7 +1843,7 @@ def convert_function_input(input):
18431843
raise TypeError(
18441844
f"Unknown update type: {type(update)}, expected Variable instance"
18451845
)
1846-
if value is not None and isinstance(value, (Variable, SymbolicInput)):
1846+
if value is not None and isinstance(value, Variable | SymbolicInput):
18471847
raise TypeError(
18481848
f"The value for input {variable} should not be a Variable "
18491849
f"or SymbolicInput instance (got: {value})"

pytensor/compile/mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def get_target_language(mode=None) -> tuple[Literal["py", "c", "numba", "jax"],
565565
if isinstance(linker, CLinker):
566566
return ("c",)
567567

568-
if isinstance(linker, (VMLinker, OpWiseCLinker)):
568+
if isinstance(linker, VMLinker | OpWiseCLinker):
569569
return ("c", "py") if config.cxx else ("py",)
570570

571571
raise Exception(f"Unsupported Linker: {linker}")

pytensor/compile/monitormode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def detect_nan(fgraph, i, node, fn):
105105

106106
for output in fn.outputs:
107107
if (
108-
not isinstance(output[0], (np.random.RandomState, np.random.Generator))
108+
not isinstance(output[0], np.random.RandomState | np.random.Generator)
109109
and np.isnan(output[0]).any()
110110
):
111111
print("*** NaN detected ***")

pytensor/compile/nanguardmode.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _is_numeric_value(arr, var):
3434

3535
if isinstance(arr, _cdata_type):
3636
return False
37-
elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)):
37+
elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator):
3838
return False
3939
elif var and isinstance(var.type, RandomType):
4040
return False
@@ -62,10 +62,10 @@ def flatten(l):
6262
A flattened list of objects.
6363
6464
"""
65-
if isinstance(l, (list, tuple, ValuesView)):
65+
if isinstance(l, list | tuple | ValuesView):
6666
rval = []
6767
for elem in l:
68-
if isinstance(elem, (list, tuple)):
68+
if isinstance(elem, list | tuple):
6969
rval.extend(flatten(elem))
7070
else:
7171
rval.append(elem)

pytensor/compile/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def __str__(self):
256256

257257
def perform(self, node, inputs, outputs):
258258
outs = self.__fn(*inputs)
259-
if not isinstance(outs, (list, tuple)):
259+
if not isinstance(outs, list | tuple):
260260
outs = (outs,)
261261
assert len(outs) == len(outputs)
262262
for i in range(len(outs)):
@@ -308,11 +308,11 @@ def numpy_dot(a, b):
308308
return numpy.dot(a, b)
309309
310310
"""
311-
if not isinstance(itypes, (list, tuple)):
311+
if not isinstance(itypes, list | tuple):
312312
itypes = [itypes]
313313
if not all(isinstance(t, CType) for t in itypes):
314314
raise TypeError("itypes has to be a list of PyTensor types")
315-
if not isinstance(otypes, (list, tuple)):
315+
if not isinstance(otypes, list | tuple):
316316
otypes = [otypes]
317317
if not all(isinstance(t, CType) for t in otypes):
318318
raise TypeError("otypes has to be a list of PyTensor types")

pytensor/d3viz/formatting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def __call__(self, fct, graph=None):
132132
fct = [fct]
133133
elif isinstance(fct, Apply):
134134
fct = fct.outputs
135-
assert isinstance(fct, (list, tuple))
135+
assert isinstance(fct, list | tuple)
136136
assert all(isinstance(v, Variable) for v in fct)
137137
fgraph = FunctionGraph(inputs=graph_inputs(fct), outputs=fct)
138138

@@ -281,7 +281,7 @@ def var_tag(var):
281281
"""Parse tag attribute of variable node."""
282282
tag = var.tag
283283
if hasattr(tag, "trace") and len(tag.trace) and len(tag.trace[0]) == 4:
284-
if isinstance(tag.trace[0][0], (tuple, list)):
284+
if isinstance(tag.trace[0][0], tuple | list):
285285
path, line, _, src = tag.trace[0][-1]
286286
else:
287287
path, line, _, src = tag.trace[0]

pytensor/gradient.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,17 @@ def Rop(
193193
If `f` is a list/tuple, then return a list/tuple with the results.
194194
"""
195195

196-
if not isinstance(wrt, (list, tuple)):
196+
if not isinstance(wrt, list | tuple):
197197
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
198198
else:
199199
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
200200

201-
if not isinstance(eval_points, (list, tuple)):
201+
if not isinstance(eval_points, list | tuple):
202202
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
203203
else:
204204
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
205205

206-
if not isinstance(f, (list, tuple)):
206+
if not isinstance(f, list | tuple):
207207
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
208208
else:
209209
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
@@ -381,19 +381,19 @@ def Lop(
381381
coordinates of the tensor elements.
382382
If `f` is a list/tuple, then return a list/tuple with the results.
383383
"""
384-
if not isinstance(eval_points, (list, tuple)):
384+
if not isinstance(eval_points, list | tuple):
385385
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
386386
else:
387387
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
388388

389-
if not isinstance(f, (list, tuple)):
389+
if not isinstance(f, list | tuple):
390390
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
391391
else:
392392
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
393393

394394
grads = list(_eval_points)
395395

396-
if not isinstance(wrt, (list, tuple)):
396+
if not isinstance(wrt, list | tuple):
397397
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
398398
else:
399399
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
@@ -547,7 +547,7 @@ def grad(
547547
)
548548

549549
if not isinstance(
550-
g_var.type, (NullType, DisconnectedType)
550+
g_var.type, NullType | DisconnectedType
551551
) and "float" not in str(g_var.type.dtype):
552552
raise TypeError(
553553
"Gradients must always be NullType, "
@@ -1276,7 +1276,7 @@ def try_to_copy_if_needed(var):
12761276
f"of shape {i_shape}"
12771277
)
12781278

1279-
if not isinstance(term.type, (NullType, DisconnectedType)):
1279+
if not isinstance(term.type, NullType | DisconnectedType):
12801280
if term.type.dtype not in pytensor.tensor.type.float_dtypes:
12811281
raise TypeError(
12821282
str(node.op) + ".grad illegally "
@@ -1500,7 +1500,7 @@ def prod(inputs):
15001500
return rval
15011501

15021502
packed_pt = False
1503-
if not isinstance(pt, (list, tuple)):
1503+
if not isinstance(pt, list | tuple):
15041504
pt = [pt]
15051505
packed_pt = True
15061506

@@ -1732,7 +1732,7 @@ def verify_grad(
17321732
from pytensor.compile.function import function
17331733
from pytensor.compile.sharedvalue import shared
17341734

1735-
if not isinstance(pt, (list, tuple)):
1735+
if not isinstance(pt, list | tuple):
17361736
raise TypeError("`pt` should be a list or tuple")
17371737

17381738
pt = [np.array(p) for p in pt]
@@ -1933,7 +1933,7 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
19331933
using_list = isinstance(wrt, list)
19341934
using_tuple = isinstance(wrt, tuple)
19351935

1936-
if isinstance(wrt, (list, tuple)):
1936+
if isinstance(wrt, list | tuple):
19371937
wrt = list(wrt)
19381938
else:
19391939
wrt = [wrt]
@@ -2015,7 +2015,7 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
20152015
using_list = isinstance(wrt, list)
20162016
using_tuple = isinstance(wrt, tuple)
20172017

2018-
if isinstance(wrt, (list, tuple)):
2018+
if isinstance(wrt, list | tuple):
20192019
wrt = list(wrt)
20202020
else:
20212021
wrt = [wrt]

pytensor/graph/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def clone_with_new_inputs(
252252
"""
253253
from pytensor.graph.op import HasInnerGraph
254254

255-
assert isinstance(inputs, (list, tuple))
255+
assert isinstance(inputs, list | tuple)
256256
remake_node = False
257257
new_inputs: list["Variable"] = list(inputs)
258258

@@ -1368,7 +1368,7 @@ def _compute_deps_cache_(io):
13681368
d = deps(io)
13691369

13701370
if d:
1371-
if not isinstance(d, (list, OrderedSet)):
1371+
if not isinstance(d, list | OrderedSet):
13721372
raise TypeError(
13731373
"Non-deterministic collections found; make"
13741374
" toposort non-deterministic."

pytensor/graph/destroyhandler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _contains_cycle(fgraph, orderings):
6363

6464
# this is performance-critical code. it is the largest single-function
6565
# bottleneck when compiling large graphs.
66-
assert isinstance(outputs, (tuple, list, deque))
66+
assert isinstance(outputs, tuple | list | deque)
6767

6868
# TODO: For more speed - use a defaultdict for the orderings
6969
# (defaultdict runs faster than dict in the case where the key

0 commit comments

Comments
 (0)