Skip to content

Commit d3863a8

Browse files
authored
Refactoring memory planning to allow running multiple algorithms
Differential Revision: D69515056 Pull Request resolved: #8440
1 parent 6ca39f8 commit d3863a8

File tree

4 files changed

+135
-19
lines changed

4 files changed

+135
-19
lines changed

backends/vulkan/vulkan_preprocess.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from executorch.exir.backend.utils import DelegateMappingBuilder
4949

50-
from executorch.exir.memory_planning import greedy
50+
from executorch.exir.memory_planning import greedy, memory_planning_algorithm_suite
5151
from executorch.exir.pass_base import ExportPass, PassBase
5252

5353
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
@@ -199,11 +199,14 @@ def preprocess( # noqa: C901
199199
# Finally, apply dynamic shape passes and memory planning pass. These passes
200200
# must be applied only when the graph structure is finalized.
201201
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
202+
mem_planning_suite = partial(
203+
memory_planning_algorithm_suite, algo_list=[greedy_memory_planning]
204+
)
202205
program = apply_passes(
203206
program,
204207
[
205208
ConstraintBasedSymShapeEvalPass(),
206-
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
209+
MemoryPlanningPass(memory_planning_algo=mem_planning_suite),
207210
],
208211
)
209212

exir/memory_planning.py

Lines changed: 118 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
import functools
910
import itertools
1011
import logging
1112
import operator
@@ -523,6 +524,31 @@ def __repr__(self) -> str:
523524
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
524525

525526

527+
@dataclass
528+
class SpecAllocResult:
529+
"""These are the values that a memory plannig algorithm assigns to a spec.
530+
These are not directly written back into the spec object, but are used to
531+
track the allocation decisions and assigned back to the spec object in the
532+
end, based on which algorithm is picked as the best performing one.
533+
"""
534+
535+
mem_id: int
536+
mem_obj_id: int
537+
mem_offset: int
538+
539+
540+
@dataclass
541+
class MemoryAlgoResult:
542+
"""This is the result returned by a memory planning algorithm that is
543+
invoked by memory_planning_algorithm_suite. It contains the allocation
544+
decisions of that algorithm for all the specs, and the size of the buffer
545+
that was used for different memory hierarchies.
546+
"""
547+
548+
spec_dict: Dict[TensorSpec, SpecAllocResult]
549+
bufsizes: List[int]
550+
551+
526552
def materialize_buffer(
527553
shared_objects: List[SharedObject], input_total_size: int = 0
528554
) -> int:
@@ -711,7 +737,7 @@ def greedy(
711737
alloc_graph_input: bool = True,
712738
alloc_graph_output: bool = True,
713739
allow_overlapping_allocations: bool = True,
714-
) -> List[int]:
740+
) -> MemoryAlgoResult:
715741
r"""Greedy algorithm to allocate memory for tensors in the graph.
716742
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
717743
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -720,6 +746,7 @@ def greedy(
720746
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
721747
allocations disabled
722748
"""
749+
greedy_result = MemoryAlgoResult({}, [])
723750
# padding allocation with 64 bytes.
724751
# this requirement is really for XNNPACK backend which can read tensors
725752
# beyond the end of the tensor. This is done for performance
@@ -754,11 +781,19 @@ def greedy(
754781
sorted_specs.reverse()
755782

756783
for spec in sorted_specs:
784+
# Create an entry for this TensorSpec in the result object that we'll be
785+
# returning from this algorithm.
786+
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
757787
if spec.mem_id is None:
758-
spec.mem_id = 1
788+
spec_alloc_result.mem_id = 1
789+
else:
790+
spec_alloc_result.mem_id = spec.mem_id
791+
greedy_result.spec_dict[spec] = spec_alloc_result
759792
spec.realign(alignment)
760793
spec2obj[spec] = pick_shared_obj(
761-
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
794+
shared_objects[spec_alloc_result.mem_id],
795+
spec,
796+
allow_overlapping_allocations,
762797
)
763798

764799
if len(shared_objects) == 0:
@@ -787,24 +822,89 @@ def greedy(
787822
for sobj in shared_objects[mem_id]:
788823
for alloc in sobj.allocations:
789824
spec = alloc.spec
790-
alloc.spec.mem_obj_id = sobj.idx
791-
alloc.spec.mem_offset = sobj.offset + alloc.offset
825+
# Get the spec_alloc_result for this spec and update it with the
826+
# mem_obj_id and mem_offset generated by this algorithm.
827+
spec_alloc_result = greedy_result.spec_dict.get(spec, None)
828+
assert spec_alloc_result is not None, f"Spec {spec} not found."
829+
spec_alloc_result.mem_obj_id = sobj.idx
830+
spec_alloc_result.mem_offset = sobj.offset + alloc.offset
792831
num_specs_processed += 1
793832
assert (
794833
len(spec2obj) == num_specs_processed
795834
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
796835

797836
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
798-
return total_sizes
837+
greedy_result.bufsizes = total_sizes
838+
return greedy_result
799839

800840

801-
def naive(
841+
def memory_planning_algorithm_suite(
802842
graph_module: torch.fx.GraphModule,
803843
alignment: int,
804844
graph_signature: Optional[ExportGraphSignature] = None,
805845
alloc_graph_input: bool = True,
806846
alloc_graph_output: bool = True,
847+
allow_overlapping_allocations: bool = True,
848+
algo_list: Optional[List[Callable[..., MemoryAlgoResult]]] = None,
807849
) -> List[int]:
850+
r"""
851+
Memory planning algorithm suite that runs a list of memory planning algorithms
852+
and returns the result of the algorithm that minimizes the total memory usage.
853+
"""
854+
if algo_list is None:
855+
algo_list = [greedy]
856+
mem_algo_results = {}
857+
for algo in algo_list:
858+
if isinstance(algo, functools.partial):
859+
name = algo.func.__name__
860+
else:
861+
name = getattr(algo, "__name__", None)
862+
# Run this memory planning algorithm and store the result in mem_algo_results
863+
# with the name of the algorithm as the key.
864+
mem_algo_results[name] = algo(
865+
graph_module,
866+
alignment,
867+
graph_signature,
868+
alloc_graph_input,
869+
alloc_graph_output,
870+
)
871+
872+
# All the algorithms should have the same number of buffers allocated.
873+
assert (
874+
len(
875+
{
876+
len(mem_algo_result.bufsizes)
877+
for mem_algo_result in mem_algo_results.values()
878+
}
879+
)
880+
== 1
881+
), "Different memory planning algorithms should have the same number of buffers allocated."
882+
883+
# Find the algorithm that minimizes the total memory usage.
884+
best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes))
885+
logging.debug(f"Best memory planning algo for this model is {best_algo}")
886+
bufsizes = mem_algo_results[best_algo].bufsizes
887+
888+
# Update the mem_id and mem_offset for each spec in the graph module based on the
889+
# values provided by the best memory planning algorithm.
890+
for spec in mem_algo_results[best_algo].spec_dict:
891+
spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec]
892+
spec.mem_id = spec_alloc_result.mem_id
893+
spec.mem_offset = spec_alloc_result.mem_offset
894+
spec.mem_obj_id = spec_alloc_result.mem_obj_id
895+
896+
return bufsizes
897+
898+
899+
def naive(
900+
graph_module: torch.fx.GraphModule,
901+
alignment: int,
902+
graph_signature: Optional[ExportGraphSignature] = None,
903+
alloc_graph_input: bool = True,
904+
alloc_graph_output: bool = True,
905+
) -> MemoryAlgoResult:
906+
907+
naive_result = MemoryAlgoResult({}, [])
808908

809909
# allocate 'allocated' bytes from buffer with id mem_id.
810910
# return the starting offset of the allocated buffer.
@@ -826,16 +926,24 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
826926
ignore_graph_input=not alloc_graph_input,
827927
ignore_graph_output=not alloc_graph_output,
828928
):
929+
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
829930
# assume a single memory layer which has mem_id 1
830931
if spec.mem_id is None:
831-
spec.mem_id = 1
932+
spec_alloc_result.mem_id = 1
933+
else:
934+
spec_alloc_result.mem_id = spec.mem_id
935+
naive_result.spec_dict[spec] = spec_alloc_result
936+
832937
# allocate spec.allocated_memory bytes in the buffer
833938
# with the corresponding mem_id
834939
spec.realign(alignment)
835-
spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory)
940+
spec_alloc_result.mem_offset = _allocate_buf(
941+
bufsizes, spec_alloc_result.mem_id, spec.allocated_memory
942+
)
836943

