Skip to content

Commit 8606498

Browse files
ArmavicaricardoV94
authored andcommitted
Upgrade pre-commit tools and apply black format
1 parent d7ffee8 commit 8606498

File tree

150 files changed

+59
-502
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

150 files changed

+59
-502
lines changed

.pre-commit-config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ repos:
2525
- id: pyupgrade
2626
args: [--py38-plus]
2727
- repo: https://github.com/psf/black
28-
rev: 22.10.0
28+
rev: 23.1.0
2929
hooks:
3030
- id: black
3131
language_version: python3
@@ -52,7 +52,7 @@ repos:
5252
)$
5353
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
5454
- repo: https://github.com/pre-commit/mirrors-mypy
55-
rev: v0.991
55+
rev: v1.0.0
5656
hooks:
5757
- id: mypy
5858
language: python

pytensor/breakpoint.py

-2
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def __init__(self, name):
6969
self.name = name
7070

7171
def make_node(self, condition, *monitored_vars):
72-
7372
# Ensure that condition is an PyTensor tensor
7473
if not isinstance(condition, Variable):
7574
condition = as_tensor_variable(condition)
@@ -150,7 +149,6 @@ def infer_shape(self, fgraph, inputs, input_shapes):
150149
return input_shapes[1:]
151150

152151
def connection_pattern(self, node):
153-
154152
nb_inp = len(node.inputs)
155153
nb_out = nb_inp - 1
156154

pytensor/compile/builders.py

-1
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,6 @@ def connection_pattern(self, node):
900900
return list(map(list, cpmat_self))
901901

902902
def infer_shape(self, fgraph, node, shapes):
903-
904903
# TODO: Use `fgraph.shape_feature` to do this instead.
905904
out_shapes = infer_shape(self.inner_outputs, self.inner_inputs, shapes)
906905

