@@ -776,3 +776,291 @@ func.func @test_const_shape() -> !tosa.shape<4> {
776
776
%cst = tosa.const_shape {value = dense <1 > : tensor <4 xindex >} : () -> !tosa.shape <4 >
777
777
return %cst : !tosa.shape <4 >
778
778
}
779
+
780
+ // F8 support tests
781
+
782
+ // -----
783
+ // CHECK-LABEL: argmax_f8E5M2
784
+ func.func @test_argmax_f8E5M2 (%arg0: tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 > {
785
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 >
786
+ return %0 : tensor <12 x16 xi32 >
787
+ }
788
+
789
+ // -----
790
+ // CHECK-LABEL: avg_pool2d_f8E5M2
791
+ func.func @test_avg_pool2d_f8E5M2 (%arg0: tensor <1 x7 x7 x9 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 > {
792
+ %input_zp = " tosa.const" () <{value = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
793
+ %output_zp = " tosa.const" () <{value = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
794
+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 >
795
+ return %0 : tensor <1 x7 x7 x9 xf8 E5 M2 >
796
+ }
797
+
798
+ // -----
799
+ // CHECK-LABEL: conv2d_f8E5M2
800
+ func.func @test_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <8 x1 x1 x4 xf8 E5 M2 >, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
801
+ %input_zp = " tosa.const" () <{value = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
802
+ %weight_zp = " tosa.const" () <{value = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
803
+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <8 x1 x1 x4 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
804
+ return %0 : tensor <1 x4 x4 x8 xf16 >
805
+ }
806
+
807
+ // -----
808
+ // CHECK-LABEL: conv3d_f8E5M2
809
+ func.func @test_conv3d_f8E5M2 (%arg0: tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, %arg1: tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 > {
810
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, tensor <34 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 >
811
+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
812
+ }
813
+
814
+ // -----
815
+ // CHECK-LABEL: depthwise_conv2d_f8E5M2
816
+ func.func @test_depthwise_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <1 x1 x4 x2 xf8 E5 M2 >, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 > {
817
+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <1 x1 x4 x2 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
818
+ return %0 : tensor <1 x4 x4 x8 xf16 >
819
+ }
820
+
821
+ // -----
822
+ // CHECK-LABEL: test_matmul_f8E5M2
823
+ func.func @test_matmul_f8E5M2 (%arg0: tensor <1 x14 x19 xf8 E5 M2 >, %arg1: tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 > {
824
+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E5 M2 >, tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 >
825
+ return %0 : tensor <1 x14 x28 xf16 >
826
+ }
827
+
828
+ // -----
829
+ // CHECK-LABEL: max_pool2d_f8E5M2
830
+ func.func @test_max_pool2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 > {
831
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 >
832
+ return %0 : tensor <1 x32 x32 x8 xf8 E5 M2 >
833
+ }
834
+
835
+ // -----
836
+
837
+ // CHECK-LABEL: transpose_conv2d_f8E5M2
838
+ func.func @test_transpose_conv2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >, %arg1: tensor <16 x1 x1 x8 xf8 E5 M2 >, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 > {
839
+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >, tensor <16 x1 x1 x8 xf8 E5 M2 >, tensor <16 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 >
840
+ return %0 : tensor <1 x32 x32 x16 xf16 >
841
+ }
842
+
843
+ // -----
844
+ // CHECK-LABEL: const_f8E5M2
845
+ func.func @test_const_f8E5M2 (%arg0 : index ) -> tensor <4 xf8 E5 M2 > {
846
+ %0 = " tosa.const" () {value = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E5 M2 >} : () -> tensor <4 xf8 E5 M2 >
847
+ return %0 : tensor <4 xf8 E5 M2 >
848
+ }
849
+
850
+ // -----
851
+ // CHECK-LABEL: cast_f8E5M2
852
+ func.func @test_cast_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 > {
853
+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 >
854
+ return %0 : tensor <13 x21 x3 xf16 >
855
+ }
856
+
857
+ // -----
858
+ // CHECK-LABEL: concat_f8E5M2
859
+ func.func @test_concat_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 > {
860
+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 >
861
+ return %0 : tensor <26 x21 x3 xf8 E5 M2 >
862
+ }
863
+
864
+ // -----
865
+ // CHECK-LABEL: pad_f8E5M2
866
+ func.func @test_pad_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
867
+ %padding = tosa.const_shape {value = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
868
+ %cst = " tosa.const" () { value = dense <-0.0 > : tensor <1 xf8 E5 M2 > } : () -> tensor <1 xf8 E5 M2 >
869
+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <6 >, tensor <1 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
870
+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
871
+ }
872
+
873
+ // -----
874
+ // CHECK-LABEL: reshape_f8E5M2
875
+ func.func @test_reshape_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <1 x819 xf8 E5 M2 > {
876
+ %1 = tosa.const_shape {value = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
877
+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <2 >) -> tensor <1 x819 xf8 E5 M2 >
878
+ return %0 : tensor <1 x819 xf8 E5 M2 >
879
+ }
880
+
881
+ // -----
882
+ // CHECK-LABEL: reverse_f8E5M2
883
+ func.func @test_reverse_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
884
+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
885
+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
886
+ }
887
+
888
+ // -----
889
+ // CHECK-LABEL: slice_f8E5M2
890
+ func.func @test_slice_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <4 x11 x1 xf8 E5 M2 > {
891
+ %0 = tosa.const_shape {value = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
892
+ %1 = tosa.const_shape {value = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
893
+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E5 M2 >
894
+ return %2 : tensor <4 x11 x1 xf8 E5 M2 >
895
+ }
896
+
897
+ // -----
898
+ // CHECK-LABEL: tile_f8E5M2
899
+ func.func @test_tile_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <39 x21 x6 xf8 E5 M2 > {
900
+ %cst = tosa.const_shape { value = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
901
+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E5 M2 >
902
+ return %0 : tensor <39 x21 x6 xf8 E5 M2 >
903
+ }
904
+
905
+ // -----
906
+ func.func @test_transpose_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 > {
907
+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 >
908
+ return %1 : tensor <3 x13 x21 xf8 E5 M2 >
909
+ }
910
+
911
+ // -----
912
+ // CHECK-LABEL: gather_f8E5M2
913
+ func.func @test_gather_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 > {
914
+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 >
915
+ return %0 : tensor <13 x26 x3 xf8 E5 M2 >
916
+ }
917
+
918
+ // -----
919
+ // CHECK-LABEL: scatter_f8E5M2
920
+ func.func @test_scatter_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
921
+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
922
+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
923
+ }
924
+
925
+ // -----
926
+ // CHECK-LABEL: argmax_f8E4M3FN
927
+ func.func @test_argmax_f8E4M3FN (%arg0: tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 > {
928
+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 >
929
+ return %0 : tensor <12 x16 xi32 >
930
+ }
931
+
932
+ // -----
933
+ // CHECK-LABEL: avg_pool2d_f8E4M3FN
934
+ func.func @test_avg_pool2d_f8E4M3FN (%arg0: tensor <1 x7 x7 x9 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN> {
935
+ %input_zp = " tosa.const" () <{value = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
936
+ %output_zp = " tosa.const" () <{value = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
937
+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN>
938
+ return %0 : tensor <1 x7 x7 x9 xf8 E4 M3 FN>
939
+ }
940
+
941
+ // -----
942
+ // CHECK-LABEL: conv2d_f8E4M3FN
943
+ func.func @test_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <8 x1 x1 x4 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
944
+ %input_zp = " tosa.const" () <{value = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
945
+ %weight_zp = " tosa.const" () <{value = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
946
+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <8 x1 x1 x4 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
947
+ return %0 : tensor <1 x4 x4 x8 xf16 >
948
+ }
949
+
950
+ // -----
951
+ // CHECK-LABEL: conv3d_f8E4M3FN
952
+ func.func @test_conv3d_f8E4M3FN (%arg0: tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, %arg1: tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 > {
953
+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, tensor <34 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 >
954
+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
955
+ }
956
+
957
+ // -----
958
+ // CHECK-LABEL: depthwise_conv2d_f8E4M3FN
959
+ func.func @test_depthwise_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <1 x1 x4 x2 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 > {
960
+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <1 x1 x4 x2 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
961
+ return %0 : tensor <1 x4 x4 x8 xf16 >
962
+ }
963
+
964
+ // -----
965
+ // CHECK-LABEL: matmul_f8E4M3FN
966
+ func.func @test_matmul_f8E4M3FN (%arg0: tensor <1 x14 x19 xf8 E4 M3 FN>, %arg1: tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 > {
967
+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E4 M3 FN>, tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 >
968
+ return %0 : tensor <1 x14 x28 xf16 >
969
+ }
970
+
971
+ // -----
972
+ // CHECK-LABEL: max_pool2d_f8E4M3FN
973
+ func.func @test_max_pool2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN> {
974
+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN>
975
+ return %0 : tensor <1 x32 x32 x8 xf8 E4 M3 FN>
976
+ }
977
+
978
+ // -----
979
+ // CHECK-LABEL: transpose_conv2d_f8E4M3FN
980
+ func.func @test_transpose_conv2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>, %arg1: tensor <16 x1 x1 x8 xf8 E4 M3 FN>, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 > {
981
+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>, tensor <16 x1 x1 x8 xf8 E4 M3 FN>, tensor <16 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 >
982
+ return %0 : tensor <1 x32 x32 x16 xf16 >
983
+ }
984
+
985
+ // -----
986
+ // CHECK-LABEL: const_f8E4M3FN
987
+ func.func @test_const_f8E4M3FN (%arg0 : index ) -> tensor <4 xf8 E4 M3 FN> {
988
+ %0 = " tosa.const" () {value = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E4 M3 FN>} : () -> tensor <4 xf8 E4 M3 FN>
989
+ return %0 : tensor <4 xf8 E4 M3 FN>
990
+ }
991
+
992
+ // -----
993
+ // CHECK-LABEL: cast_f8E4M3FN
994
+ func.func @test_cast_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 > {
995
+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 >
996
+ return %0 : tensor <13 x21 x3 xf16 >
997
+ }
998
+
999
+ // -----
1000
+ // CHECK-LABEL: concat_f8E4M3FN
1001
+ func.func @test_concat_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN> {
1002
+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN>
1003
+ return %0 : tensor <26 x21 x3 xf8 E4 M3 FN>
1004
+ }
1005
+
1006
+ // -----
1007
+ // CHECK-LABEL: pad_f8E4M3FN
1008
+ func.func @test_pad_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1009
+ %padding = tosa.const_shape {value = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
1010
+ %cst = " tosa.const" () { value = dense <-0.0 > : tensor <1 xf8 E4 M3 FN> } : () -> tensor <1 xf8 E4 M3 FN>
1011
+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <6 >, tensor <1 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1012
+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1013
+ }
1014
+
1015
+ // -----
1016
+ // CHECK-LABEL: reshape_f8E4M3FN
1017
+ func.func @test_reshape_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <1 x819 xf8 E4 M3 FN> {
1018
+ %1 = tosa.const_shape {value = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
1019
+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <2 >) -> tensor <1 x819 xf8 E4 M3 FN>
1020
+ return %0 : tensor <1 x819 xf8 E4 M3 FN>
1021
+ }
1022
+
1023
+ // -----
1024
+ // CHECK-LABEL: reverse_f8E4M3FN
1025
+ func.func @test_reverse_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1026
+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1027
+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1028
+ }
1029
+
1030
+ // -----
1031
+ // CHECK-LABEL: slice_f8E4M3FN
1032
+ func.func @test_slice_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <4 x11 x1 xf8 E4 M3 FN> {
1033
+ %0 = tosa.const_shape {value = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1034
+ %1 = tosa.const_shape {value = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1035
+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E4 M3 FN>
1036
+ return %2 : tensor <4 x11 x1 xf8 E4 M3 FN>
1037
+ }
1038
+
1039
+ // -----
1040
+ // CHECK-LABEL: tile_f8E4M3FN
1041
+ func.func @test_tile_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <39 x21 x6 xf8 E4 M3 FN> {
1042
+ %cst = tosa.const_shape { value = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
1043
+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E4 M3 FN>
1044
+ return %0 : tensor <39 x21 x6 xf8 E4 M3 FN>
1045
+ }
1046
+
1047
+ // -----
1048
+ // CHECK-LABEL: transpose_f8E4M3FN
1049
+ func.func @test_transpose_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN> {
1050
+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN>
1051
+ return %1 : tensor <3 x13 x21 xf8 E4 M3 FN>
1052
+ }
1053
+
1054
+ // -----
1055
+ // CHECK-LABEL: gather_f8E4M3FN
1056
+ func.func @test_gather_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN> {
1057
+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN>
1058
+ return %0 : tensor <13 x26 x3 xf8 E4 M3 FN>
1059
+ }
1060
+
1061
+ // -----
1062
+ // CHECK-LABEL: scatter_f8E4M3FN
1063
+ func.func @test_scatter_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1064
+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1065
+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1066
+ }
0 commit comments