Skip to content

Commit a2cfdc9

Browse files
tarun292facebook-github-bot
authored andcommitted
Refactoring memory planning to allow running multiple algorithms (#8440)
Summary: This diff introduces `memory_planning_algorithm_suite` which is a method that allows us to iterate through multiple memory planning algorithms and pick the one that gives us the best results i.e. least memory consumed. The requirement for each of these algorithms is that they should generate a `MemoryAlgoResult` that contains the results of the memory planning done by that algorithm. These algos like before don't update the `TensorSpec` directly, but rather in `memory_planning_algorithm_suite` we figure out which algo gave us the best result and then update the `TensorSpec`'s with values (offsets etc.) returned by that algo. Reviewed By: JacobSzwejbka Differential Revision: D69515056
1 parent ce7aedf commit a2cfdc9

File tree

4 files changed

+137
-21
lines changed

4 files changed

+137
-21
lines changed

backends/vulkan/vulkan_preprocess.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from executorch.exir.backend.utils import DelegateMappingBuilder
4949

50-
from executorch.exir.memory_planning import greedy
50+
from executorch.exir.memory_planning import greedy, memory_planning_algorithm_suite
5151
from executorch.exir.pass_base import ExportPass, PassBase
5252

5353
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
@@ -199,11 +199,14 @@ def preprocess( # noqa: C901
199199
# Finally, apply dynamic shape passes and memory planning pass. These passes
200200
# must be applied only when the graph structure is finalized.
201201
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
202+
mem_planning_suite = partial(
203+
memory_planning_algorithm_suite, algo_list=[greedy_memory_planning]
204+
)
202205
program = apply_passes(
203206
program,
204207
[
205208
ConstraintBasedSymShapeEvalPass(),
206-
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
209+
MemoryPlanningPass(memory_planning_algo=mem_planning_suite),
207210
],
208211
)
209212

exir/memory_planning.py

Lines changed: 118 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import functools
910
import itertools
1011
import logging
1112
import operator
@@ -523,6 +524,31 @@ def __repr__(self) -> str:
523524
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
524525

525526

527+
@dataclass
528+
class SpecAllocResult:
529+
"""These are the values that a memory plannig algorithm assigns to a spec.
530+
These are not directly written back into the spec object, but are used to
531+
track the allocation decisions and assigned back to the spec object in the
532+
end, based on which algorithm is picked as the best performing one.
533+
"""
534+
535+
mem_id: int
536+
mem_obj_id: int
537+
mem_offset: int
538+
539+
540+
@dataclass
541+
class MemoryAlgoResult:
542+
"""This is the result returned by a memory planning algorithm that is
543+
invoked by memory_planning_algorithm_suite. It contains the allocation
544+
decisions of that algorithm for all the specs, and the size of the buffer
545+
that was used for different memory hierarchies.
546+
"""
547+
548+
spec_dict: Dict[TensorSpec, SpecAllocResult]
549+
bufsizes: List[int]
550+
551+
526552
def materialize_buffer(
527553
shared_objects: List[SharedObject], input_total_size: int = 0
528554
) -> int:
@@ -711,7 +737,7 @@ def greedy(
711737
alloc_graph_input: bool = True,
712738
alloc_graph_output: bool = True,
713739
allow_overlapping_allocations: bool = True,
714-
) -> List[int]:
740+
) -> MemoryAlgoResult:
715741
r"""Greedy algorithm to allocate memory for tensors in the graph.
716742
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
717743
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -720,6 +746,7 @@ def greedy(
720746
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
721747
allocations disabled
722748
"""
749+
greedy_result = MemoryAlgoResult({}, [])
723750
# padding allocation with 64 bytes.
724751
# this requirement is really for XNNPACK backend which can read tensors
725752
# beyond the end of the tensor. This is done for performance
@@ -754,11 +781,19 @@ def greedy(
754781
sorted_specs.reverse()
755782

756783
for spec in sorted_specs:
784+
# Create an entry for this TensorSpec in the result object that we'll be
785+
# returning from this algorithm.
786+
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
757787
if spec.mem_id is None:
758-
spec.mem_id = 1
788+
spec_alloc_result.mem_id = 1
789+
else:
790+
spec_alloc_result.mem_id = spec.mem_id
791+
greedy_result.spec_dict[spec] = spec_alloc_result
759792
spec.realign(alignment)
760793
spec2obj[spec] = pick_shared_obj(
761-
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
794+
shared_objects[spec_alloc_result.mem_id],
795+
spec,
796+
allow_overlapping_allocations,
762797
)
763798

764799
if len(shared_objects) == 0:
@@ -787,24 +822,89 @@ def greedy(
787822
for sobj in shared_objects[mem_id]:
788823
for alloc in sobj.allocations:
789824
spec = alloc.spec
790-
alloc.spec.mem_obj_id = sobj.idx
791-
alloc.spec.mem_offset = sobj.offset + alloc.offset
825+
# Get the spec_alloc_result for this spec and update it with the
826+
# mem_obj_id and mem_offset generated by this algorithm.
827+
spec_alloc_result = greedy_result.spec_dict.get(spec, None)
828+
assert spec_alloc_result is not None, f"Spec {spec} not found."
829+
spec_alloc_result.mem_obj_id = sobj.idx
830+
spec_alloc_result.mem_offset = sobj.offset + alloc.offset
792831
num_specs_processed += 1
793832
assert (
794833
len(spec2obj) == num_specs_processed
795834
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
796835

797836
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
798-
return total_sizes
837+
greedy_result.bufsizes = total_sizes
838+
return greedy_result
799839

800840

801-
def naive(
841+
def memory_planning_algorithm_suite(
802842
graph_module: torch.fx.GraphModule,
803843
alignment: int,
804844
graph_signature: Optional[ExportGraphSignature] = None,
805845
alloc_graph_input: bool = True,
806846
alloc_graph_output: bool = True,
847+
allow_overlapping_allocations: bool = True,
848+
algo_list: List[Callable[..., MemoryAlgoResult]] = None,
807849
) -> List[int]:
850+
r"""
851+
Memory planning algorithm suite that runs a list of memory planning algorithms
852+
and returns the result of the algorithm that minimizes the total memory usage.
853+
"""
854+
if algo_list is None:
855+
algo_list = [greedy]
856+
mem_algo_results = {}
857+
for algo in algo_list:
858+
if isinstance(algo, functools.partial):
859+
name = algo.func.__name__
860+
else:
861+
name = getattr(algo, "__name__", None)
862+
# Run this memory planning algorithm and store the result in mem_algo_results
863+
# with the name of the algorithm as the key.
864+
mem_algo_results[name] = algo(
865+
graph_module,
866+
alignment,
867+
graph_signature,
868+
alloc_graph_input,
869+
alloc_graph_output,
870+
)
871+
872+
# All the algorithms should have the same number of buffers allocated.
873+
assert (
874+
len(
875+
{
876+
len(mem_algo_result.bufsizes)
877+
for mem_algo_result in mem_algo_results.values()
878+
}
879+
)
880+
== 1
881+
), "Different memory planning algorithms should have the same number of buffers allocated."
882+
883+
# Find the algorithm that minimizes the total memory usage.
884+
best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes))
885+
logging.debug(f"Best memory planning algo for this model is {best_algo}")
886+
bufsizes = mem_algo_results[best_algo].bufsizes
887+
888+
# Update the mem_id and mem_offset for each spec in the graph module based on the
889+
# values provided by the best memory planning algorithm.
890+
for spec in mem_algo_results[best_algo].spec_dict:
891+
spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec]
892+
spec.mem_id = spec_alloc_result.mem_id
893+
spec.mem_offset = spec_alloc_result.mem_offset
894+
spec.mem_obj_id = spec_alloc_result.mem_obj_id
895+
896+
return bufsizes
897+
898+
899+
def naive(
900+
graph_module: torch.fx.GraphModule,
901+
alignment: int,
902+
graph_signature: Optional[ExportGraphSignature] = None,
903+
alloc_graph_input: bool = True,
904+
alloc_graph_output: bool = True,
905+
) -> MemoryAlgoResult:
906+
907+
naive_result = MemoryAlgoResult({}, [])
808908

809909
# allocate 'allocated' bytes from buffer with id mem_id.
810910
# return the starting offset of the allocated buffer.
@@ -826,16 +926,24 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
826926
ignore_graph_input=not alloc_graph_input,
827927
ignore_graph_output=not alloc_graph_output,
828928
):
929+
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
829930
# assume a single memory layer which has mem_id 1
830931
if spec.mem_id is None:
831-
spec.mem_id = 1
932+
spec_alloc_result.mem_id = 1
933+
else:
934+
spec_alloc_result.mem_id = spec.mem_id
935+
naive_result.spec_dict[spec] = spec_alloc_result
936+
832937
# allocate spec.allocated_memory bytes in the buffer
833938
# with the corresponding mem_id
834939
spec.realign(alignment)
835-
spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory)
940+
spec_alloc_result.mem_offset = _allocate_buf(
941+
bufsizes, spec_alloc_result.mem_id, spec.allocated_memory
942+
)
836943

