Skip to content

Commit e33b95b

Browse files
committed
Add flake8-comprehensions plugin
1 parent 9a653c7 commit e33b95b

27 files changed

+55
-61
lines changed

.pre-commit-config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ repos:
3333
rev: 6.0.0
3434
hooks:
3535
- id: flake8
36+
additional_dependencies:
37+
- flake8-comprehensions
3638
- repo: https://github.com/pycqa/isort
3739
rev: 5.12.0
3840
hooks:

pytensor/compile/builders.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -969,9 +969,7 @@ def inline_ofg_expansion(fgraph, node):
969969
return False
970970
if not op.is_inline:
971971
return False
972-
return clone_replace(
973-
op.inner_outputs, {u: v for u, v in zip(op.inner_inputs, node.inputs)}
974-
)
972+
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
975973

976974

977975
# We want to run this before the first merge optimizer

pytensor/gradient.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def grad(
504504
if not isinstance(wrt, Sequence):
505505
_wrt: List[Variable] = [wrt]
506506
else:
507-
_wrt = [x for x in wrt]
507+
_wrt = list(wrt)
508508

509509
outputs = []
510510
if cost is not None:
@@ -791,8 +791,8 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
791791

792792
pgrads = dict(zip(params, grads))
793793
# separate wrt from end grads:
794-
wrt_grads = list(pgrads[k] for k in wrt)
795-
end_grads = list(pgrads[k] for k in end)
794+
wrt_grads = [pgrads[k] for k in wrt]
795+
end_grads = [pgrads[k] for k in end]
796796

797797
if details:
798798
return wrt_grads, end_grads, start_grads, cost_grads

pytensor/graph/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1629,7 +1629,7 @@ def as_string(
16291629
multi.add(op2)
16301630
else:
16311631
seen.add(input.owner)
1632-
multi_list = [x for x in multi]
1632+
multi_list = list(multi)
16331633
done: Set = set()
16341634

16351635
def multi_index(x):

pytensor/graph/replace.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ def toposort_key(fg: FunctionGraph, ts, pair):
142142
raise ValueError(f"{key} is not a part of graph")
143143

144144
sorted_replacements = sorted(
145-
tuple(fg_replace.items()),
146-
# sort based on the fg toposort, if a variable has no owner, it goes first
145+
fg_replace.items(),
147146
key=partial(toposort_key, fg, toposort),
148147
reverse=True,
149148
)

pytensor/graph/rewriting/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2575,8 +2575,8 @@ def print_profile(cls, stream, prof, level=0):
25752575
for i in range(len(loop_timing)):
25762576
loop_times = ""
25772577
if loop_process_count[i]:
2578-
d = list(
2579-
reversed(sorted(loop_process_count[i].items(), key=lambda a: a[1]))
2578+
d = sorted(
2579+
loop_process_count[i].items(), key=lambda a: a[1], reverse=True
25802580
)
25812581
loop_times = " ".join([str((str(k), v)) for k, v in d[:5]])
25822582
if len(d) > 5:

pytensor/link/c/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -633,11 +633,11 @@ def fetch_variables(self):
633633

634634
# The orphans field is listified to ensure a consistent order.
635635
# list(fgraph.orphans.difference(self.outputs))
636-
self.orphans = list(
636+
self.orphans = [
637637
r
638638
for r in self.variables
639639
if isinstance(r, AtomicVariable) and r not in self.inputs
640-
)
640+
]
641641
# C type constants (pytensor.scalar.ScalarType). They don't request an object
642642
self.consts = []
643643
# Move c type from orphans (pytensor.scalar.ScalarType) to self.consts

pytensor/link/c/params_type.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def c_support_code(self, **kwargs):
810810
struct_extract_method=struct_extract_method,
811811
)
812812

813-
return list(sorted(list(c_support_code_set))) + [final_struct_code]
813+
return list(c_support_code_set) + [final_struct_code]
814814

815815
def c_code_cache_version(self):
816816
return ((3,), tuple(t.c_code_cache_version() for t in self.types))

pytensor/link/jax/dispatch/elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def careduce(x):
4141
elif scalar_op_name:
4242
scalar_fn_name = scalar_op_name
4343

44-
to_reduce = reversed(sorted(axis))
44+
to_reduce = sorted(axis, reverse=True)
4545

4646
if to_reduce:
4747
# In this case, we need to use the `jax.lax` function (if there

pytensor/link/numba/dispatch/elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def careduce_maximum(input):
361361

362362
careduce_fn_name = f"careduce_{scalar_op}"
363363
global_env = {}
364-
to_reduce = reversed(sorted(axes))
364+
to_reduce = sorted(axes, reverse=True)
365365
careduce_lines_src = []
366366
var_name = input_name
367367

pytensor/printing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def grad(self, input, output_gradients):
796796
return output_gradients
797797

798798
def R_op(self, inputs, eval_points):
799-
return [x for x in eval_points]
799+
return list(eval_points)
800800

801801
def __setstate__(self, dct):
802802
dct.setdefault("global_fn", _print_fn)

pytensor/scan/basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def wrap_into_list(x):
492492
# wrap sequences in a dictionary if they are not already dictionaries
493493
for i in range(n_seqs):
494494
if not isinstance(seqs[i], dict):
495-
seqs[i] = dict([("input", seqs[i]), ("taps", [0])])
495+
seqs[i] = {"input": seqs[i], "taps": [0]}
496496
elif seqs[i].get("taps", None) is not None:
497497
seqs[i]["taps"] = wrap_into_list(seqs[i]["taps"])
498498
elif seqs[i].get("taps", None) is None:
@@ -504,7 +504,7 @@ def wrap_into_list(x):
504504
if outs_info[i] is not None:
505505
if not isinstance(outs_info[i], dict):
506506
# by default any output has a tap value of -1
507-
outs_info[i] = dict([("initial", outs_info[i]), ("taps", [-1])])
507+
outs_info[i] = {"initial": outs_info[i], "taps": [-1]}
508508
elif (
509509
outs_info[i].get("initial", None) is None
510510
and outs_info[i].get("taps", None) is not None

pytensor/scan/op.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -1718,12 +1718,9 @@ def perform(self, node, inputs, output_storage, params=None):
17181718
arg.shape[0]
17191719
for arg in inputs[self.seqs_arg_offset : self.shared_arg_offset]
17201720
]
1721-
store_steps += [
1722-
arg
1723-
for arg in inputs[
1724-
self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot
1725-
]
1726-
]
1721+
store_steps += list(
1722+
inputs[self.nit_sot_arg_offset : self.nit_sot_arg_offset + info.n_nit_sot]
1723+
)
17271724

17281725
# 2.1 Create storage space for outputs
17291726
for idx in range(self.n_outs):
@@ -2270,7 +2267,7 @@ def infer_shape(self, fgraph, node, input_shapes):
22702267
)
22712268

22722269
offset = 1 + info.n_seqs
2273-
scan_outs = [x for x in input_shapes[offset : offset + n_outs]]
2270+
scan_outs = list(input_shapes[offset : offset + n_outs])
22742271
offset += n_outs
22752272
outs_shape_n = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot
22762273
for x in range(info.n_nit_sot):
@@ -2301,7 +2298,7 @@ def infer_shape(self, fgraph, node, input_shapes):
23012298
shp.append(v_shp_i[0])
23022299
scan_outs.append(tuple(shp))
23032300

2304-
scan_outs += [x for x in input_shapes[offset : offset + info.n_shared_outs]]
2301+
scan_outs += list(input_shapes[offset : offset + info.n_shared_outs])
23052302
# if we are dealing with a repeat-until, then we do not know the
23062303
# leading dimension so we replace it for every entry with Shape_i
23072304
if info.as_while:

pytensor/scan/rewriting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def add_to_replace(y):
388388
if out in local_fgraph_outs_set:
389389
x = node.outputs[local_fgraph_outs_map[out]]
390390
y = replace_with_out[idx]
391-
y_shape = [shp for shp in y.shape]
391+
y_shape = list(y.shape)
392392
replace_with[x] = at.alloc(y, node.inputs[0], *y_shape)
393393

394394
# We need to add one extra dimension to the outputs

pytensor/tensor/random/op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def infer_shape(self, fgraph, node, input_shapes):
283283

284284
shape = self._infer_shape(size, dist_params, param_shapes=param_shapes)
285285

286-
return [None, [s for s in shape]]
286+
return [None, list(shape)]
287287

288288
def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
289289
res = super().__call__(rng, size, dtype, *args, **kwargs)

pytensor/tensor/rewriting/math.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1555,11 +1555,11 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
15551555
)
15561556

