6
6
7
7
# pyre-strict
8
8
9
+ import collections
10
+ import functools
9
11
import itertools
10
12
import logging
11
13
import operator
@@ -503,6 +505,27 @@ class SharedObject:
503
505
def __repr__ (self ) -> str :
504
506
return f"SharedObject(idx={ self .idx } , offset={ self .offset } , size={ self .size } , lifetime=[{ self .first_used_index , self .last_used_index } ])"
505
507
508
+ @dataclass
509
+ class SpecAllocResult :
510
+ """ These are the values that a memory plannig algorithm assigns to a spec.
511
+ These are not directly written back into the spec object, but are used to
512
+ track the allocation decisions and assigned back to the spec object in the
513
+ end, based on which algorithm is picked as the best performing one.
514
+ """
515
+ mem_id : int
516
+ mem_obj_id : int
517
+ mem_offset : int
518
+
519
+ @dataclass
520
+ class MemoryAlgoResult :
521
+ """ This is the result returned by a memory planning algorithm that is
522
+ invoked by memory_planning_algorithm_suite. It contains the allocation
523
+ decisions of that algorithm for all the specs, and the size of the buffer
524
+ that was used for different memory hierarchies.
525
+ """
526
+ spec_dict : Dict [TensorSpec , SpecAllocResult ]
527
+ bufsizes : List [int ]
528
+
506
529
507
530
def materialize_buffer (
508
531
shared_objects : List [SharedObject ], input_total_size : int = 0
@@ -692,7 +715,7 @@ def greedy(
692
715
alloc_graph_input : bool = True ,
693
716
alloc_graph_output : bool = True ,
694
717
allow_overlapping_allocations : bool = True ,
695
- ) -> List [ int ] :
718
+ ) -> MemoryAlgoResult :
696
719
r"""Greedy algorithm to allocate memory for tensors in the graph.
697
720
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
698
721
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -701,6 +724,7 @@ def greedy(
701
724
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
702
725
allocations disabled
703
726
"""
727
+ greedy_result = MemoryAlgoResult ({}, [])
704
728
# padding allocation with 64 bytes.
705
729
# this requirement is really for XNNPACK backend which can read tensors
706
730
# beyond the end of the tensor. This is done for performance
@@ -735,12 +759,16 @@ def greedy(
735
759
sorted_specs .reverse ()
736
760
737
761
for spec in sorted_specs :
762
+ # Create an entry for this TensorSpec in the result object that we'll be
763
+ # returning from this algorithm.
764
+ spec_alloc_result = greedy_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
738
765
if spec .mem_id is None :
739
- spec .mem_id = 1
766
+ spec_alloc_result .mem_id = 1
767
+ else :
768
+ spec_alloc_result .mem_id = spec .mem_id
769
+ greedy_result .spec_dict [spec ] = spec_alloc_result
740
770
spec .realign (alignment )
741
- spec2obj [spec ] = pick_shared_obj (
742
- shared_objects [spec .mem_id ], spec , allow_overlapping_allocations
743
- )
771
+ spec2obj [spec ] = pick_shared_obj (shared_objects [spec_alloc_result .mem_id ], spec , allow_overlapping_allocations )
744
772
745
773
if len (shared_objects ) == 0 :
746
774
# Cannot find any tensor in the graph that needs to be allocated.
@@ -768,24 +796,73 @@ def greedy(
768
796
for sobj in shared_objects [mem_id ]:
769
797
for alloc in sobj .allocations :
770
798
spec = alloc .spec
771
- alloc .spec .mem_obj_id = sobj .idx
772
- alloc .spec .mem_offset = sobj .offset + alloc .offset
799
+ # Get the spec_alloc_result for this spec and update it with the
800
+ # mem_obj_id and mem_offset generated by this algorithm.
801
+ spec_alloc_result = greedy_result .spec_dict .get (spec , None )
802
+ assert spec_alloc_result is not None , f"Spec { spec } not found."
803
+ spec_alloc_result .mem_obj_id = sobj .idx
804
+ spec_alloc_result .mem_offset = sobj .offset + alloc .offset
773
805
num_specs_processed += 1
774
806
assert (
775
807
len (spec2obj ) == num_specs_processed
776
808
), f"All specs should be processed but there were { len (spec2obj )} specs and processed { num_specs_processed } specs"
777
809
778
810
logging .debug (f"greedy algorithm returns bufsizes: { total_sizes } " )
779
- return total_sizes
811
+ greedy_result .bufsizes = total_sizes
812
+ return greedy_result
780
813
814
+ def memory_planning_algorithm_suite (
815
+ graph_module : torch .fx .GraphModule ,
816
+ alignment : int ,
817
+ graph_signature : Optional [ExportGraphSignature ] = None ,
818
+ alloc_graph_input : bool = True ,
819
+ alloc_graph_output : bool = True ,
820
+ allow_overlapping_allocations : bool = True ,
821
+ algo_list : List [Callable [..., MemoryAlgoResult ]] = [greedy ],
822
+ ) -> List [int ]:
823
+ r"""
824
+ Memory planning algorithm suite that runs a list of memory planning algorithms
825
+ and returns the result of the algorithm that minimizes the total memory usage.
826
+ """
827
+ mem_algo_results = {}
828
+ for algo in algo_list :
829
+ if isinstance (algo , functools .partial ):
830
+ name = algo .func .__name__
831
+ else :
832
+ name = getattr (algo , "__name__" , None )
833
+ # Run this memory planning algorithm and store the result in mem_algo_results
834
+ # with the name of the algorithm as the key.
835
+ mem_algo_results [name ] = algo (
836
+ graph_module , alignment , graph_signature , alloc_graph_input , alloc_graph_output
837
+ )
838
+
839
+ # All the algorithms should have the same number of buffers allocated.
840
+ assert len ({len (mem_algo_result .bufsizes ) for mem_algo_result in mem_algo_results .values ()}) == 1 , "Different memory planning algorithms should have the same number of buffers allocated."
841
+
842
+ # Find the algorithm that minimizes the total memory usage.
843
+ best_algo = min (mem_algo_results , key = lambda k : sum (mem_algo_results [k ].bufsizes ))
844
+ logging .debug (f"Best memory planning algo for this model is { best_algo } " )
845
+ bufsizes = mem_algo_results [best_algo ].bufsizes
846
+
847
+ # Update the mem_id and mem_offset for each spec in the graph module based on the
848
+ # values provided by the best memory planning algorithm.
849
+ for spec in mem_algo_results [best_algo ].spec_dict :
850
+ spec_alloc_result = mem_algo_results [best_algo ].spec_dict [spec ]
851
+ spec .mem_id = spec_alloc_result .mem_id
852
+ spec .mem_offset = spec_alloc_result .mem_offset
853
+ spec .mem_obj_id = spec_alloc_result .mem_obj_id
854
+
855
+ return bufsizes
781
856
782
857
def naive (
783
858
graph_module : torch .fx .GraphModule ,
784
859
alignment : int ,
785
860
graph_signature : Optional [ExportGraphSignature ] = None ,
786
861
alloc_graph_input : bool = True ,
787
862
alloc_graph_output : bool = True ,
788
- ) -> List [int ]:
863
+ ) -> MemoryAlgoResult :
864
+
865
+ naive_result = MemoryAlgoResult ({}, [])
789
866
790
867
# allocate 'allocated' bytes from buffer with id mem_id.
791
868
# return the starting offset of the allocated buffer.
@@ -807,16 +884,22 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
807
884
ignore_graph_input = not alloc_graph_input ,
808
885
ignore_graph_output = not alloc_graph_output ,
809
886
):
887
+ spec_alloc_result = naive_result .spec_dict .get (spec , SpecAllocResult (0 , 0 , 0 ))
810
888
# assume a single memory layer which has mem_id 1
811
889
if spec .mem_id is None :
812
- spec .mem_id = 1
890
+ spec_alloc_result .mem_id = 1
891
+ else :
892
+ spec_alloc_result .mem_id = spec .mem_id
893
+ naive_result .spec_dict [spec ] = spec_alloc_result
894
+
813
895
# allocate spec.allocated_memory bytes in the buffer
814
896
# with the corresponding mem_id
815
897
spec .realign (alignment )
816
- spec .mem_offset = _allocate_buf (bufsizes , spec .mem_id , spec .allocated_memory )
898
+ spec_alloc_result .mem_offset = _allocate_buf (bufsizes , spec_alloc_result .mem_id , spec .allocated_memory )
817
899
818
900
logging .debug (f"naive algorithm returns bufsizes: { bufsizes } " )
819
- return bufsizes
901
+ naive_result .bufsizes = bufsizes
902
+ return naive_result
820
903
821
904
822
905
def get_cond_nodes (graph_module : torch .fx .GraphModule ) -> Iterable [Node ]:
@@ -961,5 +1044,4 @@ def handle_submodule(
961
1044
)
962
1045
963
1046
graph_module .meta .update ({"non_const_buffer_sizes" : bufsizes })
964
-
965
1047
return bufsizes
0 commit comments