File tree 1 file changed +9
-2
lines changed
core/conversion/evaluators
1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -342,15 +342,22 @@ auto aten_registrations TORCHTRT_UNUSED =
342
342
auto a = args.at (n->input (0 )).unwrapToDouble ();
343
343
auto b = args.at (n->input (1 )).unwrapToDouble ();
344
344
return a + b;
345
+ } else if (args.at (n->input (0 )).IValue ()->isString ()) {
346
+ auto a = args.at (n->input (0 )).unwrapToString ();
347
+ auto b = args.at (n->input (1 )).unwrapToString ();
348
+ return a + b;
345
349
} else {
346
350
TORCHTRT_THROW_ERROR (
347
351
" Unimplemented data type for aten::add evaluator: "
348
352
<< args.at (n->input (0 )).IValue ()->type ()->str ());
349
353
return {};
350
354
}
351
355
},
352
- EvalOptions ().validSchemas (
353
- {" aten::add.int(int a, int b) -> (int)" , " aten::add.float(float a, float b) -> (float)" })})
356
+ EvalOptions ().validSchemas ({
357
+ " aten::add.int(int a, int b) -> (int)" ,
358
+ " aten::add.float(float a, float b) -> (float)" ,
359
+ " aten::add.str(str a, str b) -> (str)"
360
+ })})
354
361
.evaluator({c10::Symbol::fromQualString (" aten::add_" ),
355
362
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
356
363
if (args.at (n->input (0 )).IValue ()->isList ()) {
You can’t perform that action at this time.
0 commit comments