@@ -982,6 +982,7 @@ class tinyBLAS_Q0_B16_AVX {
982
982
}
983
983
984
984
#if defined(__AVX512BF16__)
985
+ // Templated functions for gemm of dimesnions 4xN
985
986
template <int RN>
986
987
NOINLINE void gemm4xN (int64_t m0, int64_t m, int64_t n0, int64_t n) {
987
988
int64_t ytiles = (m - m0) / 4 ;
@@ -1006,6 +1007,7 @@ class tinyBLAS_Q0_B16_AVX {
1006
1007
__m256i avec3 = load (A + lda * (ii + 3 ) + l);
1007
1008
for (int64_t j = 0 ; j < RN; ++j) {
1008
1009
__m128bh db = m128bh (_mm_set1_epi16 (B[ldb * (jj + j) + l].d ));
1010
+ // Computation of product of delta values for four blocks
1009
1011
__m256 dvec = _mm256_castps128_ps256 (_mm_dpbf16_ps (zerovec, da, db));
1010
1012
dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
1011
1013
Cv[j][0 ] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
@@ -1057,7 +1059,8 @@ class tinyBLAS_Q0_B16_AVX {
1057
1059
__m256i bvec3 = load (B + ldb * (jj + 3 ) + l);
1058
1060
for (int64_t i = 0 ; i < RM; ++i) {
1059
1061
__m128bh da = m128bh (_mm_set1_epi16 ((A[lda * (ii + i) + l].d )));
1060
- __m256 dvec = _mm256_castps128_ps256 (_mm_dpbf16_ps (zerovec, da, db));
1062
+ // Computation of product of delta values for four blocks
1063
+ __m256 dvec = _mm256_castps128_ps256 (_mm_dpbf16_ps (zerovec, da, db));
1061
1064
dvec = _mm256_permute2f128_ps (dvec ,dvec, 0 );
1062
1065
Cv[0 ][i] = madd (_mm256_shuffle_ps (dvec, dvec, 0 ),
1063
1066
updot (_mm256_sign_epi8 (load (A + lda * (ii + i) + l),
0 commit comments