1
1
#include < limits>
2
2
3
- #include " torch/csrc/jit/ir/ir.h"
4
- // #include "torch/csrc/jit/ir/constants.h"
5
3
#include " ATen/core/List.h"
6
4
#include " ATen/core/functional.h"
7
5
#include " ATen/core/ivalue.h"
8
6
#include " ATen/core/stack.h"
9
7
#include " c10/util/intrusive_ptr.h"
8
+ #include " torch/csrc/jit/ir/ir.h"
10
9
#include " torch/torch.h"
11
10
12
11
#include " core/conversion/evaluators/eval_macros.h"
@@ -24,28 +23,28 @@ auto prim_registrations =
24
23
RegisterNodeEvaluators ()
25
24
.evaluator(
26
25
{torch::jit::prim::Constant,
27
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
26
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
28
27
if (n->output ()->type ()->kind () == at::FunctionType::Kind) {
29
28
return {};
30
29
}
31
30
return evaluators::toIValue (n->output ());
32
31
}})
33
32
.evaluator(
34
33
{torch::jit::prim::NumToTensor,
35
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
34
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
36
35
return evaluators::scalar_to_tensor (args.at (n->input (0 )).IValue ()->toScalar ());
37
36
}})
38
37
.evaluator(
39
38
{torch::jit::prim::ListUnpack,
40
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
39
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
41
40
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
42
41
const torch::jit::IValue* outputs = args.at (n->input ()).IValue ();
43
42
auto outputVec = outputs->toList ().vec ();
44
43
return std::move (c10::ivalue::Tuple::create (outputVec));
45
44
}})
46
45
.evaluator(
47
46
{torch::jit::prim::ListConstruct,
48
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
47
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
49
48
const auto num_inputs = n->inputs ().size ();
50
49
if (constTypesOnly (args)) {
51
50
c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
@@ -89,9 +88,8 @@ auto prim_registrations =
89
88
return c10::optional<torch::jit::IValue>(std::move (torch::jit::IValue (list)));
90
89
}
91
90
} else {
92
- c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
93
- c10::TypePtr elementType = lt->getElementType ();
94
- auto list = c10::impl::GenericList (elementType);
91
+ // List would be of IValues (with ITensors embedded in them)
92
+ auto list = c10::impl::GenericList (c10::AnyType::get ());
95
93
list.reserve (num_inputs);
96
94
for (auto in : n->inputs ()) {
97
95
if (args.at (in).isITensor ()) {
@@ -103,8 +101,27 @@ auto prim_registrations =
103
101
if (args.at (in).IValue ()->isNone ()) {
104
102
auto ival = torch::jit::IValue ();
105
103
list.emplace_back (std::move (ival));
104
+ } else if (args.at (in).IValue ()->isInt ()) {
105
+ auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const (
106
+ ctx, torch::tensor ({args.at (in).unwrapToInt ()}).to (torch::kI32 ));
107
+ auto tensor_holder = TensorContainer ();
108
+ tensor_holder.hold_tensor (itensor);
109
+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
110
+ list.emplace_back (std::move (ival));
111
+ } else if (args.at (in).IValue ()->isDouble ()) {
112
+ auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const (
113
+ ctx, torch::tensor ({args.at (in).unwrapToDouble ()}).to (torch::kFloat ));
114
+ auto tensor_holder = TensorContainer ();
115
+ tensor_holder.hold_tensor (itensor);
116
+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
117
+ list.emplace_back (std::move (ival));
106
118
} else {
107
- list.emplace_back (std::move (args.at (in).unwrapToTensor ()));
119
+ auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const (
120
+ ctx, std::move (args.at (in).unwrapToTensor ()));
121
+ auto tensor_holder = TensorContainer ();
122
+ tensor_holder.hold_tensor (itensor);
123
+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
124
+ list.emplace_back (std::move (ival));
108
125
}
109
126
}
110
127
}
@@ -113,7 +130,7 @@ auto prim_registrations =
113
130
}})
114
131
.evaluator(
115
132
{c10::Symbol::fromQualString (" prim::dtype" ),
116
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
133
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
117
134
auto input = args.at (n->input (0 ));
118
135
if (input.isITensor ()) {
119
136
auto trt_dtype = input.ITensor ()->getType ();
@@ -136,7 +153,7 @@ auto prim_registrations =
136
153
})})
137
154
.evaluator(
138
155
{c10::Symbol::fromQualString (" prim::min" ),
139
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
156
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
140
157
if (n->inputs ().size () == 1 ) {
141
158
auto a = args.at (n->input (0 )).unwrapToIntList ();
142
159
int64_t min = std::numeric_limits<int64_t >::max ();
@@ -198,7 +215,7 @@ auto prim_registrations =
198
215
})})
199
216
.evaluator(
200
217
{c10::Symbol::fromQualString (" prim::max" ),
201
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
218
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
202
219
if (n->inputs ().size () == 1 ) {
203
220
auto a = args.at (n->input (0 )).unwrapToIntList ();
204
221
int64_t max = std::numeric_limits<int64_t >::min ();
@@ -260,7 +277,7 @@ auto prim_registrations =
260
277
})})
261
278
.evaluator(
262
279
{c10::Symbol::fromQualString (" prim::shape" ),
263
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
280
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
264
281
LOG_WARNING (" There may be undefined behavior using dynamic shape and prim::shape" );
265
282
auto tensor_var = args.at (n->input (0 ));
266
283
if (tensor_var.isITensor ()) {
@@ -274,7 +291,7 @@ auto prim_registrations =
274
291
EvalOptions ().validSchemas ({" prim::shape(Tensor a) -> (int[])" })})
275
292
.evaluator(
276
293
{torch::jit::prim::TupleConstruct,
277
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
294
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
278
295
c10::IValue tuple = c10::ivalue::Tuple::create ();
279
296
std::vector<c10::IValue> elems;
280
297
for (auto in : n->inputs ()) {
@@ -292,7 +309,7 @@ auto prim_registrations =
292
309
}})
293
310
.evaluator(
294
311
{torch::jit::prim::TupleIndex,
295
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
312
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
296
313
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
297
314
auto tuple = args.at (n->input (0 )).IValue ()->toTuple ();
298
315
int64_t idx = args.at (n->input (1 )).IValue ()->toInt ();
@@ -302,24 +319,24 @@ auto prim_registrations =
302
319
EvalOptions ().validSchemas ({" prim::TupleIndex(Any tup, int i) -> (Any)" })})
303
320
.evaluator(
304
321
{torch::jit::prim::TupleUnpack,
305
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
322
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
306
323
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
307
324
auto output = args.at (n->input ()).IValue ()->toTuple ();
308
325
return c10::optional<torch::jit::IValue>(std::move (output));
309
326
}})
310
327
.evaluator(
311
328
{c10::Symbol::fromQualString (" prim::unchecked_cast" ),
312
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
329
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
313
330
return *(args.at (n->input (0 )).IValue ());
314
331
}})
315
332
.evaluator(
316
333
{c10::Symbol::fromQualString (" prim::Uninitialized" ),
317
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
334
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318
335
return c10::IValue::uninitialized ();
319
336
}})
320
337
.evaluator(
321
338
{c10::Symbol::fromQualString (" prim::RaiseException" ),
322
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
339
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
323
340
auto exception = args.at (n->input (0 )).IValue ();
324
341
TORCHTRT_THROW_ERROR (" Error from TorchScript: " << *exception );
325
342
return {};
@@ -328,4 +345,4 @@ auto prim_registrations =
328
345
} // namespace evaluators
329
346
} // namespace conversion
330
347
} // namespace core
331
- } // namespace torch_tensorrt
348
+ } // namespace torch_tensorrt
0 commit comments