15571557
if len(compatible_dims) > 0:
1558-
optimized_dimshuffle_order = list(
1558+
optimized_dimshuffle_order = [
15591559
ax
15601560
for i, ax in enumerate(dimshuffle_order)
15611561
if (i not in axis) or (ax != "x")
1562-
)
1562+
]
15631563

15641564
# Removing leading 'x' (since it will be done automatically)
15651565
while (
@@ -1644,7 +1644,7 @@ def local_op_of_op(fgraph, node):
16441644
return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])]
16451645

16461646
# figure out which axes were in the original sum
1647-
newaxis = list(tuple(node_inps.owner.op.axis))
1647+
newaxis = list(node_inps.owner.op.axis)
16481648
for i in node.op.axis:
16491649
new_i = i
16501650
for ii in node_inps.owner.op.axis:

pytensor/tensor/shape.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,7 @@ def shape_padleft(t, n_ones=1):
810810
"""
811811
_t = at.as_tensor_variable(t)
812812

813-
pattern = ["x"] * n_ones + [i for i in range(_t.type.ndim)]
813+
pattern = ["x"] * n_ones + list(range(_t.type.ndim))
814814
return _t.dimshuffle(pattern)
815815

816816

@@ -826,7 +826,7 @@ def shape_padright(t, n_ones=1):
826826
"""
827827
_t = at.as_tensor_variable(t)
828828

