@@ -393,41 +393,6 @@ def __init__(
393
393
assert len (self .input_storage ) == len (self .maker .fgraph .inputs )
394
394
assert len (self .output_storage ) == len (self .maker .fgraph .outputs )
395
395
396
- # Group indexes of inputs that are potentially aliased to each other
397
- # Note: Historically, we only worried about aliasing inputs if they belonged to the same type,
398
- # even though there could be two distinct types that use the same kinds of underlying objects.
399
- potential_aliased_input_groups = []
400
- for inp in maker .inputs :
401
- # If the input is a shared variable, the memory region is under PyTensor control
402
- # and can't be aliased.
403
- if not (
404
- isinstance (inp , In )
405
- and inp .borrow
406
- and not inp .shared
407
- and hasattr (inp .variable .type , "may_share_memory" )
408
- ):
409
- continue
410
-
411
- for group in potential_aliased_input_groups :
412
- # If one is super of the other, that means one could be replaced by the other
413
- if any (
414
- inp .variable .type .is_super (other_inp .variable .type )
415
- or other_inp .variable .type .is_super (inp .variable .type )
416
- for other_inp in group
417
- ):
418
- group .append (inp )
419
- break
420
- else : # no break
421
- # Input makes a new group
422
- potential_aliased_input_groups .append ([inp ])
423
-
424
- # Potential aliased inputs are those that belong to the same group
425
- self ._potential_aliased_input_groups : tuple [tuple [int , ...], ...] = tuple (
426
- tuple (maker .inputs .index (inp ) for inp in group )
427
- for group in potential_aliased_input_groups
428
- if len (group ) > 1
429
- )
430
-
431
396
# We will be popping stuff off this `containers` object. It is a copy.
432
397
containers = list (self .input_storage )
433
398
finder = {}
@@ -844,11 +809,18 @@ def __call__(self, *args, **kwargs):
844
809
if self .output_keys is not None :
845
810
output_subset = [self .output_keys .index (key ) for key in output_subset ]
846
811
847
- # Reinitialize each container's 'provided' counter
848
812
if self .trust_input :
813
+ # Set positional arguments
849
814
for arg_container , arg in zip (input_storage , args , strict = False ):
850
815
arg_container .storage [0 ] = arg
816
+
817
+ # Set keyword arguments
818
+ if kwargs : # for speed, skip the items for empty kwargs
819
+ for k , arg in kwargs .items ():
820
+ self [k ] = arg
821
+
851
822
else :
823
+ # Reinitialize each container's 'provided' counter
852
824
for arg_container in input_storage :
853
825
arg_container .provided = 0
854
826
@@ -899,39 +871,10 @@ def __call__(self, *args, **kwargs):
899
871
raise
900
872
arg_container .provided += 1
901
873
902
- # Set keyword arguments
903
- if kwargs : # for speed, skip the items for empty kwargs
904
- for k , arg in kwargs .items ():
905
- self [k ] = arg
906
-
907
- if not self .trust_input :
908
- # Collect aliased inputs among the storage space
909
- for potential_group in self ._potential_aliased_input_groups :
910
- args_share_memory : list [list [int ]] = []
911
- for i in potential_group :
912
- i_type = self .maker .inputs [i ].variable .type
913
- i_val = input_storage [i ].storage [0 ]
914
-
915
- # Check if value is aliased with any of the values in one of the groups
916
- for j_group in args_share_memory :
917
- if any (
918
- i_type .may_share_memory (input_storage [j ].storage [0 ], i_val )
919
- for j in j_group
920
- ):
921
- j_group .append (i )
922
- break
923
- else : # no break
924
- # Create a new group
925
- args_share_memory .append ([i ])
926
-
927
- # Check for groups of more than one argument that share memory
928
- for group in args_share_memory :
929
- if len (group ) > 1 :
930
- # copy all but the first
931
- for i in group [1 :]:
932
- input_storage [i ].storage [0 ] = copy .copy (
933
- input_storage [i ].storage [0 ]
934
- )
874
+ # Set keyword arguments
875
+ if kwargs : # for speed, skip the items for empty kwargs
876
+ for k , arg in kwargs .items ():
877
+ self [k ] = arg
935
878
936
879
# Check if inputs are missing, or if inputs were set more than once, or
937
880
# if we tried to provide inputs that are supposed to be implicit.
0 commit comments