Skip to content

feat: refactoring segmentation in partitioning #1067

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
16 changes: 10 additions & 6 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ void AddIfBlockToGraph(

auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, new_g, block_graph_to_new_g); };
new_if_block->cloneFrom(cur_block_graph->block(), env);
if (cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
if (cur_block_graph->inputs().size() &&
cur_block_graph->inputs()[0]->type()->str().find("__torch__") != std::string::npos) {
if (new_g->inputs()[0]->type()->str().find("__torch__") == std::string::npos) {
auto self = new_g->insertInput(0, "self_1");
self->setType(cur_block_graph->inputs()[0]->type());
Expand All @@ -223,13 +224,14 @@ GraphAndMapping ConstructFallbackGraph(
torch::jit::Block* block,
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> example_tensor_map,
CompileSpec cfg,
ir::StaticParams static_params) {
ir::StaticParams static_params,
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
auto convert_cfg = cfg.convert_info;
auto partition_info = cfg.partition_info;

auto new_g = std::make_shared<torch::jit::Graph>();

auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info);
auto segmented_blocks = partitioning::Partition(block, example_tensor_map, partition_info, fallback_nodes);

// the mapping from lowering graph => fallback global graph
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
Expand Down Expand Up @@ -270,7 +272,7 @@ GraphAndMapping ConstructFallbackGraph(
std::vector<GraphAndMapping> graph_and_mappings;
for (auto cur_block : if_node->blocks()) {
graph_and_mappings.push_back(
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params));
ConstructFallbackGraph(new_mod, cur_block, example_tensor_map, cfg, static_params, fallback_nodes));
}
AddIfBlockToGraph(new_g, if_node, graph_and_mappings, old_to_new_g);

Expand All @@ -293,7 +295,7 @@ GraphAndMapping ConstructFallbackGraph(
// Set the output as the produced tuple
new_g->registerOutput(return_tuple_node->outputs()[0]);
} else {
if (old_to_new_g.count(block->outputs()[0])) {
if (block->outputs().size() && old_to_new_g.count(block->outputs()[0])) {
new_g->registerOutput(old_to_new_g[block->outputs()[0]]);
}
}
Expand Down Expand Up @@ -430,7 +432,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
std::unordered_map<torch::jit::Node*, int> fallback_nodes;
auto graph_and_mapping =
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
new_g = graph_and_mapping.first;
LOG_INFO("Segmented Graph: " << *new_g);

Expand Down
Loading