Skip to content

Commit 013934a

Browse files
peri044bowang007
authored andcommitted
fix: Fix how ITensorList is detected
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent d7cb415 commit 013934a

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

core/conversion/var/Var.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,31 @@ bool Var::isITensor() const {
146146
}
147147
}
148148

149+
bool Var::isITensorList() {
150+
// Unpack the Var as a List and check if each entry is a custom class since
151+
// ITensors are stored in CustomClassHolder
152+
auto ival_list = ptr_.ivalue->toList();
153+
for (int i = 0; i < ival_list.size(); i++) {
154+
if (!ival_list.get(i).isCustomClass()) {
155+
return false;
156+
}
157+
}
158+
return true;
159+
}
160+
161+
std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList() {
162+
TORCHTRT_CHECK(
163+
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
164+
TORCHTRT_CHECK(isITensorList(), "Expected IValue to be an ITensorList");
165+
auto ivalue_list = ptr_.ivalue->toList();
166+
std::vector<nvinfer1::ITensor*> outputs;
167+
for (int i = 0; i < ivalue_list.size(); i++) {
168+
auto element = ivalue_list.get(i).toCustomClass<TensorContainer>()->tensor();
169+
outputs.push_back(std::move(element));
170+
}
171+
return outputs;
172+
}
173+
149174
bool Var::isIValue() const {
150175
if (type_ == Type::kIValue) {
151176
return true;

0 commit comments

Comments
 (0)