Skip to content

Commit 65dbf90

Browse files
committed
feat(aten::add): adding string concat evaluator
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 828d120 commit 65dbf90

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

core/conversion/evaluators/aten.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -342,15 +342,22 @@ auto aten_registrations TORCHTRT_UNUSED =
342342
auto a = args.at(n->input(0)).unwrapToDouble();
343343
auto b = args.at(n->input(1)).unwrapToDouble();
344344
return a + b;
345+
} else if (args.at(n->input(0)).IValue()->isString()) {
346+
auto a = args.at(n->input(0)).unwrapToString();
347+
auto b = args.at(n->input(1)).unwrapToString();
348+
return a + b;
345349
} else {
346350
TORCHTRT_THROW_ERROR(
347351
"Unimplemented data type for aten::add evaluator: "
348352
<< args.at(n->input(0)).IValue()->type()->str());
349353
return {};
350354
}
351355
},
352-
EvalOptions().validSchemas(
353-
{"aten::add.int(int a, int b) -> (int)", "aten::add.float(float a, float b) -> (float)"})})
356+
EvalOptions().validSchemas({
357+
"aten::add.int(int a, int b) -> (int)",
358+
"aten::add.float(float a, float b) -> (float)",
359+
"aten::add.str(str a, str b) -> (str)"
360+
})})
354361
.evaluator({c10::Symbol::fromQualString("aten::add_"),
355362
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
356363
if (args.at(n->input(0)).IValue()->isList()) {

0 commit comments

Comments
 (0)