Skip to content

Commit b39ae34

Browse files
committed
[Android] Add unknown Tensor type instead of crash
1 parent dfd3dbe commit b39ae34

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/DType.java

+9
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,13 @@ public enum DType {
7373
DType(int jniCode) {
7474
this.jniCode = jniCode;
7575
}
76+
77+
public static DType fromJniCode(int jniCode) {
78+
for (DType dtype : values()) {
79+
if (dtype.jniCode == jniCode) {
80+
return dtype;
81+
}
82+
}
83+
throw new IllegalArgumentException("No DType found for jniCode " + jniCode);
84+
}
7685
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

+27-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
package org.pytorch.executorch;
1010

11+
import android.util.Log;
1112
import com.facebook.jni.HybridData;
1213
import com.facebook.jni.annotations.DoNotStrip;
1314
import java.nio.Buffer;
@@ -630,6 +631,31 @@ public String toString() {
630631
}
631632
}
632633

634+
static class Tensor_unsupported extends Tensor {
635+
private final ByteBuffer data;
636+
private final DType myDtype;
637+
638+
private Tensor_unsupported(ByteBuffer data, long[] shape, DType dtype) {
639+
super(shape);
640+
this.data = data;
641+
this.myDtype = dtype;
642+
Log.e(
643+
"ExecuTorch",
644+
toString() + " in Java. Please consider re-export the model with proper return type");
645+
}
646+
647+
@Override
648+
public DType dtype() {
649+
return myDtype;
650+
}
651+
652+
@Override
653+
public String toString() {
654+
return String.format(
655+
"Unsupported tensor(%s, dtype=%d)", Arrays.toString(shape), this.myDtype);
656+
}
657+
}
658+
633659
// region checks
634660
private static void checkArgument(boolean expression, String errorMessage, Object... args) {
635661
if (!expression) {
@@ -675,7 +701,7 @@ private static Tensor nativeNewTensor(
675701
} else if (DType.INT8.jniCode == dtype) {
676702
tensor = new Tensor_int8(data, shape);
677703
} else {
678-
throw new IllegalArgumentException("Unknown Tensor dtype");
704+
tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype));
679705
}
680706
tensor.mHybridData = hybridData;
681707
return tensor;

0 commit comments

Comments
 (0)