Skip to content

Commit 66d4aa9

Browse files
authored
Merge pull request #1584 from gs-olive/full_like_evaluator
fix: Add `aten::full_like` evaluator
2 parents 0d32562 + 0aaeecb commit 66d4aa9

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,53 @@ auto aten_registrations TORCHTRT_UNUSED =
172172
auto out_tensor = torch::full(args.at(n->input(0)).unwrapToIntList().vec(), scalar_value, options);
173173
return out_tensor;
174174
}})
175+
.evaluator(
176+
{c10::Symbol::fromQualString("aten::full_like"),
177+
// aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None,
178+
// Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> (Tensor)
179+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
180+
// Override options related to layout and device for TensorRT
181+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
182+
auto input_tensor_var = args.at(n->input(0));
183+
184+
std::vector<int64_t> input_shape;
185+
c10::ScalarType input_dtype;
186+
187+
// Extract data type and shape of input tensor
188+
if (input_tensor_var.isITensor()) {
189+
auto tensor = input_tensor_var.ITensor();
190+
input_shape = util::toVec(tensor->getDimensions());
191+
input_dtype = util::TRTDataTypeToScalarType(tensor->getType());
192+
} else if (input_tensor_var.IValue()->isTensor()) {
193+
auto tensor = input_tensor_var.unwrapToTensor();
194+
input_shape = tensor.sizes().vec();
195+
input_dtype = tensor.scalar_type();
196+
} else if (input_tensor_var.IValue()->isCustomClass()) {
197+
auto tensor = input_tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
198+
input_shape = util::toVec(tensor->getDimensions());
199+
input_dtype = util::TRTDataTypeToScalarType(tensor->getType());
200+
} else {
201+
TORCHTRT_THROW_ERROR(
202+
"Invalid IValue type. IValue is not some class of torch::Tensor or nvinfer1::ITensor. Found: "
203+
<< input_tensor_var.IValue()->type());
204+
}
205+
206+
// If specified, use third input arg to determine data type, otherwise default to input tensor data type
207+
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
208+
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
209+
} else {
210+
options = options.dtype(input_dtype);
211+
}
212+
213+
// Generate full tensor with specified input options
214+
auto scalar_value = args.at(n->input(1)).unwrapToScalar();
215+
auto out_tensor = torch::full(input_shape, scalar_value, options);
216+
return out_tensor;
217+
},
218+
EvalOptions().validSchemas(
219+
{R"SIG(aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None,
220+
Layout? layout=None, Device? device=None, bool? pin_memory=None,
221+
MemoryFormat? memory_format=None) -> (Tensor))SIG"})})
175222
.evaluator(
176223
{c10::Symbol::fromQualString("aten::slice"),
177224
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
@@ -821,4 +868,4 @@ auto aten_registrations TORCHTRT_UNUSED =
821868
} // namespace evaluators
822869
} // namespace conversion
823870
} // namespace core
824-
} // namespace torch_tensorrt
871+
} // namespace torch_tensorrt

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,71 @@ TEST(Evaluators, FullEvaluatesCorrectly) {
8383
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
8484
}
8585

86+
TEST(Evaluators, FullLikeEvaluatesCorrectly) {
87+
const auto graph = R"IR(
88+
graph(%x.1 : Tensor):
89+
%9 : None = prim::Constant()
90+
%13 : float = prim::Constant[value=1.3]()
91+
%14 : int = prim::Constant[value=4]()
92+
%35 : Device = prim::Constant[value="cuda:0"]()
93+
%19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9)
94+
return (%19))IR";
95+
96+
auto in = at::randint(1, 10, {1, 2, 3, 5}, {at::kCUDA});
97+
98+
auto g = std::make_shared<torch::jit::Graph>();
99+
torch::jit::parseIR(graph, g.get());
100+
101+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
102+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
103+
104+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
105+
ASSERT_TRUE(jit_results[0].toTensor().dtype() == trt_results[0].toTensor().dtype());
106+
}
107+
108+
TEST(Evaluators, FullLikeNewDtypeEvaluatesCorrectly) {
109+
const auto graph = R"IR(
110+
graph(%x.1 : Tensor):
111+
%9 : None = prim::Constant()
112+
%13 : Scalar = prim::Constant[value=1]()
113+
%14 : int = prim::Constant[value=11]()
114+
%35 : Device = prim::Constant[value="cuda:0"]()
115+
%19 : Tensor = aten::full_like(%x.1, %13, %14, %9, %35, %9, %9)
116+
return (%19))IR";
117+
118+
auto in = at::randint(1, 10, {1, 2, 3, 5}, {at::kCUDA});
119+
120+
auto g = std::make_shared<torch::jit::Graph>();
121+
torch::jit::parseIR(graph, g.get());
122+
123+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
124+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
125+
126+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
127+
ASSERT_TRUE(jit_results[0].toTensor().dtype() == trt_results[0].toTensor().dtype());
128+
}
129+
130+
TEST(Evaluators, FullLikeOldDtypeEvaluatesCorrectly) {
131+
const auto graph = R"IR(
132+
graph(%x.1 : Tensor):
133+
%9 : None = prim::Constant()
134+
%13 : Scalar = prim::Constant[value=1.5]()
135+
%35 : Device = prim::Constant[value="cuda:0"]()
136+
%19 : Tensor = aten::full_like(%x.1, %13, %9, %9, %35, %9, %9)
137+
return (%19))IR";
138+
139+
auto in = at::randint(1, 10, {1, 2, 3, 5}, {at::kCUDA}).to(torch::kInt32);
140+
141+
auto g = std::make_shared<torch::jit::Graph>();
142+
torch::jit::parseIR(graph, g.get());
143+
144+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {in});
145+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {in});
146+
147+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
148+
ASSERT_TRUE(jit_results[0].toTensor().dtype() == trt_results[0].toTensor().dtype());
149+
}
150+
86151
TEST(Evaluators, OnesDataTypeEvaluatesCorrectly) {
87152
const auto graph = R"IR(
88153
graph(%x.1 : Tensor):

0 commit comments

Comments
 (0)