Skip to content

Commit c63a5a5

Browse files
authored
Merge pull request #1270 from Njuapp/dynamic-transformer
Support swin/bert with dynamic batch
2 parents e8c971a + 31797c1 commit c63a5a5

File tree

3 files changed

+70
-32
lines changed

3 files changed

+70
-32
lines changed

core/conversion/converters/impl/element_wise.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ nvinfer1::ITensor* clamp_util(
2525
return clamp_layer_out;
2626
}
2727

28-
29-
3028
auto element_wise_registrations TORCHTRT_UNUSED =
3129
RegisterNodeConversionPatterns()
3230
.pattern(

core/conversion/converters/impl/select.cpp

+28-23
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ auto select_registrations TORCHTRT_UNUSED =
136136
// IShuffleLayer removes redundant dimensions
137137
auto shuffle_layer = ctx->net->addShuffle(*out);
138138
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
139-
shuffle_layer->setReshapeDimensions(util::squeezeDims(out->getDimensions(), dim));
139+
shuffle_layer->setReshapeDimensions(
140+
util::squeezeDims(out->getDimensions(), dim, !ctx->input_is_dynamic));
140141
shuffle_layer->setName(util::node_info(n).c_str());
141142
out = shuffle_layer->getOutput(0);
142143
}
@@ -249,21 +250,19 @@ auto select_registrations TORCHTRT_UNUSED =
249250
auto dims = args[2].unwrapToIntList().vec();
250251

251252
TORCHTRT_CHECK(dims.size() == shifts.size(), "dims.size() should be equal to shifts.size()");
252-
if (ctx->input_is_dynamic) {
253-
TORCHTRT_THROW_ERROR("aten::roll is currently not support in dynamic input shape compilation");
254-
} else {
255-
auto in_shape = util::toVec(in->getDimensions());
256-
for (size_t i = 0; i < dims.size(); i++) {
257-
auto dim = dims[i] < 0 ? (in_shape.size() + dims[i]) : dims[i];
258-
TORCHTRT_CHECK(dim < in_shape.size(), "Dimension out of range");
259-
in = roll(ctx, in, shifts[i], dim, in_shape);
260-
}
261-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
253+
auto in_shape = util::toVec(in->getDimensions());
254+
for (size_t i = 0; i < dims.size(); i++) {
255+
auto dim = dims[i] < 0 ? (in_shape.size() + dims[i]) : dims[i];
256+
TORCHTRT_CHECK(dim < in_shape.size(), "Dimension out of range");
257+
TORCHTRT_CHECK(
258+
in_shape[dim] != -1, "aten::roll is not supported when the targeted dimension is dynamic");
259+
in = roll(ctx, in, shifts[i], dim, in_shape);
260+
}
261+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
262262

263-
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
263+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
264264

265-
return true;
266-
}
265+
return true;
267266
}})
268267
.pattern(
269268
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
@@ -360,9 +359,15 @@ auto select_registrations TORCHTRT_UNUSED =
360359
stride_.d[i] = 1;
361360
}
362361
}
363-
auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_);
364-
365-
if (dynamic_shape) { // dynamic shape
362+
if (!dynamic_shape) {
363+
auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_);
364+
LOG_DEBUG("start_:" << start_);
365+
LOG_DEBUG("size_:" << size_);
366+
LOG_DEBUG("stride_:" << stride_);
367+
auto slice_out = slice_layer->getOutput(0);
368+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_out);
369+
LOG_DEBUG("Slice layer output shape: " << out->getDimensions());
370+
} else { // dynamic shape
366371
LOG_DEBUG("Using dynamic version of slice");
367372
// start tensor
368373
at::Tensor start_tensor = torch::zeros({nbdims}).to(torch::kI32);
@@ -398,13 +403,13 @@ auto select_registrations TORCHTRT_UNUSED =
398403
auto size_itensor = get_slice_size(ctx, out_start, out_end, stride_itensor, nbdims, node_name);
399404

400405
// update slice layer
406+
auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_);
401407
slice_layer->setInput(1, *out_start); // start
402408
slice_layer->setInput(2, *size_itensor); // size, must be set if input is dynamic
409+
auto slice_out = slice_layer->getOutput(0);
410+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_out);
411+
LOG_DEBUG("Slice layer output shape: " << out->getDimensions());
403412
}
404-
auto slice_out = slice_layer->getOutput(0);
405-
406-
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_out);
407-
LOG_DEBUG("Slice layer output shape: " << out->getDimensions());
408413

