Skip to content

Commit 042fa2e

Browse files
authored
Merge pull request #992 from pytorch/anuragd/fix_aten_split
(//core): Added a variant for aten::split
2 parents 2c787ad + 99d1f32 commit 042fa2e

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ auto select_registrations TORCHTRT_UNUSED =
350350
LOG_DEBUG("Converted split op into a list of IValues");
351351
return true;
352352
}})
353+
.pattern({"aten::split.sizes(Tensor(a -> *) self, int[] split_size, int dim=0) -> (Tensor[])",
354+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
355+
add_split(ctx, n, args, true, false);
356+
LOG_DEBUG("Converted split op into a list of IValues");
357+
return true;
358+
}})
353359
.pattern({"aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])",
354360
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
355361
add_split(ctx, n, args, false, false);

0 commit comments

Comments
 (0)