pytensor/compile/debugmode.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,6 @@ def _check_viewmap(fgraph, node, storage_map):
567567
"""
568568

569569
for oi, onode in enumerate(node.outputs):
570-
571570
good_alias, bad_alias = {}, {}
572571
outstorage = storage_map[onode][0]
573572

@@ -590,13 +589,11 @@ def _check_viewmap(fgraph, node, storage_map):
590589
if hasattr(inode.type, "may_share_memory") and inode.type.may_share_memory(
591590
outstorage, in_storage
592591
):
593-
594592
nodeid = id(inode)
595593
bad_alias[nodeid] = ii
596594

597595
# check that the aliasing was declared in [view|destroy]_map
598596
if [ii] == view_map.get(oi, None) or [ii] == destroy_map.get(oi, None):
599-
600597
good_alias[nodeid] = bad_alias.pop(nodeid)
601598

602599
# TODO: make sure this is correct
@@ -1010,7 +1007,7 @@ def _check_preallocated_output(
10101007
aliased_inputs.add(r)
10111008

10121009
_logger.debug("starting preallocated output checking")
1013-
for (name, out_map) in _get_preallocated_maps(
1010+
for name, out_map in _get_preallocated_maps(
10141011
node,
10151012
thunk,
10161013
prealloc_modes,
@@ -1697,7 +1694,6 @@ def f():
16971694
sys.stdout.flush()
16981695

16991696
if thunk_c:
1700-
17011697
clobber = True
17021698
if thunk_py:
17031699
dmap = node.op.destroy_map

pytensor/compile/function/pfunc.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def clone_inputs(i):
176176
# Fill update_d and update_expr with provided updates
177177
if updates is None:
178178
updates = []
179-
for (store_into, update_val) in iter_over_pairs(updates):
179+
for store_into, update_val in iter_over_pairs(updates):
180180
if not isinstance(store_into, SharedVariable):
181181
raise TypeError("update target must be a SharedVariable", store_into)
182182
if store_into in update_d:
@@ -471,7 +471,6 @@ def construct_pfunc_ins_and_outs(
471471
)
472472

473473
if not fgraph:
474-
475474
# Extend the outputs with the updates on input variables so they are
476475
# also cloned
477476
additional_outputs = [i.update for i in inputs if i.update]

pytensor/compile/function/types.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,7 @@ def copy(
596596
pytensor.Function
597597
Copied pytensor.Function
598598
"""
599+
599600
# helper function
600601
def checkSV(sv_ori, sv_rpl):
601602
"""
@@ -760,7 +761,6 @@ def checkSV(sv_ori, sv_rpl):
760761
for in_ori, in_cpy, ori, cpy in zip(
761762
maker.inputs, f_cpy.maker.inputs, self.input_storage, f_cpy.input_storage
762763
):
763-
764764
# Share immutable ShareVariable and constant input's storage
765765
swapped = swap is not None and in_ori.variable in swap
766766

@@ -910,7 +910,6 @@ def restore_defaults():
910910
if hasattr(i_var.type, "may_share_memory"):
911911
is_aliased = False
912912
for j in range(len(args_share_memory)):
913-
914913
group_j = zip(
915914
[
916915
self.maker.inputs[k].variable
@@ -928,7 +927,6 @@ def restore_defaults():
928927
)
929928
for (var, val) in group_j
930929
):
931-
932930
is_aliased = True
933931
args_share_memory[j].append(i)
934932
break
@@ -1056,9 +1054,7 @@ def restore_defaults():
10561054
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
10571055
return outputs[0]
10581056
else:
1059-
10601057
if self.output_keys is not None:
1061-
10621058
assert len(self.output_keys) == len(outputs)
10631059

10641060
if output_subset is None:
@@ -1399,7 +1395,6 @@ def prepare_fgraph(
13991395
mode: "Mode",
14001396
profile,
14011397
):
1402-
14031398
rewriter = mode.optimizer
14041399

14051400
try:
@@ -1422,7 +1417,6 @@ def prepare_fgraph(
14221417
# Add deep copy to respect the memory interface
14231418
insert_deepcopy(fgraph, inputs, outputs + additional_outputs)
14241419
finally:
1425-
14261420
# If the rewriter got interrupted
14271421
if rewrite_time is None:
14281422
end_rewriter = time.perf_counter()
@@ -1598,7 +1592,6 @@ def create(self, input_storage=None, storage_map=None):
15981592
for i, ((input, indices, subinputs), input_storage_i) in enumerate(
15991593
zip(self.indices, input_storage)
16001594
):
1601-
16021595
# Replace any default value given as a variable by its
16031596
# container. Note that this makes sense only in the
16041597
# context of shared variables, but for now we avoid

pytensor/compile/profiling.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def _atexit_print_fn():
6565
destination_file = config.profiling__destination
6666

6767
with extended_open(destination_file, mode="w"):
68-
6968
# Reverse sort in the order of compile+exec time
7069
for ps in sorted(
7170
_atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time
@@ -358,7 +357,7 @@ def class_impl(self):
358357
"""
359358
# timing is stored by node, we compute timing by class on demand
360359
rval = {}
361-
for (fgraph, node) in self.apply_callcount:
360+
for fgraph, node in self.apply_callcount:
362361
typ = type(node.op)
363362
if self.apply_cimpl[node]:
364363
impl = "C "
@@ -401,7 +400,7 @@ def compute_total_times(self):
401400
402401
"""
403402
rval = {}
404-
for (fgraph, node) in self.apply_time:
403+
for fgraph, node in self.apply_time:
405404
if node not in rval:
406405
self.fill_node_total_time(fgraph, node, rval)
407406
return rval
@@ -437,7 +436,7 @@ def op_impl(self):
437436
"""
438437
# timing is stored by node, we compute timing by Op on demand
439438
rval = {}
440-
for (fgraph, node) in self.apply_callcount:
439+
for fgraph, node in self.apply_callcount:
441440
if self.apply_cimpl[node]:
442441
rval[node.op] = "C "
443442
else:
@@ -711,7 +710,7 @@ def summary_nodes(self, file=sys.stderr, N=None):
711710