409414
return true;
410415
}})
@@ -484,7 +489,7 @@ auto select_registrations TORCHTRT_UNUSED =
484489

485490
auto layer = ctx->net->addScatter(*self, *index, *value_tensor, nvinfer1::ScatterMode::kELEMENT);
486491
layer->setAxis(dim);
487-
492+
488493
TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.value");
489494

490495
layer->setName(util::node_info(n).c_str());
@@ -503,7 +508,7 @@ auto select_registrations TORCHTRT_UNUSED =
503508

504509
auto layer = ctx->net->addScatter(*self, *index, *src, nvinfer1::ScatterMode::kELEMENT);
505510
layer->setAxis(dim);
506-
511+
507512
TORCHTRT_CHECK(layer, "Unable to create layer for aten::scatter.src");
508513

509514
layer->setName(util::node_info(n).c_str());

core/conversion/converters/impl/shuffle.cpp

+42-7
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,38 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
1919
auto end_dim = args[2].unwrapToInt();
2020
auto in_shape = util::toVec(in->getDimensions());
2121
std::vector<int64_t> out_shape;
22-
if (ctx->input_is_dynamic && in_shape[0] != -1) {
23-
out_shape = std::vector<int64_t>({in_shape[0], -1});
24-
} else if (ctx->input_is_dynamic && in_shape[0] == -1) {
25-
out_shape = std::vector<int64_t>(
26-
{-1,
27-
-1 * std::accumulate(std::begin(in_shape), std::end(in_shape), 1, std::multiplies<int64_t>())});
22+
if (ctx->input_is_dynamic) {
23+
end_dim = (end_dim == -1) ? in_shape.size() - 1 : end_dim;
24+
int nbDynamicFlattenedDims = 0;
25+
int nbDynamicUnflattenedDims = 0;
26+
for (int i = 0; i < (int)in_shape.size(); i++) {
27+
if (in_shape[i] == -1) {
28+
if (i >= start_dim && i <= end_dim)
29+
nbDynamicFlattenedDims++;
30+
else
31+
nbDynamicUnflattenedDims++;
32+
}
33+
}
34+
if (nbDynamicFlattenedDims > 0 && nbDynamicUnflattenedDims > 0) {
35+
TORCHTRT_THROW_ERROR(
36+
"Flatten is currently not supported when target shape contains more than one dynamic dimension");
37+
}
38+
if (nbDynamicUnflattenedDims > 1) {
39+
TORCHTRT_THROW_ERROR(
40+
"Flatten is currently not supported when target shape contains more than one dynamic dimension");
41+
}
42+
out_shape = in_shape;
43+
out_shape.erase(std::begin(out_shape) + start_dim, std::begin(out_shape) + end_dim + 1);
44+
if (nbDynamicFlattenedDims == 0) {
45+
auto flattened_dim = std::accumulate(
46+
std::begin(in_shape) + start_dim,
47+
std::begin(in_shape) + end_dim + 1,
48+
1,
49+
std::multiplies<int64_t>());
50+
out_shape.insert(std::begin(out_shape) + start_dim, flattened_dim);
51+
} else {
52+
out_shape.insert(std::begin(out_shape) + start_dim, -1);
53+
}
2854
} else {
2955
out_shape = torch::flatten(torch::rand(in_shape), start_dim, end_dim).sizes().vec();
3056
}
@@ -45,7 +71,16 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
4571
auto in_shape = util::toVec(in->getDimensions());
4672
std::vector<int64_t> new_shape;
4773
if (ctx->input_is_dynamic) {
48-
TORCHTRT_THROW_ERROR("Resize is currently not support in dynamic input shape compilation");
74+
new_shape = util::toVec(args[1].unwrapToIntList().vec());
75+
int nbDynamicDims = 0;
76+
for (size_t i = 0; i < new_shape.size(); i++) {
77+
if (in_shape[i] == -1)
78+
nbDynamicDims++;
79+
}
80+
if (nbDynamicDims > 1) {
81+
TORCHTRT_THROW_ERROR(
82+
"Resize is currently not supported when target shape contains more than one dynamic dimension");
83+
}
4984
} else {
5085
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
5186
}

0 commit comments

Comments
 (0)