6
6
7
7
# pyre-strict
8
8
9
+ import functools
9
10
import itertools
10
11
import logging
11
12
import operator
@@ -523,6 +524,31 @@ def __repr__(self) -> str:
523
524
return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
524
525
525
526
527
+ @dataclass
528
+ class SpecAllocResult :
529
+ """These are the values that a memory plannig algorithm assigns to a spec.
530
+ These are not directly written back into the spec object, but are used to
531
+ track the allocation decisions and assigned back to the spec object in the
532
+ end, based on which algorithm is picked as the best performing one.
533
+ """
534
+
535
+ mem_id : int
536
+ mem_obj_id : int
537
+ mem_offset : int
538
+
539
+
540
+ @dataclass
541
+ class MemoryAlgoResult :
542
+ """This is the result returned by a memory planning algorithm that is
543
+ invoked by memory_planning_algorithm_suite. It contains the allocation
544
+ decisions of that algorithm for all the specs, and the size of the buffer
545
+ that was used for different memory hierarchies.
546
+ """
547
+
548
+ spec_dict : Dict [TensorSpec , SpecAllocResult ]
549
+ bufsizes : List [int ]
550
+
551
+
526
552
def materialize_buffer (
527
553
shared_objects : List [SharedObject ], input_total_size : int = 0
528
554
) -> int :
@@ -711,7 +737,7 @@ def greedy(
711
737
alloc_graph_input : bool = True ,
712
738
alloc_graph_output : bool = True ,
713
739
allow_overlapping_allocations : bool = True ,
714
- ) -> List [ int ] :
740
+ ) -> MemoryAlgoResult :
715
741
r"""Greedy algorithm to allocate memory for tensors in the graph.
716
742
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
717
743
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -720,6 +746,7 @@ def greedy(
720
746
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
721
747
allocations disabled
722
748
"""
749
+ greedy_result = MemoryAlgoResult ({}, [])
723
750
# padding allocation with 64 bytes.
724
751
# this requirement is really for XNNPACK backend which can read tensors
725
752
# beyond the end of the tensor. This is done for performance
@@ -754,11 +781,19 @@ def greedy(
754
781
sorted_specs .reverse ()
755
782
756
783
for spec in sorted_specs :
784
+ # Create an entry for this TensorSpec in the result object that we'll be
785
+ # returning from this algorithm.
786
+ spec_alloc_result = greedy_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
757
787
if spec .mem_id is None :
758
- spec .mem_id = 1
788
+ spec_alloc_result .mem_id = 1
789
+ else :
790
+ spec_alloc_result .mem_id = spec .mem_id
791
+ greedy_result .spec_dict [spec ] = spec_alloc_result
759
792
spec .realign (alignment )
760
793
spec2obj [spec ] = pick_shared_obj (
761
- shared_objects [spec .mem_id ], spec , allow_overlapping_allocations
794
+ shared_objects [spec_alloc_result .mem_id ],
795
+ spec ,
796
+ allow_overlapping_allocations ,
762
797
)
763
798
764
799
if len (shared_objects ) == 0 :
@@ -787,24 +822,89 @@ def greedy(
787
822
for sobj in shared_objects [mem_id ]:
788
823
for alloc in sobj .allocations :
789
824
spec = alloc .spec
790
- alloc .spec .mem_obj_id = sobj .idx
791
- alloc .spec .mem_offset = sobj .offset + alloc .offset
825
+ # Get the spec_alloc_result for this spec and update it with the
826
+ # mem_obj_id and mem_offset generated by this algorithm.
827
+ spec_alloc_result = greedy_result .spec_dict .get (spec , None )
828
+ assert spec_alloc_result is not None , f"Spec { spec } not found."
829
+ spec_alloc_result .mem_obj_id = sobj .idx
830
+ spec_alloc_result .mem_offset = sobj .offset + alloc .offset
792
831
num_specs_processed += 1
793
832
assert (
794
833
len (spec2obj ) == num_specs_processed
795
834
), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
796
835
797
836
logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
798
- return total_sizes
837
+ greedy_result .bufsizes = total_sizes
838
+ return greedy_result
799
839
800
840
801
- def naive (
841
+ def memory_planning_algorithm_suite (
802
842
graph_module : torch .fx .GraphModule ,
803
843
alignment : int ,
804
844
graph_signature : Optional [ExportGraphSignature ] = None ,
805
845
alloc_graph_input : bool = True ,
806
846
alloc_graph_output : bool = True ,
847
+ allow_overlapping_allocations : bool = True ,
848
+ algo_list : Optional [List [Callable [..., MemoryAlgoResult ]]] = None ,
807
849
) -> List [int ]:
850
+ r"""
851
+ Memory planning algorithm suite that runs a list of memory planning algorithms
852
+ and returns the result of the algorithm that minimizes the total memory usage.
853
+ """
854
+ if algo_list is None :
855
+ algo_list = [greedy ]
856
+ mem_algo_results = {}
857
+ for algo in algo_list :
858
+ if isinstance (algo , functools .partial ):
859
+ name = algo .func .__name__
860
+ else :
861
+ name = getattr (algo , "__name__" , None )
862
+ # Run this memory planning algorithm and store the result in mem_algo_results
863
+ # with the name of the algorithm as the key.
864
+ mem_algo_results [name ] = algo (
865
+ graph_module ,
866
+ alignment ,
867
+ graph_signature ,
868
+ alloc_graph_input ,
869
+ alloc_graph_output ,
870
+ )
871
+
872
+ # All the algorithms should have the same number of buffers allocated.
873
+ assert (
874
+ len (
875
+ {
876
+ len (mem_algo_result .bufsizes )
877
+ for mem_algo_result in mem_algo_results .values ()
878
+ }
879
+ )
880
+ == 1
881
+ ), "Different memory planning algorithms should have the same number of buffers allocated."
882
+
883
+ # Find the algorithm that minimizes the total memory usage.
884
+ best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
885
+ logging .debug (f"Best memory planning algo for this model is { best_algo } " )
886
+ bufsizes = mem_algo_results [best_algo ].bufsizes
887
+
888
+ # Update the mem_id and mem_offset for each spec in the graph module based on the
889
+ # values provided by the best memory planning algorithm.
890
+ for spec in mem_algo_results [best_algo ].spec_dict :
891
+ spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
892
+ spec .mem_id = spec_alloc_result .mem_id
893
+ spec .mem_offset = spec_alloc_result .mem_offset
894
+ spec .mem_obj_id = spec_alloc_result .mem_obj_id
895
+
896
+ return bufsizes
897
+
898
+
899
+ def naive (
900
+ graph_module : torch .fx .GraphModule ,
901
+ alignment : int ,
902
+ graph_signature : Optional [ExportGraphSignature ] = None ,
903
+ alloc_graph_input : bool = True ,
904
+ alloc_graph_output : bool = True ,
905
+ ) -> MemoryAlgoResult :
906
+
907
+ naive_result = MemoryAlgoResult ({}, [])
808
908
809
909
# allocate 'allocated' bytes from buffer with id mem_id.
810
910
# return the starting offset of the allocated buffer.
@@ -826,16 +926,24 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
826
926
ignore_graph_input = not alloc_graph_input ,
827
927
ignore_graph_output = not alloc_graph_output ,
828
928
):
929
+ spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
829
930
# assume a single memory layer which has mem_id 1
830
931
if spec .mem_id is None :
831
- spec .mem_id = 1
932
+ spec_alloc_result .mem_id = 1
933
+ else :
934
+ spec_alloc_result .mem_id = spec .mem_id
935
+ naive_result .spec_dict [spec ] = spec_alloc_result
936
+
832
937
# allocate spec.allocated_memory bytes in the buffer
833
938
# with the corresponding mem_id
834
939
spec .realign (alignment )
835
- spec .mem_offset = _allocate_buf (bufsizes , spec .mem_id , spec .allocated_memory )
940
+ spec_alloc_result .mem_offset = _allocate_buf (
941
+ bufsizes , spec_alloc_result .mem_id , spec .allocated_memory
942
+ )
836
943
837
944
logging .debug (f"naive algorithm returns bufsizes: { bufsizes } " )
838
- return bufsizes
945
+ naive_result .bufsizes = bufsizes
946
+ return naive_result
839
947
840
948
841
949
def get_cond_nodes (graph_module : torch .fx .GraphModule ) -> Iterable [Node ]:
@@ -980,5 +1088,4 @@ def handle_submodule(
980
1088
)
981
1089
982
1090
graph_module .meta .update ({"non_const_buffer_sizes" : bufsizes })
983
-
984
1091
return bufsizes
0 commit comments