837944
logging.debug(f"naive algorithm returns bufsizes: {bufsizes}")
838-
return bufsizes
945+
naive_result.bufsizes = bufsizes
946+
return naive_result
839947

840948

841949
def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
@@ -980,5 +1088,4 @@ def handle_submodule(
9801088
)
9811089

9821090
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
983-
9841091
return bufsizes

exir/passes/memory_planning_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
_is_out_var_node,
1818
apply_algo,
1919
get_node_tensor_specs,
20-
greedy,
20+
memory_planning_algorithm_suite,
2121
Verifier,
2222
)
2323
from executorch.exir.operator.convert import get_out_args_from_opoverload
@@ -40,7 +40,9 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
4040
class MemoryPlanningPass(PassBase):
4141
def __init__(
4242
self,
43-
memory_planning_algo: Callable[..., List[int]] = greedy,
43+
memory_planning_algo: Callable[
44+
..., List[int]
45+
] = memory_planning_algorithm_suite,
4446
allow_lifetime_and_storage_overlap: bool = False,
4547
alloc_graph_input: bool = True,
4648
alloc_graph_output: bool = True,

exir/tests/test_memory_planning.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import itertools
1010
import unittest
11+
from functools import partial
1112
from typing import Any, Callable, List, Optional, Tuple, Type
1213

1314
import executorch.exir as exir
@@ -19,6 +20,8 @@
1920
filter_nodes,
2021
get_node_tensor_specs,
2122
greedy,
23+
memory_planning_algorithm_suite,
24+
MemoryAlgoResult,
2225
naive,
2326
Verifier,
2427
)
@@ -234,7 +237,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
234237

