Skip to content

Commit d97f668

Browse files
committed
support for hierarchical adjacency partitioner
1 parent 727cbd2 commit d97f668

File tree

3 files changed

+1469
-0
lines changed

3 files changed

+1469
-0
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# from torch_tensorrt.dynamo.partitioning._global_partitioner import partition
2+
import torch
3+
import torch.nn as nn
4+
import torch_tensorrt
5+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
6+
DYNAMO_ATEN_CONVERTERS,
7+
)
8+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
9+
DYNAMO_CONVERTERS as CONVERTERS,
10+
)
11+
from torch_tensorrt.dynamo.lowering import (
12+
get_decompositions,
13+
post_lowering,
14+
pre_export_lowering,
15+
)
16+
from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition
17+
from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import (
18+
hierarchical_partition,
19+
)
20+
21+
22+
class SimpleModel(nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
26+
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
27+
self.bn1 = nn.BatchNorm2d(64)
28+
self.bn2 = nn.BatchNorm2d(128)
29+
30+
def forward(self, x):
31+
x = self.conv1(x)
32+
x = self.bn1(x)
33+
x = torch.relu(x)
34+
x = self.conv2(x)
35+
x = self.bn2(x)
36+
x = torch.relu(x)
37+
return x
38+
39+
40+
def main():
41+
# Create model
42+
model = SimpleModel().cuda()
43+
# model = models.efficientnet_b0(pretrained=True).cuda()
44+
model = model.eval()
45+
46+
# Create example input
47+
example_input = torch.randn(1, 3, 224, 224).cuda()
48+
49+
exported_program = torch.export.export(model, (example_input,))
50+
exported_program = pre_export_lowering(exported_program)
51+
exported_program = exported_program.run_decompositions(get_decompositions())
52+
53+
gm = exported_program.module()
54+
55+
print(gm.graph)
56+
57+
# Partition the model using the adjacency partitioner
58+
# partitioned_model, op_support = partition(
59+
# gm,
60+
# verbose=True,
61+
# min_block_size=1,
62+
# torch_executed_ops=[
63+
# torch.ops.aten.relu.default,
64+
# ],
65+
# )
66+
67+
partitioned_model, op_support = hierarchical_partition(
68+
gm,
69+
verbose=True,
70+
min_block_size=1,
71+
backend_priority=["mlir", "tensorrt"], # , "inductor"],
72+
backend_support_map={
73+
"mlir": {
74+
# operator.getitem,
75+
torch.ops.aten.conv2d.default,
76+
torch.ops.aten.convolution.default,
77+
},
78+
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()),
79+
# "inductor": {
80+
# torch.ops.aten.relu.default,
81+
# },
82+
},
83+
torch_executed_ops=[
84+
torch.ops.aten._native_batch_norm_legit_no_training.default
85+
],
86+
require_full_compilation=False,
87+
skip_fusion=False,
88+
)
89+
90+
print("\nPartitioned Model Structure:")
91+
print(partitioned_model)
92+
93+
with torch.no_grad():
94+
output = partitioned_model(example_input)
95+
print("Partitioned output:", output)
96+
print(
97+
"Partitioned output == original output:",
98+
torch.allclose(model(example_input), output, 1e-2, 1e-2),
99+
)
100+
101+
102+
if __name__ == "__main__":
103+
main()

0 commit comments

Comments
 (0)