Skip to content

Commit 745af55

Browse files
authored
Merge pull request #1647 from pytorch/aten_size_fix
feat(//core/conversion): Add support for aten::size with dynamic shaped models for Torchscript backend.
2 parents 5a45f6b + 76dc804 commit 745af55

File tree

14 files changed

+470
-209
lines changed

14 files changed

+470
-209
lines changed

core/conversion/conversion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
6868
return {};
6969
}
7070
}
71-
auto eval = evaluators::EvalNode(n, eval_args);
71+
auto eval = evaluators::EvalNode(ctx, n, eval_args);
7272
return eval;
7373
}
7474

core/conversion/converters/impl/shuffle.cpp

+24-12
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,37 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
7070
auto in = args[0].ITensorOrFreeze(ctx);
7171
auto in_shape = util::toVec(in->getDimensions());
7272
std::vector<int64_t> new_shape;
73+
nvinfer1::ITensor* shape_tensor;
7374
if (ctx->input_is_dynamic) {
74-
new_shape = util::toVec(args[1].unwrapToIntList().vec());
75-
int nbDynamicDims = 0;
76-
for (size_t i = 0; i < new_shape.size(); i++) {
77-
if (in_shape[i] == -1)
78-
nbDynamicDims++;
79-
}
80-
if (nbDynamicDims > 1) {
81-
TORCHTRT_THROW_ERROR(
82-
"Resize is currently not supported when target shape contains more than one dynamic dimension");
75+
LOG_DEBUG("Using dynamic version of reshape layer");
76+
if (args[1].isITensorList()) {
77+
LOG_DEBUG("Shape tensor is an ITensorList");
78+
auto new_shape = args[1].unwrapToITensorList();
79+
auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size());
80+
TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
81+
concat_layer->setAxis(static_cast<int32_t>(0));
82+
shape_tensor = concat_layer->getOutput(0);
83+
} else if (args[1].isIntList()) {
84+
LOG_DEBUG("Shape tensor is an IntList");
85+
auto shape_vec = args[1].unwrapToIntList().vec();
86+
shape_tensor = tensor_to_const(ctx, torch::tensor(shape_vec).to(torch::kI32));
87+
} else {
88+
LOG_ERROR(
89+
"Invalid IValue type of " << args[1].IValue()->type()
90+
<< " detected for shape tensor from node: " << *n);
8391
}
8492
} else {
8593
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
8694
}
87-
8895
auto shuffle = ctx->net->addShuffle(*in);
89-
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
90-
shuffle->setReshapeDimensions(util::toDims(new_shape));
9196
shuffle->setName(util::node_info(n).c_str());
97+
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
98+
99+
if (ctx->input_is_dynamic) {
100+
shuffle->setInput(1, *shape_tensor);
101+
} else {
102+
shuffle->setReshapeDimensions(util::toDims(new_shape));
103+
}
92104

93105
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
94106
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ std::vector<std::string> getEvaluatorList() {
114114
return get_evaluator_registry().GetRegisteredEvaluatorList();
115115
}
116116

117-
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
117+
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
118118
auto evaluator = get_evaluator_registry().GetEvaluator(n);
119-
return evaluator(n, args);
119+
return evaluator(ctx, n, args);
120120
}
121121

