Skip to content

Commit f737996

Browse files
maresbricardoV94
authored andcommitted
Fix redefinitions
1 parent 5c4c6b5 commit f737996

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed

pytensor/graph/basic.py

+43-11
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
TypeVar,
2323
Union,
2424
cast,
25+
overload,
2526
)
2627

2728
import numpy as np
@@ -1301,9 +1302,31 @@ def clone_get_equiv(
13011302
return memo
13021303

13031304

1305+
@overload
1306+
def general_toposort(
1307+
outputs: Iterable[T],
1308+
deps: None,
1309+
compute_deps_cache: Callable[[T], Optional[Union[OrderedSet, list[T]]]],
1310+
deps_cache: Optional[dict[T, list[T]]],
1311+
clients: Optional[dict[T, list[T]]],
1312+
) -> list[T]:
1313+
...
1314+
1315+
1316+
@overload
13041317
def general_toposort(
13051318
outputs: Iterable[T],
13061319
deps: Callable[[T], Union[OrderedSet, list[T]]],
1320+
compute_deps_cache: None,
1321+
deps_cache: None,
1322+
clients: Optional[dict[T, list[T]]],
1323+
) -> list[T]:
1324+
...
1325+
1326+
1327+
def general_toposort(
1328+
outputs: Iterable[T],
1329+
deps: Optional[Callable[[T], Union[OrderedSet, list[T]]]],
13071330
compute_deps_cache: Optional[
13081331
Callable[[T], Optional[Union[OrderedSet, list[T]]]]
13091332
] = None,
@@ -1345,7 +1368,7 @@ def general_toposort(
13451368
if deps_cache is None:
13461369
deps_cache = {}
13471370

1348-
def _compute_deps_cache(io):
1371+
def _compute_deps_cache_(io):
13491372
if io not in deps_cache:
13501373
d = deps(io)
13511374

@@ -1363,6 +1386,8 @@ def _compute_deps_cache(io):
13631386
else:
13641387
return deps_cache[io]
13651388

1389+
_compute_deps_cache = _compute_deps_cache_
1390+
13661391
else:
13671392
_compute_deps_cache = compute_deps_cache
13681393

@@ -1451,15 +1476,14 @@ def io_toposort(
14511476
)
14521477
return order
14531478

1454-
compute_deps = None
1455-
compute_deps_cache = None
14561479
iset = set(inputs)
1457-
deps_cache: dict = {}
14581480

14591481
if not orderings: # ordering can be None or empty dict
14601482
# Specialized function that is faster when no ordering.
14611483
# Also include the cache in the function itself for speed up.
14621484

1485+
deps_cache: dict = {}
1486+
14631487
def compute_deps_cache(obj):
14641488
if obj in deps_cache:
14651489
return deps_cache[obj]
@@ -1478,6 +1502,14 @@ def compute_deps_cache(obj):
14781502
deps_cache[obj] = rval
14791503
return rval
14801504

1505+
topo = general_toposort(
1506+
outputs,
1507+
deps=None,
1508+
compute_deps_cache=compute_deps_cache,
1509+
deps_cache=deps_cache,
1510+
clients=clients,
1511+
)
1512+
14811513
else:
14821514
# the inputs are used only here in the function that decides what
14831515
# 'predecessors' to explore
@@ -1494,13 +1526,13 @@ def compute_deps(obj):
14941526
assert not orderings.get(obj, None)
14951527
return rval
14961528

1497-
topo = general_toposort(
1498-
outputs,
1499-
deps=compute_deps,
1500-
compute_deps_cache=compute_deps_cache,
1501-
deps_cache=deps_cache,
1502-
clients=clients,
1503-
)
1529+
topo = general_toposort(
1530+
outputs,
1531+
deps=compute_deps,
1532+
compute_deps_cache=None,
1533+
deps_cache=None,
1534+
clients=clients,
1535+
)
15041536
return [o for o in topo if isinstance(o, Apply)]
15051537

15061538

pytensor/graph/rewriting/basic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2405,13 +2405,15 @@ def importer(node):
24052405
if node is not current_node:
24062406
q.append(node)
24072407

2408-
chin = None
2408+
chin: Optional[Callable] = None
24092409
if self.tracks_on_change_inputs:
24102410

2411-
def chin(node, i, r, new_r, reason):
2411+
def chin_(node, i, r, new_r, reason):
24122412
if node is not current_node and not isinstance(node, str):
24132413
q.append(node)
24142414

2415+
chin = chin_
2416+
24152417
u = self.attach_updater(
24162418
fgraph, importer, None, chin=chin, name=getattr(self, "name", None)
24172419
)

tests/tensor/test_math.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1403,7 +1403,7 @@ def test_bool(self):
14031403

14041404

14051405
rng = np.random.default_rng(seed=utt.fetch_seed())
1406-
TestClip = makeTester(
1406+
TestClip1 = makeTester(
14071407
name="ClipTester",
14081408
op=clip,
14091409
expected=lambda x, y, z: np.clip(x, y, z),
@@ -1470,7 +1470,7 @@ def test_bool(self):
14701470
)
14711471

14721472

1473-
class TestClip:
1473+
class TestClip2:
14741474
def test_complex_value(self):
14751475
for dtype in ["complex64", "complex128"]:
14761476
a = vector(dtype=dtype)

0 commit comments

Comments
 (0)