235238
def maketest(
236239
module_cls: Type[torch.nn.Module],
237-
criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None,
240+
criteria: Optional[List[Tuple[Callable[..., MemoryAlgoResult], bool]]] = None,
238241
extra_check: Optional[Callable[..., None]] = None,
239242
use_functionalization: bool = True,
240243
alloc_graph_input: bool = True,
@@ -266,13 +269,13 @@ def wrapper(self: "TestMemoryPlanning") -> None:
266269
.exported_program()
267270
.graph_module
268271
)
269-
272+
mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo])
270273
graph_module = PassManager(
271274
passes=[
272275
SpecPropPass(),
273276
ToOutVarPass(),
274277
MemoryPlanningPass(
275-
algo,
278+
mem_algo,
276279
alloc_graph_input=alloc_graph_input,
277280
alloc_graph_output=alloc_graph_output,
278281
),
@@ -519,10 +522,11 @@ def test_multiple_pools(
519522
export(MultiplePoolsToyModel(), (torch.ones(1),), strict=True)
520523
)
521524

525+
mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo])
522526
edge_program.to_executorch(
523527
exir.ExecutorchBackendConfig(
524528
memory_planning_pass=CustomPoolMemoryPlanningPass(
525-
memory_planning_algo=algo,
529+
memory_planning_algo=mem_algo,
526530
alignment=1,
527531
),
528532
)
@@ -708,10 +712,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
708712
et_program = et.executorch_program
709713
inputs = et_program.execution_plan[0].inputs
710714
self.assertNotEqual(
711-
et_program.execution_plan[0] # pyre-ignore
715+
et_program.execution_plan[0]
712716
.values[inputs[0]]
713717
.val.allocation_info.memory_offset_low,
714-
et_program.execution_plan[0] # pyre-ignore
718+
et_program.execution_plan[0]
715719
.values[inputs[1]]
716720
.val.allocation_info.memory_offset_low,
717721
)

0 commit comments

Comments
 (0)