837944
logging.debug(f"naive algorithm returns bufsizes: {bufsizes}")
838-
return bufsizes
945+
naive_result.bufsizes = bufsizes
946+
return naive_result
839947

840948

841949
def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
@@ -980,5 +1088,4 @@ def handle_submodule(
9801088
)
9811089

9821090
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
983-
9841091
return bufsizes

exir/passes/memory_planning_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
_is_out_var_node,
1818
apply_algo,
1919
get_node_tensor_specs,
20-
greedy,
20+
memory_planning_algorithm_suite,
2121
Verifier,
2222
)
2323
from executorch.exir.operator.convert import get_out_args_from_opoverload
@@ -40,7 +40,9 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
4040
class MemoryPlanningPass(PassBase):
4141
def __init__(
4242
self,
43-
memory_planning_algo: Callable[..., List[int]] = greedy,
43+
memory_planning_algo: Callable[
44+
..., List[int]
45+
] = memory_planning_algorithm_suite,
4446
allow_lifetime_and_storage_overlap: bool = False,
4547
alloc_graph_input: bool = True,
4648
alloc_graph_output: bool = True,

exir/tests/test_memory_planning.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import itertools
1010
import unittest
11+
from functools import partial
1112
from typing import Any, Callable, List, Optional, Tuple, Type
1213

1314
import executorch.exir as exir
@@ -19,6 +20,8 @@
1920
filter_nodes,
2021
get_node_tensor_specs,
2122
greedy,
23+
memory_planning_algorithm_suite,
24+
MemoryAlgoResult,
2225
naive,
2326
Verifier,
2427
)
@@ -234,7 +237,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
234237

