Skip to content

Commit 14f23e0

Browse files
committed
Merge branch 'update-pre-commit' into py312
2 parents e9efa5d + ea1617f commit 14f23e0

33 files changed

+104
-70
lines changed

.pre-commit-config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ exclude: |
77
)$
88
repos:
99
- repo: https://github.com/pre-commit/pre-commit-hooks
10-
rev: v4.4.0
10+
rev: v4.5.0
1111
hooks:
1212
- id: debug-statements
1313
exclude: |
@@ -20,23 +20,23 @@ repos:
2020
)$
2121
- id: check-merge-conflict
2222
- repo: https://github.com/asottile/pyupgrade
23-
rev: v3.3.1
23+
rev: v3.15.0
2424
hooks:
2525
- id: pyupgrade
2626
args: [--py39-plus]
2727
- repo: https://github.com/psf/black
28-
rev: 23.1.0
28+
rev: 23.12.1
2929
hooks:
3030
- id: black
3131
language_version: python3
3232
- repo: https://github.com/pycqa/flake8
33-
rev: 6.0.0
33+
rev: 7.0.0
3434
hooks:
3535
- id: flake8
3636
additional_dependencies:
3737
- flake8-comprehensions
3838
- repo: https://github.com/pycqa/isort
39-
rev: 5.12.0
39+
rev: 5.13.2
4040
hooks:
4141
- id: isort
4242
- repo: https://github.com/humitos/mirrors-autoflake.git
@@ -54,7 +54,7 @@ repos:
5454
)$
5555
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
5656
- repo: https://github.com/pre-commit/mirrors-mypy
57-
rev: v1.0.0
57+
rev: v1.8.0
5858
hooks:
5959
- id: mypy
6060
language: python

pytensor/compile/debugmode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def _lessbroken_deepcopy(a):
689689
else:
690690
rval = copy.deepcopy(a)
691691

692-
assert type(rval) == type(a), (type(rval), type(a))
692+
assert type(rval) is type(a), (type(rval), type(a))
693693

694694
if isinstance(rval, np.ndarray):
695695
assert rval.dtype == a.dtype
@@ -1156,7 +1156,7 @@ def __str__(self):
11561156
return str(self.__dict__)
11571157

11581158
def __eq__(self, other):
1159-
rval = type(self) == type(other)
1159+
rval = type(self) is type(other)
11601160
if rval:
11611161
# nodes are not compared because this comparison is
11621162
# supposed to be true for corresponding events that happen

pytensor/compile/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(self, fn, itypes, otypes, infer_shape):
246246
self.infer_shape = self._infer_shape
247247

248248
def __eq__(self, other):
249-
return type(self) == type(other) and self.__fn == other.__fn
249+
return type(self) is type(other) and self.__fn == other.__fn
250250

251251
def __hash__(self):
252252
return hash(type(self)) ^ hash(self.__fn)

pytensor/compile/profiling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,8 +1084,8 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
10841084
viewof_change = []
10851085
# Use to track view_of changes
10861086

1087-
viewedby_add = defaultdict(lambda: [])
1088-
viewedby_remove = defaultdict(lambda: [])
1087+
viewedby_add = defaultdict(list)
1088+
viewedby_remove = defaultdict(list)
10891089
# Use to track viewed_by changes
10901090

10911091
for var in node.outputs:

pytensor/graph/basic.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TypeVar,
2424
Union,
2525
cast,
26+
overload,
2627
)
2728

2829
import numpy as np
@@ -718,7 +719,7 @@ def __eq__(self, other):
718719
return True
719720

