Skip to content

Collection Support [Inprogress] #802

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4434553
feat: [collection] make torch_tensorrt::core::ir::Input and torch_ten…
inocsin Jan 13, 2022
2fc1363
feat: [collection] try to defer determing the data type of tuple/list…
inocsin Feb 17, 2022
0072e37
feat: [collection] limited support for tuple input
inocsin Mar 10, 2022
b1d66cb
fix: [collection] test normal input, fix bug
inocsin Mar 10, 2022
d4e54f1
feat: [collection] support list input type
inocsin Mar 10, 2022
a9aa2e7
feat: [collection] support user defined input data type
inocsin Mar 16, 2022
5830cbe
feat: [collection] support output type of list and tuple
inocsin Mar 17, 2022
6733cfb
feat: [collection] add unit test for complex collection model
inocsin Mar 17, 2022
d21b0ab
chore: [collection] delete comments
inocsin Mar 17, 2022
eada66d
chore: [collection] update code and comments
inocsin Mar 31, 2022
633c00f
chore: [collection] rename ConversionInfo.collection_inputs to Conver…
inocsin Mar 31, 2022
89665c8
refactor: [collection] fuse Input with GraphInputs
inocsin Mar 31, 2022
205452e
feat: [collection] move collection test model to hub.py
inocsin Mar 31, 2022
a4d4131
test: [collection] update model path in test_collection.cpp
inocsin Mar 31, 2022
2d585e5
fix: [collection] solve confict in ir.cpp
inocsin Apr 5, 2022
5f36810
feat: [collection] update python api, refactor code
inocsin Apr 6, 2022
d9d8665
fix: [collection] remove aten::__getitem__ and prim::ListConstruct
inocsin Apr 8, 2022
991f023
[collection] rebase to master, update some api
inocsin Apr 12, 2022
016c991
feat: [collection] handle prim::ListConstruct without fallback it man…
inocsin Apr 14, 2022
2e7cd58
chore: [collection] update test_resolve_nontensor_inputs.cpp
inocsin Apr 14, 2022
b35cdd0
fix: [collection] handle the case that only the output is collection …
inocsin Apr 14, 2022
fa6c10e
fix: [collection] update tests/cpp/test_example_tensors.cpp
inocsin Apr 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 70 additions & 49 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ GraphAndMapping ConstructFallbackGraph(
// update the input ranges for each segments
convert_cfg.inputs = ir::associate_specs_with_inputs(seg_block.g(), inputs, static_params);

// TODO mapping Inputs Ivalue to flatten one here
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, static_params);
auto temp_g = std::make_shared<torch::jit::Graph>();
auto device_spec = convert_cfg.engine_settings.device;
Expand Down Expand Up @@ -304,57 +305,72 @@ void MapInputsAndDetermineDTypes(
CompileSpec& cfg,
std::shared_ptr<torch::jit::Graph>& g,
ir::StaticParams& static_params,
ir::TypeMap& first_use_type_map) {
// Associate input specs with inputs
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));

