Skip to content

Commit 0ec2eb3

Browse files
authored
Merge pull request #616 from NVIDIA/example_tensors
Example tensors
2 parents e95aa99 + 122429f commit 0ec2eb3

File tree

7 files changed

+111
-6
lines changed

7 files changed

+111
-6
lines changed

cpp/include/trtorch/trtorch.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ struct TRTORCH_API CompileSpec {
427427
Input(c10::ArrayRef<int64_t> shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous);
428428

429429
/**
430-
* @brief Construct a new Input Range object dynamic input size from
430+
* @brief Construct a new Input spec object dynamic input size from
431431
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
432432
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
433433
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
@@ -462,7 +462,7 @@ struct TRTORCH_API CompileSpec {
462462
TensorFormat format = TensorFormat::kContiguous);
463463

464464
/**
465-
* @brief Construct a new Input Range object dynamic input size from
465+
* @brief Construct a new Input spec object dynamic input size from
466466
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
467467
* supported sizes. dtype (Expected data type for the input) defaults to PyTorch
468468
* / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8)
@@ -479,7 +479,7 @@ struct TRTORCH_API CompileSpec {
479479
TensorFormat format = TensorFormat::kContiguous);
480480

481481
/**
482-
* @brief Construct a new Input Range object dynamic input size from
482+
* @brief Construct a new Input spec object dynamic input size from
483483
* c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max
484484
* supported sizes
485485
*
@@ -496,6 +496,16 @@ struct TRTORCH_API CompileSpec {
496496
DataType dtype,
497497
TensorFormat format = TensorFormat::kContiguous);
498498

499+
/**
500+
* @brief Construct a new Input spec object using a torch tensor as an example
501+
* The tensor's shape, type and layout inform the spec's values
502+
*
503+
* Note: You cannot set dynamic shape through this method, you must use an alternative constructor
504+
*
505+
* @param tensor Reference tensor to set shape, type and layout
506+
*/
507+
Input(at::Tensor tensor);
508+
499509
bool get_explicit_set_dtype() {
500510
return explicit_set_dtype;
501511
}

cpp/src/compile_spec.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,26 @@ CompileSpec::Input::Input(
287287
this->input_is_dynamic = true;
288288
}
289289

290+
CompileSpec::Input::Input(at::Tensor tensor) {
291+
this->opt_shape = tensor.sizes().vec();
292+
this->min_shape = tensor.sizes().vec();
293+
this->max_shape = tensor.sizes().vec();
294+
this->shape = tensor.sizes().vec();
295+
this->dtype = tensor.scalar_type();
296+
this->explicit_set_dtype = true;
297+
TRTORCH_ASSERT(
298+
tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous),
299+
"Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last");
300+
at::MemoryFormat frmt;
301+
if (tensor.is_contiguous(at::MemoryFormat::Contiguous)) {
302+
frmt = at::MemoryFormat::Contiguous;
303+
} else {
304+
frmt = at::MemoryFormat::ChannelsLast;
305+
}
306+
this->format = frmt;
307+
this->input_is_dynamic = false;
308+
}
309+
290310
/* ==========================================*/
291311

292312
core::ir::Input to_internal_input(CompileSpec::InputRange& i) {

py/trtorch/Input.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,16 @@ def _parse_format(format: Any) -> _types.TensorFormat:
196196
else:
197197
raise TypeError(
198198
"Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")
199+
200+
@classmethod
201+
def _from_tensor(cls, t: torch.Tensor):
202+
if not any([
203+
t.is_contiguous(memory_format=torch.contiguous_format),
204+
t.is_contiguous(memory_format=torch.channels_last)
205+
]):
206+
raise ValueError(
207+
"Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last"
208+
)
209+
frmt = torch.contiguous_format if t.is_contiguous(
210+
memory_format=torch.contiguous_format) else torch.channels_last
211+
return cls(shape=t.shape, dtype=t.dtype, format=frmt)

py/trtorch/_compile_spec.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,12 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
174174
info.inputs = _parse_input_ranges(compile_spec["input_shapes"])
175175

176176
if "inputs" in compile_spec:
177-
info.inputs = [i._to_internal() for i in compile_spec["inputs"]]
177+
if not all([isinstance(i, torch.Tensor) or isinstance(i, trtorch.Input) for i in compile_spec["inputs"]]):
178+
raise KeyError("Input specs should be either trtorch.Input or torch.Tensor, found types: {}".format(
179+
[typeof(i) for i in compile_spec["inputs"]]))
180+
181+
inputs = [trtorch.Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]]
182+
info.inputs = [i._to_internal() for i in inputs]
178183

179184
if "op_precision" in compile_spec and "enabled_precisions" in compile_spec:
180185
raise KeyError(

tests/cpp/BUILD

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ test_suite(
1616
":test_modules_as_engines",
1717
":test_multiple_registered_engines",
1818
":test_serialization",
19-
":test_module_fallback"
19+
":test_module_fallback",
20+
":test_example_tensors"
2021
],
2122
)
2223

@@ -28,7 +29,8 @@ test_suite(
2829
":test_modules_as_engines",
2930
":test_multiple_registered_engines",
3031
":test_serialization",
31-
":test_module_fallback"
32+
":test_module_fallback",
33+
":test_example_tensors"
3234
],
3335
)
3436

@@ -43,6 +45,17 @@ cc_test(
4345
],
4446
)
4547

48+
cc_test(
49+
name = "test_example_tensors",
50+
srcs = ["test_example_tensors.cpp"],
51+
data = [
52+
"//tests/modules:jit_models",
53+
],
54+
deps = [
55+
":cpp_api_test",
56+
],
57+
)
58+
4659
cc_test(
4760
name = "test_serialization",
4861
srcs = ["test_serialization.cpp"],

tests/cpp/test_example_tensors.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include "cpp_api_test.h"
2+
3+
TEST_P(CppAPITests, InputsFromTensors) {
4+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
5+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
6+
for (auto in_shape : input_shapes) {
7+
auto in = at::randn(in_shape, {at::kCUDA});
8+
jit_inputs_ivalues.push_back(in.clone());
9+
trt_inputs_ivalues.push_back(in.clone());
10+
}
11+
12+
auto spec = trtorch::CompileSpec({trt_inputs_ivalues[0].toTensor()});
13+
14+
auto trt_mod = trtorch::CompileGraph(mod, spec);
15+
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(trt_mod, trt_inputs_ivalues);
16+
std::vector<at::Tensor> trt_results;
17+
trt_results.push_back(trt_results_ivalues.toTensor());
18+
}
19+
20+
INSTANTIATE_TEST_SUITE_P(
21+
CompiledModuleForwardIsCloseSuite,
22+
CppAPITests,
23+
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));

tests/py/test_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,27 @@ def test_compile_script(self):
7373
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
7474
self.assertTrue(same < 2e-2)
7575

76+
def test_from_torch_tensor(self):
77+
compile_spec = {
78+
"inputs": [self.input],
79+
"device": {
80+
"device_type": trtorch.DeviceType.GPU,
81+
"gpu_id": 0,
82+
},
83+
"enabled_precisions": {torch.float}
84+
}
85+
86+
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
87+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
88+
self.assertTrue(same < 2e-2)
89+
90+
def test_device(self):
91+
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}
92+
93+
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
94+
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
95+
self.assertTrue(same < 2e-2)
96+
7697

7798
class TestCompileHalf(ModelTestCase):
7899

0 commit comments

Comments
 (0)