File tree 2 files changed +36
-1
lines changed
extension/android/executorch_android/src/main/java/org/pytorch/executorch
2 files changed +36
-1
lines changed Original file line number Diff line number Diff line change @@ -73,4 +73,13 @@ public enum DType {
73
73
DType (int jniCode ) {
74
74
this .jniCode = jniCode ;
75
75
}
76
+
77
+ public static DType fromJniCode (int jniCode ) {
78
+ for (DType dtype : values ()) {
79
+ if (dtype .jniCode == jniCode ) {
80
+ return dtype ;
81
+ }
82
+ }
83
+ throw new IllegalArgumentException ("No DType found for jniCode " + jniCode );
84
+ }
76
85
}
Original file line number Diff line number Diff line change 8
8
9
9
package org .pytorch .executorch ;
10
10
11
+ import android .util .Log ;
11
12
import com .facebook .jni .HybridData ;
12
13
import com .facebook .jni .annotations .DoNotStrip ;
13
14
import java .nio .Buffer ;
@@ -630,6 +631,31 @@ public String toString() {
630
631
}
631
632
}
632
633
634
+ static class Tensor_unsupported extends Tensor {
635
+ private final ByteBuffer data ;
636
+ private final DType myDtype ;
637
+
638
+ private Tensor_unsupported (ByteBuffer data , long [] shape , DType dtype ) {
639
+ super (shape );
640
+ this .data = data ;
641
+ this .myDtype = dtype ;
642
+ Log .e (
643
+ "ExecuTorch" ,
644
+ toString () + " in Java. Please consider re-export the model with proper return type" );
645
+ }
646
+
647
+ @ Override
648
+ public DType dtype () {
649
+ return myDtype ;
650
+ }
651
+
652
+ @ Override
653
+ public String toString () {
654
+ return String .format (
655
+ "Unsupported tensor(%s, dtype=%d)" , Arrays .toString (shape ), this .myDtype );
656
+ }
657
+ }
658
+
633
659
// region checks
634
660
private static void checkArgument (boolean expression , String errorMessage , Object ... args ) {
635
661
if (!expression ) {
@@ -675,7 +701,7 @@ private static Tensor nativeNewTensor(
675
701
} else if (DType .INT8 .jniCode == dtype ) {
676
702
tensor = new Tensor_int8 (data , shape );
677
703
} else {
678
- throw new IllegalArgumentException ( "Unknown Tensor dtype" );
704
+ tensor = new Tensor_unsupported ( data , shape , DType . fromJniCode ( dtype ) );
679
705
}
680
706
tensor .mHybridData = hybridData ;
681
707
return tensor ;
You can’t perform that action at this time.
0 commit comments