for (auto& in : g->inputs()) {
if (static_params.find(in) == static_params.end()) {
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
auto est_type_opt = first_use_type_map.find(in)->second;
if (est_type_opt && !spec.dtype_is_user_defined) {
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
// type
LOG_INFO(
"Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
<< in->debugName() << " has type " << est_type_opt.value()
<< ". If this is incorrect explicitly set dtype for input and file a bug");
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
// If we cannot calculate the type and the user did not define the type, then default to FP32
LOG_WARNING(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec.dtype = nvinfer1::DataType::kFLOAT;
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
if (!est_type_opt) {
LOG_INFO("Cannot infer input tensor dtype in graph. Using user provided input dtype settings");
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
} else {
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
ir::CollectionTypeMap& first_use_type_map) {
cfg.convert_info.collection_input_spec_map = std::move(ir::associate_specs_with_collection_inputs(g, cfg.graph_inputs, static_params));

auto collection_inputs = ir::get_collection_inputs(g, static_params);
LOG_DEBUG("In MapInputsAndDetermineDTypes, the g->inputs() size is " << g->inputs().size() << ", CollectionInputSpecMap size is" << collection_inputs.size());

for (auto in : collection_inputs) {
std::vector<ir::Input>& spec = cfg.convert_info.collection_input_spec_map.find(in)->second;
std::vector<c10::optional<at::ScalarType>> est_type_opt;

auto est_it = first_use_type_map.find(in);
if (est_it != first_use_type_map.end()) {
est_type_opt = first_use_type_map.find(in)->second;
}
// traverse elements in est_type_out and spec
for (int i = 0; i < est_type_opt.size(); i++) {
if (est_type_opt[i] && !spec[i].dtype_is_user_defined) {
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
// type
LOG_INFO(
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
<< in->debugName() << " has type " << est_type_opt[i].value());
spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value());
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
// If we cannot calculate the type and the user did not define the type, then default to FP32
LOG_WARNING(
"Cannot infer input type from calcuations in graph for input "
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
spec[i].dtype = nvinfer1::DataType::kFLOAT;
} else if (spec[i].dtype_is_user_defined && cfg.partition_info.enabled) {
if (!est_type_opt[i]) {
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.inputs.find(in)->second.dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt.value() << std::endl;
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ". The compiler is going to use the user setting " << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};

} else {
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) != est_type_opt[i].value()) {
std::stringstream ss;
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
ss << est_type_opt[i].value() << std::endl;
ss << "The compiler is going to use the user setting " << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
ss << "compatibility with PyTorch's data type convention is required.\n";
ss << "If you do indeed see errors at runtime either:\n";
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
ss << "- Disable partial compilation by setting require_full_compilation to True";
auto warn_str = ss.str();
LOG_WARNING(warn_str);
// Overwrite type map with user settings
first_use_type_map[in][i] = {util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
}
}
// Overwrite type map with user settings
// We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
} else {
// The user defined the type so no changes are necessary
}
} else {
// The user defined the type so no changes are necessary
}
}
}
// }
}

uint64_t GetRecommendedWorkspaceSize(const runtime::CudaDevice& device) {
Expand All @@ -376,7 +392,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
// auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());

// GPU default WS size : 1 GB
// Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
Expand Down Expand Up @@ -416,21 +433,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
auto params = graph_and_parameters.second;
auto static_params = ir::get_static_params(g->inputs(), params);
// Infer the type of an input from the weights of the calculation
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());

MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
auto outputIsCollection = conversion::OutputIsCollection(g->block());
if (cfg.partition_info.enabled &&
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
LOG_INFO("Skipping partitioning since model is fully supported");
}

if (cfg.partition_info.enabled &&
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)
|| outputIsCollection)) {

auto collection_input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.collection_input_spec_map, first_use_types);
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), collection_input_ivalues_map, cfg, static_params);
new_g = graph_and_mapping.first;
LOG_INFO("Segmented Graph: " << *new_g);

Expand All @@ -444,6 +464,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
TORCHTRT_CHECK(
conversion::VerifyConverterSupportForBlock(g->block()),
"Not all operations in graph are supported by the compiler");
// TODO find the right
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
}
Expand Down
6 changes: 4 additions & 2 deletions core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
#include "core/partitioning/partitioning.h"
#include "core/runtime/runtime.h"
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/ir/ir.h"