829-
pattern = [i for i in range(_t.type.ndim)] + ["x"] * n_ones
829+
pattern = list(range(_t.type.ndim)) + ["x"] * n_ones
830830
return _t.dimshuffle(pattern)
831831

832832

@@ -861,7 +861,7 @@ def shape_padaxis(t, axis):
861861
if axis < 0:
862862
axis += ndim
863863

864-
pattern = [i for i in range(_t.type.ndim)]
864+
pattern = list(range(_t.type.ndim))
865865
pattern.insert(axis, "x")
866866
return _t.dimshuffle(pattern)
867867

pytensor/tensor/subtensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2604,7 +2604,7 @@ def infer_shape(self, fgraph, node, ishapes):
26042604
ishapes[0], index_shapes, indices_are_shapes=True
26052605
)
26062606
assert node.outputs[0].ndim == len(res_shape)
2607-
return [[s for s in res_shape]]
2607+
return [list(res_shape)]
26082608

26092609
def perform(self, node, inputs, out_):
26102610
(out,) = out_

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[flake8]
22
select = C,E,F,W
3-
ignore = E203,E231,E501,E741,W503,W504,C901
3+
ignore = E203,E231,E501,E741,W503,W504,C408,C901
44
per-file-ignores =
55
**/__init__.py:F401,E402,F403
66
pytensor/tensor/linalg.py:F401,F403

tests/graph/test_features.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def inputs():
7373

7474
assert hasattr(g, "get_nodes")
7575
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
76-
if len([t for t in g.get_nodes(type)]) != num:
76+
if len(list(g.get_nodes(type))) != num:
7777
raise Exception("Expected: %i times %s" % (num, type))
7878
new_e0 = add(y, z)
7979
assert e0.owner in g.get_nodes(dot)
@@ -82,7 +82,7 @@ def inputs():
8282
assert e0.owner not in g.get_nodes(dot)
8383
assert new_e0.owner in g.get_nodes(add)
8484
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
85-
if len([t for t in g.get_nodes(type)]) != num:
85+
if len(list(g.get_nodes(type))) != num:
8686
raise Exception("Expected: %i times %s" % (num, type))
8787

8888

tests/graph/test_op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_sanity_0(self):
8787
r1, r2 = MyType(1)(), MyType(2)()
8888
node = MyOp.make_node(r1, r2)
8989
# Are the inputs what I provided?
90-
assert [x for x in node.inputs] == [r1, r2]
90+
assert list(node.inputs) == [r1, r2]
9191
# Are the outputs what I expect?
9292
assert [x.type for x in node.outputs] == [MyType(3)]
9393
assert node.outputs[0].owner is node and node.outputs[0].index == 0

