@@ -855,35 +855,44 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
855
855
856
856
# Root case, RNG is not used elsewhere
857
857
if not rng_clients :
858
- return rng
858
+ return None
859
859
860
860
if len (rng_clients ) > 1 :
861
861
# Multiple clients are techincally fine if they are used in identical operations
862
862
# We check if the default_update of each client would be the same
863
- update , * other_updates = (
863
+ all_updates = [
864
864
find_default_update (
865
865
# Pass version of clients that includes only one the RNG clients at a time
866
866
clients | {rng : [rng_client ]},
867
867
rng ,
868
868
)
869
869
for rng_client in rng_clients
870
- )
871
- if all (equal_computations ([update ], [other_update ]) for other_update in other_updates ):
872
- return update
873
-
874
- warnings .warn (
875
- f"RNG Variable { rng } has multiple distinct clients { rng_clients } , "
876
- f"likely due to an inconsistent random graph. "
877
- f"No default update will be returned." ,
878
- UserWarning ,
879
- )
880
- return None
870
+ ]
871
+ updates = [update for update in all_updates if update is not None ]
872
+ if not updates :
873
+ return None
874
+ if len (updates ) == 1 :
875
+ return updates [0 ]
876
+ else :
877
+ update , * other_updates = updates
878
+ if all (
879
+ equal_computations ([update ], [other_update ]) for other_update in other_updates
880
+ ):
881
+ return update
882
+
883
+ warnings .warn (
884
+ f"RNG Variable { rng } has multiple distinct clients { rng_clients } , "
885
+ f"likely due to an inconsistent random graph. "
886
+ f"No default update will be returned." ,
887
+ UserWarning ,
888
+ )
889
+ return None
881
890
882
891
[client , _ ] = rng_clients [0 ]
883
892
884
893
# RNG is an output of the function, this is not a problem
885
894
if isinstance (client .op , Output ):
886
- return rng
895
+ return None
887
896
888
897
# RNG is used by another operator, which should output an update for the RNG
889
898
if isinstance (client .op , RandomVariable ):
@@ -912,18 +921,26 @@ def find_default_update(clients, rng: Variable) -> None | Variable:
912
921
)
913
922
elif isinstance (client .op , OpFromGraph ):
914
923
try :
915
- next_rng = collect_default_updates_inner_fgraph (client )[rng ]
916
- except (ValueError , KeyError ):
924
+ next_rng = collect_default_updates_inner_fgraph (client ).get (rng )
925
+ if next_rng is None :
926
+ # OFG either does not make use of this RNG or inconsistent use that will have emitted a warning
927
+ return None
928
+ except ValueError as exc :
917
929
raise ValueError (
918
930
f"No update found for at least one RNG used in OpFromGraph Op { client .op } .\n "
919
931
"You can use `pytensorf.collect_default_updates` and include those updates as outputs."
920
- )
932
+ ) from exc
921
933
else :
922
934
# We don't know how this RNG should be updated. The user should provide an update manually
923
935
return None
924
936
925
937
# Recurse until we find final update for RNG
926
- return find_default_update (clients , next_rng )
938
+ nested_next_rng = find_default_update (clients , next_rng )
939
+ if nested_next_rng is None :
940
+ # There were no more uses of this next_rng
941
+ return next_rng
942
+ else :
943
+ return nested_next_rng
927
944
928
945
if inputs is None :
929
946
inputs = []
0 commit comments