712711
atimes.sort(reverse=True, key=lambda t: (t[1], t[3]))
713712
tot = 0
714-
for (f, t, a, nd_id, nb_call) in atimes[:N]:
713+
for f, t, a, nd_id, nb_call in atimes[:N]:
715714
tot += t
716715
ftot = tot * 100 / local_time
717716
if nb_call == 0:
@@ -840,7 +839,7 @@ def summary_memory(self, file, N=None):
840839
var_mem = {} # variable->size in bytes; don't include input variables
841840
node_mem = {} # (fgraph, node)->total outputs size (only dense outputs)
842841

843-
for (fgraph, node) in self.apply_callcount:
842+
for fgraph, node in self.apply_callcount:
844843
fct_memory.setdefault(fgraph, {})
845844
fct_memory[fgraph].setdefault(node, [])
846845
fct_shapes.setdefault(fgraph, {})
@@ -1610,7 +1609,7 @@ def exp_float32_op(op):
16101609
printed_tip = True
16111610

16121611
# tip 4
1613-
for (fgraph, a) in self.apply_time:
1612+
for fgraph, a in self.apply_time:
16141613
node = a
16151614
if isinstance(node.op, Dot) and all(
16161615
len(i.type.broadcastable) == 2 for i in node.inputs
@@ -1630,7 +1629,7 @@ def exp_float32_op(op):
16301629
# The tip was about MRG_RandomStream which is removed
16311630

16321631
# tip 6
1633-
for (fgraph, a) in self.apply_time:
1632+
for fgraph, a in self.apply_time:
16341633
node = a
16351634
if isinstance(node.op, Dot) and len({i.dtype for i in node.inputs}) != 1:
16361635
print(

pytensor/configdefaults.py

-5
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,6 @@ def short_platform(r=None, p=None):
278278

279279

280280
def add_basic_configvars():
281-
282281
config.add(
283282
"floatX",
284283
"Default floating-point precision for python casts.\n"
@@ -388,7 +387,6 @@ def _is_greater_or_equal_0(x):
388387

389388

390389
def add_compile_configvars():
391-
392390
config.add(
393391
"mode",
394392
"Default compilation mode",
@@ -631,7 +629,6 @@ def _is_valid_cmp_sloppy(v):
631629

632630

633631
def add_tensor_configvars():
634-
635632
# This flag is used when we import PyTensor to initialize global variables.
636633
# So changing it after import will not modify these global variables.
637634
# This could be done differently... but for now we simply prevent it from being
@@ -717,7 +714,6 @@ def add_experimental_configvars():
717714

718715

719716
def add_error_and_warning_configvars():
720-
721717
###
722718
# To disable some warning about old bug that are fixed now.
723719
###
@@ -1196,7 +1192,6 @@ def add_vm_configvars():
11961192

11971193

11981194
def add_deprecated_configvars():
1199-
12001195
# TODO: remove this? Agree
12011196
config.add(
12021197
"unittests__rseed",

pytensor/gradient.py

-8
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ def Rop(
225225
# Check that each element of wrt corresponds to an element
226226
# of eval_points with the same dimensionality.
227227
for i, (wrt_elem, eval_point) in enumerate(zip(_wrt, _eval_points)):
228-
229228
try:
230229
if wrt_elem.type.ndim != eval_point.type.ndim:
231230
raise ValueError(
@@ -262,7 +261,6 @@ def _traverse(node):
262261
# arguments, like for example random states
263262
local_eval_points.append(None)
264263
elif inp.owner in seen_nodes:
265-
266264
local_eval_points.append(
267265
seen_nodes[inp.owner][inp.owner.outputs.index(inp)]
268266
)
@@ -937,7 +935,6 @@ def account_for(var):
937935
var_idx = app.outputs.index(var)
938936

939937
for i, ipt in enumerate(app.inputs):
940-
941938
# don't process ipt if it is not a true
942939
# parent of var
943940
if not connection_pattern[i][var_idx]:
@@ -1048,7 +1045,6 @@ def access_term_cache(node):
10481045
"""Populates term_dict[node] and returns it"""
10491046

10501047
if node not in term_dict:
1051-
10521048
inputs = node.inputs
10531049

10541050
output_grads = [access_grad_cache(var) for var in node.outputs]
@@ -1263,7 +1259,6 @@ def try_to_copy_if_needed(var):
12631259
]
12641260

12651261
for i, term in enumerate(input_grads):
1266-
12671262
# Disallow Nones
12681263
if term is None:
12691264
# We don't know what None means. in the past it has been
@@ -1377,7 +1372,6 @@ def access_grad_cache(var):
13771372
node_to_idx = var_to_app_to_idx[var]
13781373
for node in node_to_idx:
13791374
for idx in node_to_idx[node]:
1380-
13811375
term = access_term_cache(node)[idx]
13821376

13831377
if not isinstance(term, Variable):
@@ -1858,7 +1852,6 @@ def random_projection():
18581852
)
18591853

18601854
if max_abs_err > abs_tol and max_rel_err > rel_tol:
1861-
18621855
raise GradientError(
18631856
max_arg,
18641857
max_err_pos,
@@ -2042,7 +2035,6 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
20422035

20432036
hessians = []
20442037
for input in wrt:
2045-
20462038
if not isinstance(input, Variable):
20472039
raise TypeError("hessian expects a (list of) Variable as `wrt`")
20482040

pytensor/graph/basic.py

-7
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,6 @@ def walk(
823823
node_hash: int = hash_fn(node)
824824

825825
if node_hash not in rval_set:
826-
827826
rval_set.add(node_hash)
828827

829828
new_nodes: Optional[Iterable[T]] = expand(node)
@@ -1323,7 +1322,6 @@ def general_toposort(
13231322
13241323
"""
13251324
if compute_deps_cache is None:
1326-
13271325
if deps_cache is None:
13281326
deps_cache = {}
13291327

@@ -1460,7 +1458,6 @@ def compute_deps_cache(obj):
14601458
return rval
14611459

14621460
else:
1463-
14641461
# the inputs are used only here in the function that decides what
14651462
# 'predecessors' to explore
14661463
def compute_deps(obj):
@@ -1511,7 +1508,6 @@ def io_connection_pattern(inputs, outputs):
15111508
# inputs and, for every node, infer their connection pattern to
15121509
# every input from the connection patterns of their parents.
15131510
for n in inner_nodes:
1514-
15151511
# Get the connection pattern of the inner node's op. If the op
15161512
# does not define a connection_pattern method, assume that
15171513
# every node output is connected to every node input
@@ -1872,15 +1868,13 @@ def compare_nodes(nd_x, nd_y, common, different):
18721868
# Compare the individual inputs for equality
18731869
for dx, dy in zip(nd_x.inputs, nd_y.inputs):
18741870
if (dx, dy) not in common:
1875-
18761871
# Equality between the variables is unknown, compare
18771872
# their respective owners, if they have some
18781873
if (
18791874
dx.owner
18801875
and dy.owner
18811876
and dx.owner.outputs.index(dx) == dy.owner.outputs.index(dy)
18821877
):
1883-
18841878
nodes_equal = compare_nodes(
18851879
dx.owner, dy.owner, common, different
18861880
)
@@ -1891,7 +1885,6 @@ def compare_nodes(nd_x, nd_y, common, different):
18911885
# If both variables don't have an owner, then they are
18921886
# inputs and can be directly compared
18931887
elif dx.owner is None and dy.owner is None:
1894-
18951888
if dx != dy:
18961889
if isinstance(dx, Constant) and isinstance(dy, Constant):
18971890
if not dx.equals(dy):

0 commit comments

Comments
 (0)