tests/tensor/rewriting/test_elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,7 @@ def test_add_mul_fusion_inplace(self):
11231123
out = dot(x, y) + x + y + z
11241124

11251125
f = function([x, y, z], out, mode=self.mode)
1126-
topo = [n for n in f.maker.fgraph.toposort()]
1126+
topo = list(f.maker.fgraph.toposort())
11271127
assert len(topo) == 2
11281128
assert topo[-1].op.inplace_pattern
11291129

tests/tensor/rewriting/test_math.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3994,9 +3994,9 @@ def test_is_1pexp(self):
39943994
exp_op = exp
39953995
assert is_1pexp(1 + exp_op(x), False) == (False, x)
39963996
assert is_1pexp(exp_op(x) + 1, False) == (False, x)
3997-
for neg_, exp_arg in map(
3998-
lambda x: is_1pexp(x, only_process_constants=False),
3999-
[(1 + exp_op(-x)), (exp_op(-x) + 1)],
3997+
for neg_, exp_arg in (
3998+
is_1pexp(x, only_process_constants=False)
3999+
for x in [(1 + exp_op(-x)), (exp_op(-x) + 1)]
40004000
):
40014001
assert not neg_ and is_same_graph(exp_arg, -x)
40024002
assert is_1pexp(1 - exp_op(x), False) is None

tests/tensor/rewriting/test_subtensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2004,7 +2004,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
20042004
y_val_fn = function(
20052005
[x] + list(s), y, on_unused_input="ignore", mode=no_rewrites_mode
20062006
)
2007-
y_val = y_val_fn(*([x_val] + [s_ for s_ in s_val]))
2007+
y_val = y_val_fn(*([x_val] + list(s_val)))
20082008

20092009
# This optimization should appear in the canonicalizations
20102010
y_opt = rewrite_graph(y, clone=False)
@@ -2017,7 +2017,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
20172017
assert isinstance(y_opt.owner.op, SpecifyShape)
20182018

20192019
y_opt_fn = function([x] + list(s), y_opt, on_unused_input="ignore")
2020-
y_opt_val = y_opt_fn(*([x_val] + [s_ for s_ in s_val]))
2020+
y_opt_val = y_opt_fn(*([x_val] + list(s_val)))
20212021

20222022
assert np.allclose(y_val, y_opt_val)
20232023

tests/tensor/test_blas.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2589,10 +2589,10 @@ def test_ger(self):
25892589
op=batched_dot,
25902590
expected=(
25912591
lambda xs, ys: np.asarray(
2592-
list(
2592+
[
25932593
x * y if x.ndim == 0 or y.ndim == 0 else np.dot(x, y)
25942594
for x, y in zip(xs, ys)
2595-
),
2595+
],
25962596
dtype=aes.upcast(xs.dtype, ys.dtype),
25972597
)
25982598
),
@@ -2694,7 +2694,7 @@ def check_first_dim(inverted):
26942694
assert x.strides[0] == direction * np.dtype(config.floatX).itemsize
26952695
assert not (x.flags["C_CONTIGUOUS"] or x.flags["F_CONTIGUOUS"])
26962696
result = f(x, w)
2697-
ref_result = np.asarray(list(np.dot(u, v) for u, v in zip(x, w)))
2697+
ref_result = np.asarray([np.dot(u, v) for u, v in zip(x, w)])
26982698
utt.assert_allclose(ref_result, result)
26992699

27002700
for inverted in (0, 1):

tests/tensor/test_complex.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ def test_basic(self):
1515
x = zvector()
1616
rng = np.random.default_rng(23)
1717
xval = np.asarray(
18-
list(
19-
complex(rng.standard_normal(), rng.standard_normal()) for i in range(10)
20-
)
18+
[complex(rng.standard_normal(), rng.standard_normal()) for i in range(10)]
2119
)
2220
assert np.all(xval.real == pytensor.function([x], real(x))(xval))
2321
assert np.all(xval.imag == pytensor.function([x], imag(x))(xval))

0 commit comments

Comments
 (0)