Skip to content

Commit 32e3fa4

Browse files
tarun292facebook-github-bot
authored andcommitted
Refactoring memory planning to allow running multiple algorithms (#8440)
Summary: This diff introduces `memory_planning_algorithm_suite` which is a method that allows us to iterate through multiple memory planning algorithms and pick the one that gives us the best results i.e. least memory consumed. The requirement for each of these algorithms is that they should generate a `MemoryAlgoResult` that contains the results of the memory planning done by that algorithm. These algos like before don't update the `TensorSpec` directly, but rather in `memory_planning_algorithm_suite` we figure out which algo gave us the best result and then update the `TensorSpec`'s with values (offsets etc.) returned by that algo. Differential Revision: D69515056
1 parent 4b85ee2 commit 32e3fa4

File tree

4 files changed

+111
-22
lines changed

4 files changed

+111
-22
lines changed

backends/vulkan/vulkan_preprocess.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
)
4848
from executorch.exir.backend.utils import DelegateMappingBuilder
4949

50-
from executorch.exir.memory_planning import greedy
50+
from executorch.exir.memory_planning import greedy, memory_planning_algorithm_suite
5151
from executorch.exir.pass_base import ExportPass, PassBase
5252

5353
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
@@ -199,11 +199,12 @@ def preprocess( # noqa: C901
199199
# Finally, apply dynamic shape passes and memory planning pass. These passes
200200
# must be applied only when the graph structure is finalized.
201201
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
202+
mem_planning_suite = partial(memory_planning_algorithm_suite, algo_list=[greedy_memory_planning])
202203
program = apply_passes(
203204
program,
204205
[
205206
ConstraintBasedSymShapeEvalPass(),
206-
MemoryPlanningPass(memory_planning_algo=greedy_memory_planning),
207+
MemoryPlanningPass(memory_planning_algo=mem_planning_suite),
207208
],
208209
)
209210

exir/memory_planning.py

Lines changed: 95 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
import collections
10+
import functools
911
import itertools
1012
import logging
1113
import operator
@@ -503,6 +505,27 @@ class SharedObject:
503505
def __repr__(self) -> str:
504506
return f"SharedObject(idx={self.idx}, offset={self.offset}, size={self.size}, lifetime=[{self.first_used_index, self.last_used_index}])"
505507

508+
@dataclass
509+
class SpecAllocResult:
510+
""" These are the values that a memory plannig algorithm assigns to a spec.
511+
These are not directly written back into the spec object, but are used to
512+
track the allocation decisions and assigned back to the spec object in the
513+
end, based on which algorithm is picked as the best performing one.
514+
"""
515+
mem_id: int
516+
mem_obj_id: int
517+
mem_offset: int
518+
519+
@dataclass
520+
class MemoryAlgoResult:
521+
""" This is the result returned by a memory planning algorithm that is
522+
invoked by memory_planning_algorithm_suite. It contains the allocation
523+
decisions of that algorithm for all the specs, and the size of the buffer
524+
that was used for different memory hierarchies.
525+
"""
526+
spec_dict: Dict[TensorSpec, SpecAllocResult]
527+
bufsizes: List[int]
528+
506529

507530
def materialize_buffer(
508531
shared_objects: List[SharedObject], input_total_size: int = 0
@@ -692,7 +715,7 @@ def greedy(
692715
alloc_graph_input: bool = True,
693716
alloc_graph_output: bool = True,
694717
allow_overlapping_allocations: bool = True,
695-
) -> List[int]:
718+
) -> MemoryAlgoResult:
696719
r"""Greedy algorithm to allocate memory for tensors in the graph.
697720
alloc_graph_input: If set to true, the algorithm will allocate memory for graph input.
698721
alloc_graph_output: If set to true, the algorithm will allocate memory for graph output.
@@ -701,6 +724,7 @@ def greedy(
701724
This flag is added to allow for Vulkan to use MemoryPlanningPass with overlapping
702725
allocations disabled
703726
"""
727+
greedy_result = MemoryAlgoResult({}, [])
704728
# padding allocation with 64 bytes.
705729
# this requirement is really for XNNPACK backend which can read tensors
706730
# beyond the end of the tensor. This is done for performance
@@ -735,12 +759,16 @@ def greedy(
735759
sorted_specs.reverse()
736760

737761
for spec in sorted_specs:
762+
# Create an entry for this TensorSpec in the result object that we'll be
763+
# returning from this algorithm.
764+
spec_alloc_result = greedy_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
738765
if spec.mem_id is None:
739-
spec.mem_id = 1
766+
spec_alloc_result.mem_id = 1
767+
else:
768+
spec_alloc_result.mem_id = spec.mem_id
769+
greedy_result.spec_dict[spec] = spec_alloc_result
740770
spec.realign(alignment)
741-
spec2obj[spec] = pick_shared_obj(
742-
shared_objects[spec.mem_id], spec, allow_overlapping_allocations
743-
)
771+
spec2obj[spec] = pick_shared_obj(shared_objects[spec_alloc_result.mem_id], spec, allow_overlapping_allocations)
744772

745773
if len(shared_objects) == 0:
746774
# Cannot find any tensor in the graph that needs to be allocated.
@@ -768,24 +796,73 @@ def greedy(
768796
for sobj in shared_objects[mem_id]:
769797
for alloc in sobj.allocations:
770798
spec = alloc.spec
771-
alloc.spec.mem_obj_id = sobj.idx
772-
alloc.spec.mem_offset = sobj.offset + alloc.offset
799+
# Get the spec_alloc_result for this spec and update it with the
800+
# mem_obj_id and mem_offset generated by this algorithm.
801+
spec_alloc_result = greedy_result.spec_dict.get(spec, None)
802+
assert spec_alloc_result is not None, f"Spec {spec} not found."
803+
spec_alloc_result.mem_obj_id = sobj.idx
804+
spec_alloc_result.mem_offset = sobj.offset + alloc.offset
773805
num_specs_processed += 1
774806
assert (
775807
len(spec2obj) == num_specs_processed
776808
), f"All specs should be processed but there were {len(spec2obj)} specs and processed {num_specs_processed} specs"
777809

778810
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
779-
return total_sizes
811+
greedy_result.bufsizes = total_sizes
812+
return greedy_result
780813

814+
def memory_planning_algorithm_suite(
815+
graph_module: torch.fx.GraphModule,
816+
alignment: int,
817+
graph_signature: Optional[ExportGraphSignature] = None,
818+
alloc_graph_input: bool = True,
819+
alloc_graph_output: bool = True,
820+
allow_overlapping_allocations: bool = True,
821+
algo_list: List[Callable[..., MemoryAlgoResult]] = [greedy],
822+
) -> List[int]:
823+
r"""
824+
Memory planning algorithm suite that runs a list of memory planning algorithms
825+
and returns the result of the algorithm that minimizes the total memory usage.
826+
"""
827+
mem_algo_results = {}
828+
for algo in algo_list:
829+
if isinstance(algo, functools.partial):
830+
name = algo.func.__name__
831+
else:
832+
name = getattr(algo, "__name__", None)
833+
# Run this memory planning algorithm and store the result in mem_algo_results
834+
# with the name of the algorithm as the key.
835+
mem_algo_results[name] = algo(
836+
graph_module, alignment, graph_signature, alloc_graph_input, alloc_graph_output
837+
)
838+
839+
# All the algorithms should have the same number of buffers allocated.
840+
assert len({len(mem_algo_result.bufsizes) for mem_algo_result in mem_algo_results.values()}) == 1, "Different memory planning algorithms should have the same number of buffers allocated."
841+
842+
# Find the algorithm that minimizes the total memory usage.
843+
best_algo = min(mem_algo_results, key=lambda k: sum(mem_algo_results[k].bufsizes))
844+
logging.debug(f"Best memory planning algo for this model is {best_algo}")
845+
bufsizes = mem_algo_results[best_algo].bufsizes
846+
847+
# Update the mem_id and mem_offset for each spec in the graph module based on the
848+
# values provided by the best memory planning algorithm.
849+
for spec in mem_algo_results[best_algo].spec_dict:
850+
spec_alloc_result = mem_algo_results[best_algo].spec_dict[spec]
851+
spec.mem_id = spec_alloc_result.mem_id
852+
spec.mem_offset = spec_alloc_result.mem_offset
853+
spec.mem_obj_id = spec_alloc_result.mem_obj_id
854+
855+
return bufsizes
781856

782857
def naive(
783858
graph_module: torch.fx.GraphModule,
784859
alignment: int,
785860
graph_signature: Optional[ExportGraphSignature] = None,
786861
alloc_graph_input: bool = True,
787862
alloc_graph_output: bool = True,
788-
) -> List[int]:
863+
) -> MemoryAlgoResult:
864+
865+
naive_result = MemoryAlgoResult({}, [])
789866

790867
# allocate 'allocated' bytes from buffer with id mem_id.
791868
# return the starting offset of the allocated buffer.
@@ -807,16 +884,22 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
807884
ignore_graph_input=not alloc_graph_input,
808885
ignore_graph_output=not alloc_graph_output,
809886
):
887+
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
810888
# assume a single memory layer which has mem_id 1
811889
if spec.mem_id is None:
812-
spec.mem_id = 1
890+
spec_alloc_result.mem_id = 1
891+
else:
892+
spec_alloc_result.mem_id = spec.mem_id
893+
naive_result.spec_dict[spec] = spec_alloc_result
894+
813895
# allocate spec.allocated_memory bytes in the buffer
814896
# with the corresponding mem_id
815897
spec.realign(alignment)
816-
spec.mem_offset = _allocate_buf(bufsizes, spec.mem_id, spec.allocated_memory)
898+
spec_alloc_result.mem_offset = _allocate_buf(bufsizes, spec_alloc_result.mem_id, spec.allocated_memory)
817899

818900
logging.debug(f"naive algorithm returns bufsizes: {bufsizes}")
819-
return bufsizes
901+
naive_result.bufsizes = bufsizes
902+
return naive_result
820903

821904

822905
def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
@@ -961,5 +1044,4 @@ def handle_submodule(
9611044
)
9621045

9631046
graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
964-
9651047
return bufsizes

exir/passes/memory_planning_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
apply_algo,
1919
get_node_tensor_specs,
2020
greedy,
21+
memory_planning_algorithm_suite,
22+
MemoryAlgoResult,
2123
Verifier,
2224
)
2325
from executorch.exir.operator.convert import get_out_args_from_opoverload
@@ -40,7 +42,7 @@ def _callable_name(any_callable: Callable[..., Any]) -> str:
4042
class MemoryPlanningPass(PassBase):
4143
def __init__(
4244
self,
43-
memory_planning_algo: Callable[..., List[int]] = greedy,
45+
memory_planning_algo: Callable[..., List[int]] = memory_planning_algorithm_suite,
4446
allow_lifetime_and_storage_overlap: bool = False,
4547
alloc_graph_input: bool = True,
4648
alloc_graph_output: bool = True,

exir/tests/test_memory_planning.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import itertools
1010
import unittest
11+
from functools import partial
1112
from typing import Any, Callable, List, Optional, Tuple, Type
1213

1314
import executorch.exir as exir
@@ -19,6 +20,8 @@
1920
filter_nodes,
2021
get_node_tensor_specs,
2122
greedy,
23+
memory_planning_algorithm_suite,
24+
MemoryAlgoResult,
2225
naive,
2326
Verifier,
2427
)
@@ -234,7 +237,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
234237