122122
void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {

core/conversion/evaluators/aten.cpp

+57-44
Large diffs are not rendered by default.

core/conversion/evaluators/eval_macros.h

+126-126
Large diffs are not rendered by default.

core/conversion/evaluators/eval_util.cpp

+44-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "core/conversion/evaluators/eval_util.h"
12
#include <ATen/ATen.h>
23
#include "ATen/InitialTensorOptions.h"
34
#include "ATen/core/List.h"
@@ -6,12 +7,54 @@
67
#include "ATen/core/jit_type.h"
78
#include "c10/util/irange.h"
89
#include "core/util/prelude.h"
10+
#include "torch/torch.h"
911

1012
namespace torch_tensorrt {
1113
namespace core {
1214
namespace conversion {
1315
namespace evaluators {
1416

17+
nvinfer1::ITensor* index_layer(
18+
ConversionCtx* ctx,
19+
const torch::jit::Node* n,
20+
nvinfer1::ITensor* input_tensor,
21+
int64_t index) {
22+
// index to access needs to be an at::Tensor
23+
at::Tensor indices = torch::tensor({index}).to(torch::kI32);
24+
auto indices_out = converters::tensor_to_const(ctx, indices);
25+
26+
auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0);
27+
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
28+
auto indexed_tensor = gather_layer->getOutput(0);
29+
return indexed_tensor;
30+
}
31+
32+
c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
33+
LOG_DEBUG("Using dynamic version of aten::size evaluator");
34+
auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
35+
LOG_DEBUG("Input dimensions: " << in->getDimensions());
36+
auto shape_layer = ctx->net->addShape(*in);
37+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
38+
auto shape_1d_tensor = shape_layer->getOutput(0);
39+
40+
if (n->inputs().size() != 1) {
41+
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
42+
auto dim = args.at(n->input(1)).unwrapToInt();
43+
// Handle negative axis by refering to nbDims of input Tensor
44+
dim = dim < 0 ? dim + maxDim : dim;
45+
LOG_DEBUG("Dimension to select: " << dim);
46+
shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
47+
}
48+
49+
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
50+
51+
auto tensor_holder = TensorContainer();
52+
tensor_holder.hold_tensor(shape_1d_tensor);
53+
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
54+
55+
return shape_1d_ivalue;
56+
}
57+
1558
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
1659
if (idx < 0) {
1760
// Handle negative indexing
@@ -128,7 +171,7 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
128171
}
129172

130173
// TODO: Conditionally enable truncation based on user setting
131-
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU) {
174+
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device) {
132175
// This function is basically same with the one in
133176
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float
134177
// won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion

core/conversion/evaluators/eval_util.h

+9
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
#pragma once
22

3+
#include "core/conversion/evaluators/evaluators.h"
34
#include "torch/csrc/jit/ir/ir.h"
45

56
namespace torch_tensorrt {
67
namespace core {
78
namespace conversion {
89
namespace evaluators {
910

11+
nvinfer1::ITensor* index_layer(
12+
ConversionCtx* ctx,
13+
const torch::jit::Node* n,
14+
nvinfer1::ITensor* input_tensor,
15+
int64_t index);
16+
17+
c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args);
18+
1019
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);
1120
at::Tensor createTensorFromList(
1221
const torch::jit::IValue& data,

core/conversion/evaluators/evaluators.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
#include "torch/csrc/jit/ir/ir.h"
88

9+
#include "core/conversion/conversionctx/ConversionCtx.h"
10+
#include "core/conversion/converters/converter_util.h"
911
#include "core/conversion/tensorcontainer/TensorContainer.h"
1012
#include "core/conversion/var/Var.h"
1113

@@ -33,7 +35,8 @@ inline bool constTypesOnly(kwargs& args) {
3335
// to use the node itself to pull out arguments.
3436
// This means that you should iterate over node inputs vs. the args
3537
// when writing evaluators
36-
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, kwargs&)> NodeEvaluator;
38+
typedef std::function<c10::optional<torch::jit::IValue>(ConversionCtx*, const torch::jit::Node*, kwargs&)>
39+
NodeEvaluator;
3740

3841
struct EvalOptions {
3942
std::set<c10::TypePtr> blacklisted_output_types;
@@ -72,7 +75,7 @@ struct EvalRegistration {
7275
: kind(_kind), evaluator(_evaluator), options(_options){};
7376
};
7477

75-
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
78+
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args);
7679
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
7780
std::vector<std::string> getEvaluatorList();
7881
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);

core/conversion/evaluators/prim.cpp

+38-21
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
#include <limits>
22

3-
#include "torch/csrc/jit/ir/ir.h"
4-
//#include "torch/csrc/jit/ir/constants.h"
53
#include "ATen/core/List.h"
64
#include "ATen/core/functional.h"
75
#include "ATen/core/ivalue.h"
86
#include "ATen/core/stack.h"
97
#include "c10/util/intrusive_ptr.h"
8+
#include "torch/csrc/jit/ir/ir.h"
109
#include "torch/torch.h"
1110

1211
#include "core/conversion/evaluators/eval_macros.h"
@@ -24,28 +23,28 @@ auto prim_registrations =
2423
RegisterNodeEvaluators()
2524
.evaluator(
2625
{torch::jit::prim::Constant,
27-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
26+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
2827
if (n->output()->type()->kind() == at::FunctionType::Kind) {
2928
return {};
3029
}
3130
return evaluators::toIValue(n->output());
3231
}})
3332
.evaluator(
3433
{torch::jit::prim::NumToTensor,
35-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
34+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
3635
return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
3736
}})
3837
.evaluator(
3938
{torch::jit::prim::ListUnpack,
40-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
39+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4140
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
4241
const torch::jit::IValue* outputs = args.at(n->input()).IValue();
4342
auto outputVec = outputs->toList().vec();
4443
return std::move(c10::ivalue::Tuple::create(outputVec));
4544
}})
4645
.evaluator(
4746
{torch::jit::prim::ListConstruct,
48-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
47+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4948
const auto num_inputs = n->inputs().size();
5049
if (constTypesOnly(args)) {
5150
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
@@ -89,9 +88,8 @@ auto prim_registrations =
8988
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
9089
}
9190
} else {
92-
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
93-
c10::TypePtr elementType = lt->getElementType();
94-
auto list = c10::impl::GenericList(elementType);
91+
// List would be of IValues (with ITensors embedded in them)
92+
auto list = c10::impl::GenericList(c10::AnyType::get());
9593
list.reserve(num_inputs);
9694
for (auto in : n->inputs()) {
9795
if (args.at(in).isITensor()) {
@@ -103,8 +101,27 @@ auto prim_registrations =
103101
if (args.at(in).IValue()->isNone()) {
104102
auto ival = torch::jit::IValue();
105103
list.emplace_back(std::move(ival));
104+
} else if (args.at(in).IValue()->isInt()) {
105+
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
106+
ctx, torch::tensor({args.at(in).unwrapToInt()}).to(torch::kI32));
107+
auto tensor_holder = TensorContainer();
108+
tensor_holder.hold_tensor(itensor);
109+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
110+
list.emplace_back(std::move(ival));
111+
} else if (args.at(in).IValue()->isDouble()) {
112+
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
113+
ctx, torch::tensor({args.at(in).unwrapToDouble()}).to(torch::kFloat));
114+
auto tensor_holder = TensorContainer();
115+
tensor_holder.hold_tensor(itensor);
116+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
117+
list.emplace_back(std::move(ival));
106118
} else {
107-
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
119+
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(
120+
ctx, std::move(args.at(in).unwrapToTensor()));
121+
auto tensor_holder = TensorContainer();
122+
tensor_holder.hold_tensor(itensor);
123+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
124+
list.emplace_back(std::move(ival));
108125
}
109126
}
110127
}
@@ -113,7 +130,7 @@ auto prim_registrations =
113130
}})
114131
.evaluator(
115132
{c10::Symbol::fromQualString("prim::dtype"),
116-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
133+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
117134
auto input = args.at(n->input(0));
118135
if (input.isITensor()) {
119136
auto trt_dtype = input.ITensor()->getType();
@@ -136,7 +153,7 @@ auto prim_registrations =
136153
})})
137154
.evaluator(
138155
{c10::Symbol::fromQualString("prim::min"),
139-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
156+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
140157
if (n->inputs().size() == 1) {
141158
auto a = args.at(n->input(0)).unwrapToIntList();
142159
int64_t min = std::numeric_limits<int64_t>::max();
@@ -198,7 +215,7 @@ auto prim_registrations =
198215
})})
199216
.evaluator(
200217
{c10::Symbol::fromQualString("prim::max"),
201-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
218+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
202219
if (n->inputs().size() == 1) {
203220
auto a = args.at(n->input(0)).unwrapToIntList();
204221
int64_t max = std::numeric_limits<int64_t>::min();
@@ -260,7 +277,7 @@ auto prim_registrations =
260277
})})
261278
.evaluator(
262279
{c10::Symbol::fromQualString("prim::shape"),
263-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
280+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
264281
LOG_WARNING("There may be undefined behavior using dynamic shape and prim::shape");
265282
auto tensor_var = args.at(n->input(0));
266283
if (tensor_var.isITensor()) {
@@ -274,7 +291,7 @@ auto prim_registrations =
274291
EvalOptions().validSchemas({"prim::shape(Tensor a) -> (int[])"})})
275292
.evaluator(
276293
{torch::jit::prim::TupleConstruct,
277-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
294+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
278295
c10::IValue tuple = c10::ivalue::Tuple::create();
279296
std::vector<c10::IValue> elems;
280297
for (auto in : n->inputs()) {
@@ -292,7 +309,7 @@ auto prim_registrations =
292309
}})
293310
.evaluator(
294311
{torch::jit::prim::TupleIndex,
295-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
312+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
296313
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
297314
auto tuple = args.at(n->input(0)).IValue()->toTuple();
298315
int64_t idx = args.at(n->input(1)).IValue()->toInt();
@@ -302,24 +319,24 @@ auto prim_registrations =
302319
EvalOptions().validSchemas({"prim::TupleIndex(Any tup, int i) -> (Any)"})})
303320
.evaluator(
304321
{torch::jit::prim::TupleUnpack,
305-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
322+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
306323
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
307324
auto output = args.at(n->input()).IValue()->toTuple();
308325
return c10::optional<torch::jit::IValue>(std::move(output));
309326
}})
310327
.evaluator(
311328
{c10::Symbol::fromQualString("prim::unchecked_cast"),
312-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
329+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
313330
return *(args.at(n->input(0)).IValue());
314331
}})
315332
.evaluator(
316333
{c10::Symbol::fromQualString("prim::Uninitialized"),
317-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
334+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318335
return c10::IValue::uninitialized();
319336
}})
320337
.evaluator(
321338
{c10::Symbol::fromQualString("prim::RaiseException"),
322-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
339+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
323340
auto exception = args.at(n->input(0)).IValue();
324341
TORCHTRT_THROW_ERROR("Error from TorchScript: " << *exception);
325342
return {};
@@ -328,4 +345,4 @@ auto prim_registrations =
328345
} // namespace evaluators
329346
} // namespace conversion
330347
} // namespace core
331-
} // namespace torch_tensorrt
348+
} // namespace torch_tensorrt

core/conversion/var/Var.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,31 @@ bool Var::isITensor() const {
146146
}
147147
}
148148

149+
bool Var::isITensorList() {
150+
// Unpack the Var as a List and check if each entry is a custom class since
151+
// ITensors are stored in CustomClassHolder
152+
auto ival_list = ptr_.ivalue->toList();
153+
for (int i = 0; i < ival_list.size(); i++) {
154+
if (!ival_list.get(i).isCustomClass()) {
155+
return false;
156+
}
157+
}
158+
return true;
159+
}
160+
161+
std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList() {
162+
TORCHTRT_CHECK(
163+
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
164+
TORCHTRT_CHECK(isITensorList(), "Expected IValue to be an ITensorList");
165+
auto ivalue_list = ptr_.ivalue->toList();
166+
std::vector<nvinfer1::ITensor*> outputs;
167+
for (int i = 0; i < ivalue_list.size(); i++) {
168+
auto element = ivalue_list.get(i).toCustomClass<TensorContainer>()->tensor();
169+
outputs.push_back(std::move(element));
170+
}
171+
return outputs;
172+
}
173+
149174
bool Var::isIValue() const {
150175
if (type_ == Type::kIValue) {
151176
return true;

0 commit comments

Comments
 (0)