720721
return (
721-
type(self) == type(other)
722+
type(self) is type(other)
722723
and self.id == other.id
723724
and self.type == other.type
724725
)
@@ -1301,9 +1302,31 @@ def clone_get_equiv(
13011302
return memo
13021303

13031304

1305+
@overload
1306+
def general_toposort(
1307+
outputs: Iterable[T],
1308+
deps: None,
1309+
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
1310+
deps_cache: Optional[dict[T, list[T]]],
1311+
clients: Optional[dict[T, list[T]]],
1312+
) -> list[T]:
1313+
...
1314+
1315+
1316+
@overload
13041317
def general_toposort(
13051318
outputs: Iterable[T],
13061319
deps: Callable[[T], Union[OrderedSet, list[T]]],
1320+
compute_deps_cache: None,
1321+
deps_cache: None,
1322+
clients: Optional[dict[T, list[T]]],
1323+
) -> list[T]:
1324+
...
1325+
1326+
1327+
def general_toposort(
1328+
outputs: Iterable[T],
1329+
deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]],
13071330
compute_deps_cache: Optional[
13081331
Callable[[T], Optional[Union[OrderedSet, list[T]]]]
13091332
] = None,
@@ -1345,7 +1368,7 @@ def general_toposort(
13451368
if deps_cache is None:
13461369
deps_cache = {}
13471370

1348-
def _compute_deps_cache(io):
1371+
def _compute_deps_cache_(io):
13491372
if io not in deps_cache:
13501373
d = deps(io)
13511374

@@ -1363,6 +1386,8 @@ def _compute_deps_cache(io):
13631386
else:
13641387
return deps_cache[io]
13651388

1389+
_compute_deps_cache = _compute_deps_cache_
1390+
13661391
else:
13671392
_compute_deps_cache = compute_deps_cache
13681393

@@ -1451,15 +1476,14 @@ def io_toposort(
14511476
)
14521477
return order
14531478

1454-
compute_deps = None
1455-
compute_deps_cache = None
14561479
iset = set(inputs)
1457-
deps_cache: dict = {}
14581480

14591481
if not orderings: # ordering can be None or empty dict
14601482
# Specialized function that is faster when no ordering.
14611483
# Also include the cache in the function itself for speed up.
14621484

1485+
deps_cache: dict = {}
1486+
14631487
def compute_deps_cache(obj):
14641488
if obj in deps_cache:
14651489
return deps_cache[obj]
@@ -1478,6 +1502,14 @@ def compute_deps_cache(obj):
14781502
deps_cache[obj] = rval
14791503
return rval
14801504

1505+
topo = general_toposort(
1506+
outputs,
1507+
deps=None,
1508+
compute_deps_cache=compute_deps_cache,
1509+
deps_cache=deps_cache,
1510+
clients=clients,
1511+
)
1512+
14811513
else:
14821514
# the inputs are used only here in the function that decides what
14831515
# 'predecessors' to explore
@@ -1494,13 +1526,13 @@ def compute_deps(obj):
14941526
assert not orderings.get(obj, None)
14951527
return rval
14961528

1497-
topo = general_toposort(
1498-
outputs,
1499-
deps=compute_deps,
1500-
compute_deps_cache=compute_deps_cache,
1501-
deps_cache=deps_cache,
1502-
clients=clients,
1503-
)
1529+
topo = general_toposort(
1530+
outputs,
1531+
deps=compute_deps,
1532+
compute_deps_cache=None,
1533+
deps_cache=None,
1534+
clients=clients,
1535+
)
15041536
return [o for o in topo if isinstance(o, Apply)]
15051537

15061538

pytensor/graph/null_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def values_eq(self, a, b, force_same_dtype=True):
3333
raise ValueError("NullType has no values to compare")
3434

3535
def __eq__(self, other):
36-
return type(self) == type(other)
36+
return type(self) is type(other)
3737

3838
def __hash__(self):
3939
return hash(type(self))

pytensor/graph/rewriting/basic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -951,8 +951,8 @@ class MetaNodeRewriter(NodeRewriter):
951951

952952
def __init__(self):
953953
self.verbose = config.metaopt__verbose
954-
self.track_dict = defaultdict(lambda: [])
955-
self.tag_dict = defaultdict(lambda: [])
954+
self.track_dict = defaultdict(list)
955+
self.tag_dict = defaultdict(list)
956956
self._tracks = []
957957
self.rewriters = []
958958

@@ -2406,13 +2406,15 @@ def importer(node):
24062406
if node is not current_node:
24072407
q.append(node)
24082408

2409-
chin = None
2409+
chin: Optional[Callable] = None
24102410
if self.tracks_on_change_inputs:
24112411

2412-
def chin(node, i, r, new_r, reason):
2412+
def chin_(node, i, r, new_r, reason):
24132413
if node is not current_node and not isinstance(node, str):
24142414
q.append(node)
24152415

2416+
chin = chin_
2417+
24162418
u = self.attach_updater(
24172419
fgraph, importer, None, chin=chin, name=getattr(self, "name", None)
24182420
)

pytensor/graph/rewriting/unify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def __new__(cls, constraint, token=None, prefix=""):
5858
return obj
5959

6060
def __eq__(self, other):
61-
if type(self) == type(other):
62-
return self.token == other.token and self.constraint == other.constraint
61+
if type(self) is type(other):
62+
return self.token is other.token and self.constraint == other.constraint
6363
return NotImplemented
6464

6565
def __hash__(self):

pytensor/graph/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def __hash__(self):
229229
if "__eq__" not in dct:
230230

231231
def __eq__(self, other):
232-
return type(self) == type(other) and tuple(
232+
return type(self) is type(other) and tuple(
233233
getattr(self, a) for a in props
234234
) == tuple(getattr(other, a) for a in props)
235235

pytensor/ifelse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, n_outs, as_view=False, name=None):
7878
self.name = name
7979

8080
def __eq__(self, other):
81-
if type(self) != type(other):
81+
if type(self) is not type(other):
8282
return False
8383
if self.as_view != other.as_view:
8484
return False

pytensor/link/c/params_type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def __hash__(self):
297297

298298
def __eq__(self, other):
299299
return (
300-
type(self) == type(other)
300+
type(self) is type(other)
301301
and self.__params_type__ == other.__params_type__
302302
and all(
303303
# NB: Params object should have been already filtered.
@@ -432,7 +432,7 @@ def __repr__(self):
432432

433433
def __eq__(self, other):
434434
return (
435-
type(self) == type(other)
435+
type(self) is type(other)
436436
and self.fields == other.fields
437437
and self.types == other.types
438438
)

pytensor/link/c/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def __hash__(self):
519519

520520
def __eq__(self, other):
521521
return (
522-
type(self) == type(other)
522+
type(self) is type(other)
523523
and self.ctype == other.ctype
524524
and len(self) == len(other)
525525
and len(self.aliases) == len(other.aliases)

pytensor/link/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def get_destroy_dependencies(fgraph: FunctionGraph) -> dict[Apply, list[Variable
818818
in destroy_dependencies.
819819
"""
820820
order = fgraph.orderings()
821-
destroy_dependencies = defaultdict(lambda: [])
821+
destroy_dependencies = defaultdict(list)
822822
for node in fgraph.apply_nodes:
823823
for prereq in order.get(node, []):
824824
destroy_dependencies[node].extend(prereq.outputs)

pytensor/raise_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
class ExceptionType(Generic):
1818
def __eq__(self, other):
19-
return type(self) == type(other)
19+
return type(self) is type(other)
2020

2121
def __hash__(self):
2222
return hash(type(self))
@@ -51,7 +51,7 @@ def __str__(self):
5151
return f"CheckAndRaise{{{self.exc_type}({self.msg})}}"
5252

5353
def __eq__(self, other):
54-
if type(self) != type(other):
54+
if type(self) is not type(other):
5555
return False
5656

5757
if self.msg == other.msg and self.exc_type == other.exc_type:

pytensor/scalar/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ def __call__(self, *types):
10741074
return [rval]
10751075

10761076
def __eq__(self, other):
1077-
return type(self) == type(other) and self.tbl == other.tbl
1077+
return type(self) is type(other) and self.tbl == other.tbl
10781078

10791079
def __hash__(self):
10801080
return hash(type(self)) # ignore hash of table
@@ -1160,7 +1160,7 @@ def L_op(self, inputs, outputs, output_gradients):
11601160
return self.grad(inputs, output_gradients)
11611161

11621162
def __eq__(self, other):
1163-
test = type(self) == type(other) and getattr(
1163+
test = type(self) is type(other) and getattr(
11641164
self, "output_types_preference", None
11651165
) == getattr(other, "output_types_preference", None)
11661166
return test
@@ -4132,7 +4132,7 @@ def __eq__(self, other):
41324132
if self is other:
41334133
return True
41344134
if (
4135-
type(self) != type(other)
4135+
type(self) is not type(other)
41364136
or self.nin != other.nin
41374137
or self.nout != other.nout
41384138
):

pytensor/scalar/math.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def c_code(self, node, name, inp, out, sub):
626626
raise NotImplementedError("only floatingpoint is implemented")
627627

628628
def __eq__(self, other):
629-
return type(self) == type(other)
629+
return type(self) is type(other)
630630

631631
def __hash__(self):
632632
return hash(type(self))
@@ -675,7 +675,7 @@ def c_code(self, node, name, inp, out, sub):
675675
raise NotImplementedError("only floatingpoint is implemented")
676676

677677
def __eq__(self, other):
678-
return type(self) == type(other)
678+
return type(self) is type(other)
679679

680680
def __hash__(self):
681681
return hash(type(self))
@@ -724,7 +724,7 @@ def c_code(self, node, name, inp, out, sub):
724724
raise NotImplementedError("only floatingpoint is implemented")
725725

726726
def __eq__(self, other):
727-
return type(self) == type(other)
727+
return type(self) is type(other)
728728

729729
def __hash__(self):
730730
return hash(type(self))
@@ -1033,7 +1033,7 @@ def c_code(self, node, name, inp, out, sub):
10331033
raise NotImplementedError("only floatingpoint is implemented")
10341034

10351035
def __eq__(self, other):
1036-
return type(self) == type(other)
1036+
return type(self) is type(other)
10371037

10381038
def __hash__(self):
10391039
return hash(type(self))
@@ -1074,7 +1074,7 @@ def c_code(self, node, name, inp, out, sub):
10741074
raise NotImplementedError("only floatingpoint is implemented")
10751075

10761076
def __eq__(self, other):
1077-
return type(self) == type(other)
1077+
return type(self) is type(other)
10781078

10791079
def __hash__(self):
10801080
return hash(type(self))

0 commit comments

Comments
 (0)