235238
def maketest(
236239
module_cls: Type[torch.nn.Module],
237-
criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None,
240+
criteria: Optional[List[Tuple[Callable[..., MemoryAlgoResult], bool]]] = None,
238241
extra_check: Optional[Callable[..., None]] = None,
239242
use_functionalization: bool = True,
240243
alloc_graph_input: bool = True,
@@ -266,13 +269,13 @@ def wrapper(self: "TestMemoryPlanning") -> None:
266269
.exported_program()
267270
.graph_module
268271
)
269-
272+
mem_algo = partial(memory_planning_algorithm_suite, algo_list = [algo])
270273
graph_module = PassManager(
271274
passes=[
272275
SpecPropPass(),
273276
ToOutVarPass(),
274277
MemoryPlanningPass(
275-
algo,
278+
mem_algo,
276279
alloc_graph_input=alloc_graph_input,
277280
alloc_graph_output=alloc_graph_output,
278281
),
@@ -519,10 +522,11 @@ def test_multiple_pools(
519522
export(MultiplePoolsToyModel(), (torch.ones(1),), strict=True)
520523
)
521524

525+
mem_algo = partial(memory_planning_algorithm_suite, algo_list = [algo])
522526
edge_program.to_executorch(
523527
exir.ExecutorchBackendConfig(
524528
memory_planning_pass=CustomPoolMemoryPlanningPass(
525-
memory_planning_algo=algo,
529+
memory_planning_algo=mem_algo,
526530
alignment=1,
527531
),
528532
)
@@ -708,10 +712,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
708712
et_program = et.executorch_program
709713
inputs = et_program.execution_plan[0].inputs
710714
self.assertNotEqual(
711-
et_program.execution_plan[0] # pyre-ignore
715+
et_program.execution_plan[0]
712716
.values[inputs[0]]
713717
.val.allocation_info.memory_offset_low,
714-
et_program.execution_plan[0] # pyre-ignore
718+
et_program.execution_plan[0]
715719
.values[inputs[1]]
716720
.val.allocation_info.memory_offset_low,
717721
)

0 commit comments

Comments
 (0)