|
| 1 | +# from torch_tensorrt.dynamo.partitioning._global_partitioner import partition |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +import torch_tensorrt |
| 5 | +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( |
| 6 | + DYNAMO_ATEN_CONVERTERS, |
| 7 | +) |
| 8 | +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( |
| 9 | + DYNAMO_CONVERTERS as CONVERTERS, |
| 10 | +) |
| 11 | +from torch_tensorrt.dynamo.lowering import ( |
| 12 | + get_decompositions, |
| 13 | + post_lowering, |
| 14 | + pre_export_lowering, |
| 15 | +) |
| 16 | +from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition |
| 17 | +from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import ( |
| 18 | + hierarchical_partition, |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +class SimpleModel(nn.Module): |
| 23 | + def __init__(self): |
| 24 | + super().__init__() |
| 25 | + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) |
| 26 | + self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
| 27 | + self.bn1 = nn.BatchNorm2d(64) |
| 28 | + self.bn2 = nn.BatchNorm2d(128) |
| 29 | + |
| 30 | + def forward(self, x): |
| 31 | + x = self.conv1(x) |
| 32 | + x = self.bn1(x) |
| 33 | + x = torch.relu(x) |
| 34 | + x = self.conv2(x) |
| 35 | + x = self.bn2(x) |
| 36 | + x = torch.relu(x) |
| 37 | + return x |
| 38 | + |
| 39 | + |
| 40 | +def main(): |
| 41 | + # Create model |
| 42 | + model = SimpleModel().cuda() |
| 43 | + # model = models.efficientnet_b0(pretrained=True).cuda() |
| 44 | + model = model.eval() |
| 45 | + |
| 46 | + # Create example input |
| 47 | + example_input = torch.randn(1, 3, 224, 224).cuda() |
| 48 | + |
| 49 | + exported_program = torch.export.export(model, (example_input,)) |
| 50 | + exported_program = pre_export_lowering(exported_program) |
| 51 | + exported_program = exported_program.run_decompositions(get_decompositions()) |
| 52 | + |
| 53 | + gm = exported_program.module() |
| 54 | + |
| 55 | + print(gm.graph) |
| 56 | + |
| 57 | + # Partition the model using the adjacency partitioner |
| 58 | + # partitioned_model, op_support = partition( |
| 59 | + # gm, |
| 60 | + # verbose=True, |
| 61 | + # min_block_size=1, |
| 62 | + # torch_executed_ops=[ |
| 63 | + # torch.ops.aten.relu.default, |
| 64 | + # ], |
| 65 | + # ) |
| 66 | + |
| 67 | + partitioned_model, op_support = hierarchical_partition( |
| 68 | + gm, |
| 69 | + verbose=True, |
| 70 | + min_block_size=1, |
| 71 | + backend_priority=["mlir", "tensorrt"], # , "inductor"], |
| 72 | + backend_support_map={ |
| 73 | + "mlir": { |
| 74 | + # operator.getitem, |
| 75 | + torch.ops.aten.conv2d.default, |
| 76 | + torch.ops.aten.convolution.default, |
| 77 | + }, |
| 78 | + "tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), |
| 79 | + # "inductor": { |
| 80 | + # torch.ops.aten.relu.default, |
| 81 | + # }, |
| 82 | + }, |
| 83 | + torch_executed_ops=[ |
| 84 | + torch.ops.aten._native_batch_norm_legit_no_training.default |
| 85 | + ], |
| 86 | + require_full_compilation=False, |
| 87 | + skip_fusion=False, |
| 88 | + ) |
| 89 | + |
| 90 | + print("\nPartitioned Model Structure:") |
| 91 | + print(partitioned_model) |
| 92 | + |
| 93 | + with torch.no_grad(): |
| 94 | + output = partitioned_model(example_input) |
| 95 | + print("Partitioned output:", output) |
| 96 | + print( |
| 97 | + "Partitioned output == original output:", |
| 98 | + torch.allclose(model(example_input), output, 1e-2, 1e-2), |
| 99 | + ) |
| 100 | + |
| 101 | + |
| 102 | +if __name__ == "__main__": |
| 103 | + main() |
0 commit comments