Skip to content

Commit 8d26d62

Browse files
committed
Ignore inner unused RNG inputs in collect_default_updates
1 parent fa43eba commit 8d26d62

File tree

2 files changed

+49
-18
lines changed

2 files changed

+49
-18
lines changed

pymc/pytensorf.py

+35-18
Original file line numberDiff line numberDiff line change
@@ -855,35 +855,44 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
855855

856856
# Root case, RNG is not used elsewhere
857857
if not rng_clients:
858-
return rng
858+
return None
859859

860860
if len(rng_clients) > 1:
861861
# Multiple clients are techincally fine if they are used in identical operations
862862
# We check if the default_update of each client would be the same
863-
update, *other_updates = (
863+
all_updates = [
864864
find_default_update(
865865
# Pass version of clients that includes only one the RNG clients at a time
866866
clients | {rng: [rng_client]},
867867
rng,
868868
)
869869
for rng_client in rng_clients
870-
)
871-
if all(equal_computations([update], [other_update]) for other_update in other_updates):
872-
return update
873-
874-
warnings.warn(
875-
f"RNG Variable {rng} has multiple distinct clients {rng_clients}, "
876-
f"likely due to an inconsistent random graph. "
877-
f"No default update will be returned.",
878-
UserWarning,
879-
)
880-
return None
870+
]
871+
updates = [update for update in all_updates if update is not None]
872+
if not updates:
873+
return None
874+
if len(updates) == 1:
875+
return updates[0]
876+
else:
877+
update, *other_updates = updates
878+
if all(
879+
equal_computations([update], [other_update]) for other_update in other_updates
880+
):
881+
return update
882+
883+
warnings.warn(
884+
f"RNG Variable {rng} has multiple distinct clients {rng_clients}, "
885+
f"likely due to an inconsistent random graph. "
886+
f"No default update will be returned.",
887+
UserWarning,
888+
)
889+
return None
881890

882891
[client, _] = rng_clients[0]
883892

884893
# RNG is an output of the function, this is not a problem
885894
if isinstance(client.op, Output):
886-
return rng
895+
return None
887896

888897
# RNG is used by another operator, which should output an update for the RNG
889898
if isinstance(client.op, RandomVariable):
@@ -912,18 +921,26 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
912921
)
913922
elif isinstance(client.op, OpFromGraph):
914923
try:
915-
next_rng = collect_default_updates_inner_fgraph(client)[rng]
916-
except (ValueError, KeyError):
924+
next_rng = collect_default_updates_inner_fgraph(client).get(rng)
925+
if next_rng is None:
926+
# OFG either does not make use of this RNG or inconsistent use that will have emitted a warning
927+
return None
928+
except ValueError as exc:
917929
raise ValueError(
918930
f"No update found for at least one RNG used in OpFromGraph Op {client.op}.\n"
919931
"You can use `pytensorf.collect_default_updates` and include those updates as outputs."
920-
)
932+
) from exc
921933
else:
922934
# We don't know how this RNG should be updated. The user should provide an update manually
923935
return None
924936

925937
# Recurse until we find final update for RNG
926-
return find_default_update(clients, next_rng)
938+
nested_next_rng = find_default_update(clients, next_rng)
939+
if nested_next_rng is None:
940+
# There were no more uses of this next_rng
941+
return next_rng
942+
else:
943+
return nested_next_rng
927944

928945
if inputs is None:
929946
inputs = []

tests/test_pytensorf.py

+14
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,20 @@ def test_op_from_graph_updates(self):
619619
fn = compile([], x, random_seed=1)
620620
assert not (set(fn()) & set(fn()))
621621

622+
def test_unused_ofg_rng(self):
623+
rng = pytensor.shared(np.random.default_rng())
624+
next_rng, x = pt.random.normal(rng=rng).owner.outputs
625+
ofg1 = OpFromGraph([rng], [next_rng, x])
626+
ofg2 = OpFromGraph([rng, x], [x + 1])
627+
628+
next_rng, x = ofg1(rng)
629+
y = ofg2(rng, x)
630+
631+
# In all these cases the update should be the same
632+
assert collect_default_updates([x]) == {rng: next_rng}
633+
assert collect_default_updates([y]) == {rng: next_rng}
634+
assert collect_default_updates([x, y]) == {rng: next_rng}
635+
622636

623637
def test_replace_rng_nodes():
624638
rng = pytensor.shared(np.random.default_rng())

0 commit comments

Comments
 (0)