namespace torch_tensorrt {
namespace core {

struct CompileSpec {
CompileSpec(std::vector<ir::Input> inputs) : inputs(inputs) {}
std::vector<ir::Input> inputs;
CompileSpec(std::vector<ir::Input> inputs) : graph_inputs(inputs) {}
CompileSpec(torch::jit::IValue& input_signature) : graph_inputs(input_signature) {}
ir::GraphInputs graph_inputs;
conversion::ConversionInfo convert_info;
lowering::LowerInfo lower_info;
partitioning::PartitionInfo partition_info;
Expand Down
27 changes: 23 additions & 4 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
void AddInputs(
ConversionCtx* ctx,
c10::ArrayRef<const torch::jit::Value*> inputs,
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs) {
ConversionInfo& conversion_info) {
std::unordered_map<const torch::jit::Value*, ir::Input>& input_specs = conversion_info.inputs;
std::unordered_map<const torch::jit::Value*, std::vector<ir::Input>> collection_input_spec = conversion_info.collection_input_spec_map;

std::vector<const torch::jit::Value*> input_tensors;
for (auto in : inputs) {
// Disregarding inputs that are not tensors
Expand Down Expand Up @@ -162,9 +165,15 @@ void AddInputs(
for (auto input : input_tensors) {
const torch::jit::Value* in = input;
TORCHTRT_CHECK(
input_specs.find(in) != input_specs.end(),
input_specs.find(in) != input_specs.end() || collection_input_spec.find(in) != collection_input_spec.end(),
"Cannot find an input spec associated with input: " << in->debugName());
ir::Input& spec = input_specs.find(in)->second;
ir::Input spec;
if (input_specs.find(in) != input_specs.end()) {
spec = input_specs.find(in)->second;
} else {
spec = collection_input_spec.find(in)->second[0]; // assume input is tensor
}
// ir::Input& spec = input_specs.find(in)->second;

std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
LOG_INFO(
Expand All @@ -184,6 +193,7 @@ void AddInputs(
ctx->input_is_dynamic = true;
}

// mapping torch Value to tensorrt iTensor
ctx->value_tensor_map[in] = trt_in;
ctx->num_inputs += 1;
}
Expand Down Expand Up @@ -404,7 +414,7 @@ void ConvertBlockToNetDef(

auto inputs = b->inputs();
AddParamsToCtxValueMap(ctx, static_params);
AddInputs(ctx, inputs, build_info.inputs);
AddInputs(ctx, inputs, build_info);

auto nodes = b->nodes();

Expand Down Expand Up @@ -545,6 +555,15 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
return convertable_ops;
}

bool OutputIsCollection(const torch::jit::Block* b) {
for (auto out: b->outputs()) {
if(out->type()->kind() == torch::jit::TypeKind::TupleType || out->type()->kind() == torch::jit::TypeKind::ListType) {
return true;
}
}
return false;
}

bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) {
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
if (unsupported_ops.size() != 0) {
Expand Down
3 changes: 3 additions & 0 deletions core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace conversion {

struct ConversionInfo {
ir::InputSpecMap inputs;
ir::CollectionInputSpecMap collection_input_spec_map;
BuilderSettings engine_settings;
};

Expand All @@ -25,6 +26,8 @@ std::string ConvertBlockToEngine(

bool OpSupported(const torch::jit::Node* n);

bool OutputIsCollection(const torch::jit::Block* b);

bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);

c10::optional<torch::jit::IValue> EvaluateNode(
Expand Down
15 changes: 0 additions & 15 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,21 +264,6 @@ auto aten_registrations TORCHTRT_UNUSED =
},
EvalOptions().validSchemas(
{"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})})
.evaluator({c10::Symbol::fromQualString("aten::__getitem__"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
auto idx = args.at(n->input(1)).unwrapToInt();

const int64_t list_size = list.size();
const int64_t normalized_idx = normalizeIndex(idx, list_size);
TORCHTRT_CHECK(
normalized_idx >= 0 || normalized_idx < list_size,
"List index out of range (aten::__getitem__)");
return list.get(normalized_idx);
},
EvalOptions().validSchemas({
"aten::__getitem__.t(t[](a) list, int idx) -> (t(*))",
})})
.evaluator({c10::Symbol::fromQualString("aten::append"),
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
Expand Down
3 changes: 2 additions & 1 deletion core/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ cc_library(
srcs = [
"ir.cpp",
"Input.cpp",
"StaticParams.cpp"
"StaticParams.cpp",
"GraphInputs.cpp"
],
deps = [
"@tensorrt//:nvinfer",
Expand Down
75 changes: 75 additions & 0 deletions core/ir/GraphInputs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "core/ir/ir.h"
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace ir {

void flatten_dfs(std::vector<torch_tensorrt::core::ir::Input>& flattened_inputs, std::vector<std::vector<torch_tensorrt::core::ir::Input>>& collection_inputs,
torch::jit::IValue input_ivalue, int level, int index) {
if (input_ivalue.isTuple()) {
auto input_tuple = input_ivalue.toTuple();
int idx = 0;
if (level == 0) {
collection_inputs.resize(input_tuple->elements().size());
}
for (auto item: input_tuple->elements()) {
torch::jit::IValue converted_item;
int cur_idx = level < 1 ? idx: index;
flatten_dfs(flattened_inputs, collection_inputs, item, level+1, cur_idx);
idx++;
}
} else if(input_ivalue.isList()) {
auto input_list = input_ivalue.toList().vec();
if (level == 0) {
collection_inputs.resize(input_list.size());
}
c10::TypePtr type = input_list[0].type();
auto converted_elements = c10::impl::GenericList(type);
int idx = 0;
for (auto item: input_list) {
int cur_idx = level < 1 ? idx: index;
flatten_dfs(flattened_inputs, collection_inputs, item, level+1, cur_idx);
idx++;
}
} else if(input_ivalue.isCustomClass()) {
torch_tensorrt::core::ir::Input cur_input = *(input_ivalue.toCustomClass<torch_tensorrt::core::ir::Input>());
flattened_inputs.push_back(cur_input);
if (level == 0) { // a single value like A
collection_inputs.resize(1);
collection_inputs[0].push_back(cur_input);
} else if (level == 1) { // like A in [A, A] or [(B, B), A]
collection_inputs[index].push_back(cur_input);
} else if (level == 2) { // like A in [(A, A), C]
collection_inputs[index].push_back(cur_input);
} else {// only support 2 level
LOG_ERROR("Input nesting depth exceeds currently supported depth (3), use 1 level: [A, B], or 2 level: [A, (B, C)]");
}
}
}


GraphInputs::GraphInputs(std::vector<ir::Input> inputs_) {
LOG_DEBUG("Construct GraphInput with ir::Input");
inputs = inputs_;
collection_inputs.resize(inputs_.size());
for (int i = 0; i < inputs_.size(); i++) {
collection_inputs[i].push_back(inputs_[i]);
}
}

GraphInputs::GraphInputs(torch::jit::IValue& input_signature_) {
LOG_DEBUG("Construct GraphInput with IValue");

std::vector<torch_tensorrt::core::ir::Input> flattened_inputs;
std::vector<std::vector<torch_tensorrt::core::ir::Input>> collection_inputs_;

flatten_dfs(flattened_inputs, collection_inputs_, input_signature_, 0, 0);
inputs = flattened_inputs;
input_signature = input_signature_;
collection_inputs = collection_inputs_;
}

} // namespace ir
} // namespace core
} // namespace torch_tensorrt
5 changes: 4 additions & 1 deletion core/ir/StaticParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ StaticParams get_static_params(c10::ArrayRef<torch::jit::Value*> inputs, std::ve
StaticParams static_params;
auto param_it = params.begin();
for (auto in : inputs) {
if (in->type() != c10::TensorType::get() && param_it != params.end()) {
// handle TensorType, TupleType and ListType
if (in->type() != c10::TensorType::get() &&
in->type()->kind() != torch::jit::TypeKind::TupleType &&
in->type()->kind() != torch::jit::TypeKind::ListType && param_it != params.end()) {
static_params[in] = *param_it;
++param_it;
}
Expand Down
Loading