235238
def maketest(
236239
module_cls: Type[torch.nn.Module],
237-
criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None,
240+
criteria: Optional[List[Tuple[Callable[..., MemoryAlgoResult], bool]]] = None,
238241
extra_check: Optional[Callable[..., None]] = None,
239242
use_functionalization: bool = True,
240243
alloc_graph_input: bool = True,
@@ -266,13 +269,13 @@ def wrapper(self: "TestMemoryPlanning") -> None:
266269
.exported_program()
267270
.graph_module
268271
)
269-
272+
mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo])
270273
graph_module = PassManager(
271274
passes=[
272275
SpecPropPass(),
273276
ToOutVarPass(),
274277
MemoryPlanningPass(
275-
algo,
278+
mem_algo,
276279
alloc_graph_input=alloc_graph_input,
277280
alloc_graph_output=alloc_graph_output,
278281
),
@@ -519,10 +522,11 @@ def test_multiple_pools(
519522
export(MultiplePoolsToyModel(), (torch.ones(1),), strict=True)
520523
)
521524

525+
mem_algo = partial(memory_planning_algorithm_suite, algo_list=[algo])
522526
edge_program.to_executorch(
523527
exir.ExecutorchBackendConfig(
524528
memory_planning_pass=CustomPoolMemoryPlanningPass(
525-
memory_planning_algo=algo,
529+
memory_planning_algo=mem_algo,
526530
alignment=1,
527531
),
528532
)

0 commit comments

Comments
 (0)