summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/mkl-dnn/src/cpu/gemm
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src/cpu/gemm')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp372
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp72
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp2131
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp36
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp2705
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp37
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp346
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp36
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp280
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp58
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp86
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp206
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp28
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp1409
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp38
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp539
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp101
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp290
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp411
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp64
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp819
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp2209
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp564
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp501
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp1283
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp3163
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp821
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp647
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp116
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp38
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp180
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp37
32 files changed, 19623 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp
new file mode 100644
index 0000000000..a9810dec28
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp
@@ -0,0 +1,372 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+#include <cmath>
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+#include "gemm_utils_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+namespace gemm_utils {
+#define BM_NOCOPY_AVX 64
+#define BN_NOCOPY_AVX 48
+#define BK_NOCOPY_AVX 384
+#define BN_LARGE_NOCOPY_AVX 192
+#define BM_SMALL_NOCOPY_AVX 16
+#define BN_SMALL_NOCOPY_AVX 1
+#define BK_SMALL_NOCOPY_AVX 4
+// Determine number of threads for each dimension of a 3-D partitioning
+// algorithm based on input parameters
+// m/n/k - First/second/third parameter for GEMM
+// nthrs - total available number of threads
+// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
+// BM/BN/BK - blocking values
+void calc_nthr_nocopy_avx(int m, int n, int k,
+ int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
+ int *BK)
+{
+ int nthr, nthr_m, nthr_n, nthr_k;
+ int MB, NB, KB;
+
+ nthr = nthrs;
+ nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX;
+ nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX;
+ nthr_k = 1;
+
+ // Partition along K dimension
+ // - if threading allows having barriers (e.g. OMP)
+ // - if there is not enough parallelism along M or N
+ if (mkldnn_thr_syncable()) {
+ int nthr_other = nthr_k = 1;
+ while ((nthr_m * nthr_n * nthr_other < nthr)
+ && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) {
+ nthr_other++;
+ if ((nthr / nthr_other) * nthr_other > 0.9 * nthr)
+ nthr_k = nthr_other;
+ }
+ }
+ nthr /= nthr_k;
+
+ if (nthr_m == 1)
+ nthr_n = nthr;
+ if (nthr_n == 1)
+ nthr_m = nthr;
+
+ // Simple partition reduction
+ while (nthr_m * nthr_n > nthr)
+ if (nthr_m > nthr_n)
+ nthr_m--;
+ else
+ nthr_n--;
+ while (nthr_m * nthr_n < nthr)
+ if (nthr_m < nthr_n)
+ nthr_m++;
+ else
+ nthr_n++;
+
+ if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) {
+
+ if (nthr_m <= nthr_n) {
+ nthr_m = (int)sqrt((double)nthr);
+ if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX)
+ nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX;
+ nthr_n = nthr / nthr_m;
+
+ while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
+ nthr_m--;
+ nthr_n = nthr / nthr_m;
+ }
+ } else {
+ nthr_n = (int)sqrt((double)nthr);
+ if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX)
+ nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX;
+ nthr_m = nthr / nthr_n;
+
+ while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
+ nthr_n--;
+ nthr_m = nthr / nthr_n;
+ }
+ }
+ }
+
+ MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1;
+ MB -= MB % BM_SMALL_NOCOPY_AVX;
+ NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1;
+ NB -= NB % BN_SMALL_NOCOPY_AVX;
+ KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1;
+ KB -= KB % BK_SMALL_NOCOPY_AVX;
+
+ if (MB * nthr_m > m)
+ nthr_m = (m + MB - 1) / MB;
+ if (NB * nthr_n > n)
+ nthr_n = (n + NB - 1) / NB;
+ if (KB * nthr_k > k)
+ nthr_k = (k + KB - 1) / KB;
+
+ *nthrs_m = nthr_m;
+ *nthrs_n = nthr_n;
+ *nthrs_k = nthr_k;
+
+ *BM = MB;
+ *BN = NB;
+ *BK = KB;
+}
+#undef BM_NOCOPY_AVX
+#undef BN_NOCOPY_AVX
+#undef BK_NOCOPY_AVX
+#undef BN_LARGE_NOCOPY_AVX
+#undef BM_SMALL_NOCOPY_AVX
+#undef BN_SMALL_NOCOPY_AVX
+#undef BK_SMALL_NOCOPY_AVX
+
+#define BM_NOCOPY_AVX512_COMMON 32
+#define BN_NOCOPY_AVX512_COMMON 64
+#define BK_NOCOPY_AVX512_COMMON 192
+#define BN_LARGE_NOCOPY_AVX512_COMMON 192
+#define BM_SMALL_NOCOPY_AVX512_COMMON 16
+#define BN_SMALL_NOCOPY_AVX512_COMMON 1
+#define BK_SMALL_NOCOPY_AVX512_COMMON 4
+// Determine number of threads for each dimension of a 3-D partitioning
+// algorithm based on input parameters
+// m/n/k - First/second/third parameter for GEMM
+// nthrs - total available number of threads
+// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
+// BM/BN/BK - blocking values
+void calc_nthr_nocopy_avx512_common(int m,
+ int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
+ int *BM, int *BN, int *BK)
+{
+ int nthr, nthr_m, nthr_n, nthr_k = 1;
+ int MB, NB, KB;
+ nthr = nthrs;
+
+ int counter = 0;
+ float ratio_float = 1.;
+ int ratio = 1;
+ nthr = nthrs;
+ int nthr_m_gt_n;
+
+ // Partition along K dimension
+ // - if threading allows having barriers (e.g. OMP)
+ // - if there is not enough parallelism along M or N
+ if (mkldnn_thr_syncable()) {
+ if (n <= 2 * BN_NOCOPY_AVX512_COMMON &&
+ m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) {
+ nthr_k = k / BK_NOCOPY_AVX512_COMMON;
+ if (nthr_k > nthr / 4)
+ nthr_k = nthr / 4;
+ if (nthr_k < 1)
+ nthr_k = 1;
+
+ while ((nthr_k > 1) && (nthr % nthr_k)) {
+ nthr_k--;
+ }
+ nthr /= nthr_k;
+ } else {
+ nthr_k = 1;
+ }
+ }
+ nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON;
+ nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON;
+
+ if (nthr_m < 1)
+ nthr_m = 1;
+ if (nthr_n < 1)
+ nthr_n = 1;
+
+ nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0;
+ ratio_float = (float)nthr_m / nthr_n;
+
+ if (nthr_m_gt_n)
+ ratio = (int)ratio_float;
+ else
+ ratio = (int)(1. / ratio_float);
+
+ // scale down nthr_m and nthr_n if they are too large
+ while (nthr_m * nthr_n > 4 * nthr) {
+ nthr_m /= 2;
+ nthr_n /= 2;
+ }
+
+ if (nthr_m < 1)
+ nthr_m = 1;
+ if (nthr_n < 1)
+ nthr_n = 1;
+
+ // Simple partition reduction
+ counter = 0;
+ while (nthr_m * nthr_n > nthr) {
+ if (nthr_m > nthr_n) {
+ if (counter < ratio)
+ nthr_m--;
+ else {
+ nthr_n--;
+ counter = -1;
+ }
+ } else {
+ if (counter < ratio)
+ nthr_n--;
+ else {
+ nthr_m--;
+ counter = -1;
+ }
+ }
+ counter++;
+ }
+
+ // Simple partition increment
+ counter = 0;
+ while (nthr_m * nthr_n < 0.95 * nthr) {
+ if (nthr_m > nthr_n) {
+ if (counter < ratio)
+ nthr_m++;
+ else {
+ nthr_n++;
+ counter = -1;
+ }
+ } else {
+ if (counter < ratio)
+ nthr_n++;
+ else {
+ nthr_m++;
+ counter = -1;
+ }
+ }
+ counter++;
+ }
+
+ // if nothing works out, then this should work
+ if ((nthr_m * nthr_n > nthr)) {
+
+ if (nthr_m <= nthr_n) {
+ nthr_m = (int)sqrt((double)nthr);
+ if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
+ / BM_SMALL_NOCOPY_AVX512_COMMON)
+ nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
+ / BM_SMALL_NOCOPY_AVX512_COMMON;
+ nthr_n = nthr / nthr_m;
+
+ while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
+ nthr_m--;
+ nthr_n = nthr / nthr_m;
+ }
+ } else {
+ nthr_n = (int)sqrt((double)nthr);
+ if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
+ / BN_SMALL_NOCOPY_AVX512_COMMON)
+ nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
+ / BN_SMALL_NOCOPY_AVX512_COMMON;
+ nthr_m = nthr / nthr_n;
+
+ while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
+ nthr_n--;
+ nthr_m = nthr / nthr_n;
+ }
+ }
+ }
+
+ MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1;
+ MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON;
+ NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1;
+ NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON;
+ KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1;
+ KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON;
+
+ if (MB * nthr_m > m)
+ nthr_m = (m + MB - 1) / MB;
+ if (NB * nthr_n > n)
+ nthr_n = (n + NB - 1) / NB;
+ if (KB * nthr_k > k)
+ nthr_k = (k + KB - 1) / KB;
+
+ *nthrs_m = nthr_m;
+ *nthrs_n = nthr_n;
+ *nthrs_k = nthr_k;
+
+ *BM = MB;
+ *BN = NB;
+ *BK = KB;
+}
+#undef BM_NOCOPY_AVX512_COMMON
+#undef BN_NOCOPY_AVX512_COMMON
+#undef BK_NOCOPY_AVX512_COMMON
+#undef BN_LARGE_NOCOPY_AVX512_COMMON
+#undef BM_SMALL_NOCOPY_AVX512_COMMON
+#undef BN_SMALL_NOCOPY_AVX512_COMMON
+#undef BK_SMALL_NOCOPY_AVX512_COMMON
+
+// Partition n values as equally as possible among nthr threads
+// and set the offset (t_offset) and number of values (t_block) for ithr
+// Assumption: 0 <= ithr < nthr
+void partition_unit_diff(
+ int ithr, int nthr, int n, int *t_offset, int *t_block)
+{
+ int band = n / nthr;
+ if (band == 0)
+ band = 1;
+ int tail = n - band * nthr;
+ if (tail < 0)
+ tail = 0;
+
+ if (ithr < tail) {
+ band++;
+ *t_offset = band * ithr;
+ *t_block = band;
+ } else {
+ *t_offset = band * ithr + tail;
+ *t_block = band;
+ }
+
+ if (*t_offset >= n) {
+ *t_offset = 0;
+ *t_block = 0;
+ }
+
+ if (*t_offset + *t_block > n) {
+ *t_block = n - *t_offset;
+ }
+}
+
+// Sum the m*n values from p_src into p_dst, assuming the two-dimensional
+// arrays have leading dimensions ld_src and ld_dst, respectively
+template<typename data_t>
+void sum_two_matrices(int m, int n,
+ data_t * __restrict p_src, dim_t ld_src,
+ data_t * __restrict p_dst, dim_t ld_dst)
+{
+ int i, j;
+ for (j = 0; j < n; j++) {
+ for (i = 0; i < m; i++) {
+ p_dst[i + j * ld_dst] += p_src[i + j * ld_src];
+ }
+ }
+}
+
+template
+void sum_two_matrices<float>(int m, int n,
+ float * __restrict p_src, dim_t ld_src,
+ float * __restrict p_dst, dim_t ld_dst);
+
+template
+void sum_two_matrices<double>(int m, int n,
+ double * __restrict p_src, dim_t ld_src,
+ double * __restrict p_dst, dim_t ld_dst);
+}
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp
new file mode 100644
index 0000000000..3352298b4a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp
@@ -0,0 +1,72 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef GEMM_UTILS_HPP
+#define GEMM_UTILS_HPP
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace gemm_utils {
+// Alias for any dimension related variable.
+typedef ptrdiff_t dim_t;
+
+template <typename T, bool isTransA, bool isTransB>
+struct gemm_traits {};
+
+template <bool isTransA, bool isTransB>
+struct gemm_traits<double, isTransA, isTransB> {
+ static constexpr int m = 8;
+ static constexpr int n = 6;
+ static constexpr int BM = 4032;
+ static constexpr int BN = isTransA ? 96 : 192;
+ static constexpr int BK = isTransB ? 96 : 512;
+};
+
+template <bool isTransA, bool isTransB>
+struct gemm_traits<float, isTransA, isTransB> {
+ static constexpr int m = 16;
+ static constexpr int n = 6;
+ static constexpr int BM = 4032;
+ static constexpr int BN = isTransA ? 96 : 48;
+ static constexpr int BK = isTransB ? 96 : 256;
+};
+
+template <typename T>
+using unroll_factor = gemm_traits<T, false, false>;
+
+template <typename data_t>
+void sum_two_matrices(int m, int n,
+ data_t * __restrict p_src, dim_t ld_src,
+ data_t * __restrict p_dst, dim_t ld_dst);
+
+void calc_nthr_nocopy_avx512_common(int m,
+ int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
+ int *BM, int *BN, int *BK);
+
+void calc_nthr_nocopy_avx(int m, int n, int k,
+ int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
+ int *BK);
+
+void partition_unit_diff(
+ int ithr, int nthr, int n, int *t_offset, int *t_block);
+};
+
+}
+}
+}
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp
new file mode 100644
index 0000000000..d7be43e392
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp
@@ -0,0 +1,2131 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <cmath>
+#include <mutex>
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "ref_gemm_f32.hpp"
+#include "gemm_utils_f32.hpp"
+#include "jit_avx512_common_gemm_f32.hpp"
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+#define CACHE_LINE_SIZE 64
+
+#define STACKSIZE get_size_of_abi_save_regs()
+#ifdef _WIN32
+#define STACK_K_CAPACITY 32
+#else
+#define STACK_K_CAPACITY 2048
+#endif
+#define SIZE 4
+#define OFFSET 128
+#define BASE_SHIFT 2
+#define SECOND_FETCH unroll_n
+#define UNROLL_M 48
+#define UNROLL_N 8
+
+namespace avx512_common_gemm_f32 {
+using namespace gemm_utils;
+
+struct xbyak_gemm : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm)
+
+ xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
+ void *code_ptr = nullptr,
+ size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
+ : jit_generator(code_ptr, code_size)
+ {
+ using namespace Xbyak;
+
+ enum { ver_avx512_core, ver_avx512_mic } ver =
+ mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic;
+
+ bool isBeta0 = (beta == 0.0);
+ bool isBetaN = (!isBeta0 && beta != 1.0);
+
+ // various definitions for convenience
+ auto ARG_M = abi_param1;
+ auto ARG_N = abi_param2;
+ auto K = abi_param3;
+ auto ARG_ALPHA = abi_param4;
+#ifdef _WIN32
+ auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
+ auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
+ sizeof(float *) + STACKSIZE];
+ const auto stackOffset = OFFSET_SHADOWSPACE +
+ sizeof(float *) + STACKSIZE;
+ auto A = rsi;
+ auto LDA = rdi;
+#else
+ auto ARG_A = r8;
+ auto ARG_LDA = r9;
+ const auto stackOffset = STACKSIZE;
+ auto A = ARG_A;
+ auto LDA = ARG_LDA;
+#endif
+ auto ARG_B = ptr[rsp + 8 + stackOffset];
+ auto ARG_LDB = ptr[rsp + 16 + stackOffset];
+ auto ARG_BETA = ptr[rsp + 24 + stackOffset];
+ auto ARG_C = ptr[rsp + 32 + stackOffset];
+ auto ARG_LDC = ptr[rsp + 40 + stackOffset];
+ auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
+ auto ARG_WS = ptr[rsp + 56 + stackOffset];
+
+ auto B = r11;
+ auto LDB = rbx;
+ auto LDC = r13;
+ auto LL = rax;
+ auto AO1 = abi_param2;
+ auto BO1 = abi_param4;
+ auto BO2 = rbp;
+ auto CO1 = r14;
+ auto CO2 = r15;
+ auto LDB3 = r10;
+ auto LDA4 = abi_param1;
+ auto AA = r12;
+ auto BIAS1 = abi_param1;
+
+ auto M = qword[rsp + 0];
+ auto N = qword[rsp + 8];
+ auto FLAG = qword[rsp + 16];
+ auto I = qword[rsp + 24];
+ auto C = qword[rsp + 32];
+ auto BIAS = qword[rsp + 40];
+ auto ALPHA = qword[rsp + 48];
+ auto BETA = qword[rsp + 64];
+ auto ORIG_A = qword[rsp + 80];
+ auto ORIG_SP = qword[rsp + 120];
+
+ auto ZSTRIDE = zmm4;
+ auto VALPHA = zmm6;
+ auto VBETA = zmm7;
+ auto VBIAS1 = zmm1;
+ auto VBIAS2 = zmm2;
+ auto VBIAS3 = zmm3;
+
+ auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80;
+ auto PREFETCHSIZEB = 16;
+
+ Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15,
+ zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24,
+ zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 };
+
+ // Function for packing if needed
+ auto do_pack = [&](int unroll_m) {
+ Label pack2, pack3, pack4, pack10;
+
+ mov(BO1, A);
+ lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
+ mov(LL, K);
+ sar(LL, 2);
+ jle(pack3, T_NEAR);
+ align(16);
+
+ L(pack2);
+ if (!isTransA) {
+ for (int i = 0; i < 4; i++) {
+ vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
+ if (unroll_m > 16)
+ vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
+ if (unroll_m > 32)
+ vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
+ add(BO1, LDA);
+
+ vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE]
+ | k1,
+ zmm0);
+ if (unroll_m > 16)
+ vmovups(ptr[AO1
+ + (unroll_m * i + 1 * 16 - OFFSET)
+ * SIZE]
+ | k2,
+ zmm1);
+ if (unroll_m > 32)
+ vmovups(ptr[AO1
+ + (unroll_m * i + 2 * 16 - OFFSET)
+ * SIZE]
+ | k3,
+ zmm2);
+ }
+ } else {
+ for (int i = 0; i < 4; i++) {
+ kmovw(k4, k1);
+ vgatherqps(ymm5 | k4,
+ ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO1 + LDA * 8]);
+ kshiftrw(k4, k1, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ vshuff64x2(zmm0, zmm5, zmm6, 0x44);
+
+ if (unroll_m > 16) {
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kmovw(k4, k2);
+ vgatherqps(ymm5 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kshiftrw(k4, k2, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ vshuff64x2(zmm1, zmm5, zmm6, 0x44);
+ }
+
+ if (unroll_m > 32) {
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kmovw(k4, k3);
+ vgatherqps(ymm5 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kshiftrw(k4, k3, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ vshuff64x2(zmm2, zmm5, zmm6, 0x44);
+ }
+
+ vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE],
+ zmm0 | k1);
+ if (unroll_m > 16)
+ vmovups(ptr[AO1
+ + (unroll_m * i + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1 | k2);
+ if (unroll_m > 32)
+ vmovups(ptr[AO1
+ + (unroll_m * i + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2 | k3);
+ }
+ add(BO1, 4 * SIZE);
+ }
+ add(AO1, unroll_m * 4 * SIZE);
+
+ sub(LL, 1);
+ jg(pack2, T_NEAR);
+ align(16);
+
+ L(pack3);
+ mov(LL, K);
+ and_(LL, 3);
+ jle(pack10, T_NEAR);
+ align(16);
+
+ L(pack4);
+ if (!isTransA) {
+ vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
+ if (unroll_m > 16)
+ vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
+ if (unroll_m > 32)
+ vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
+ add(BO1, LDA);
+ } else {
+ kmovw(k4, k1);
+ vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO1 + LDA * 8]);
+ kshiftrw(k4, k1, 8);
+ vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ vshuff64x2(zmm0, zmm5, zmm6, 0x44);
+
+ if (unroll_m > 16) {
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kmovw(k4, k2);
+ vgatherqps(ymm5 | k4,
+ ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kshiftrw(k4, k2, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ vshuff64x2(zmm1, zmm5, zmm6, 0x44);
+ }
+
+ if (unroll_m > 32) {
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kmovw(k4, k3);
+ vgatherqps(ymm5 | k4,
+ ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kshiftrw(k4, k3, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ vshuff64x2(zmm2, zmm5, zmm6, 0x44);
+ }
+ add(BO1, SIZE);
+ }
+
+ vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
+ zmm0 | k1);
+ if (unroll_m > 16)
+ vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE],
+ zmm1 | k2);
+ if (unroll_m > 32)
+ vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE],
+ zmm2 | k3);
+
+ add(AO1, unroll_m * SIZE);
+ sub(LL, 1);
+ jg(pack4, T_NEAR);
+ align(16);
+
+ L(pack10);
+ };
+
+ // Function to update C, covering masking and other considerations
+ auto update = [&](Zmm reg, bool useCO1, int offset, int mask,
+ bool useScale = false) {
+ vmulps(reg, reg, VALPHA);
+ if (!isBeta0) {
+ if (!useScale) {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(zmm0, ptr[CO1 + offset * SIZE]);
+ else
+ vmovups(zmm0, ptr[CO2 + offset * SIZE]);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]);
+ else
+ vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]);
+ else
+ vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]);
+ else
+ vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]);
+ break;
+ }
+ } else {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]);
+ else
+ vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(zmm0 | k1 | T_z,
+ ptr[CO1 + LDC + offset * SIZE]);
+ else
+ vmovups(zmm0 | k1 | T_z,
+ ptr[CO2 + LDC + offset * SIZE]);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(zmm0 | k2 | T_z,
+ ptr[CO1 + LDC + offset * SIZE]);
+ else
+ vmovups(zmm0 | k2 | T_z,
+ ptr[CO2 + LDC + offset * SIZE]);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(zmm0 | k3 | T_z,
+ ptr[CO1 + LDC + offset * SIZE]);
+ else
+ vmovups(zmm0 | k3 | T_z,
+ ptr[CO2 + LDC + offset * SIZE]);
+ break;
+ }
+ }
+ if (!isBetaN) {
+ vaddps(zmm0, reg, zmm0);
+ } else {
+ vfmadd132ps(zmm0, reg, VBETA);
+ }
+ if (!useScale) {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], zmm0);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], zmm0);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3);
+ break;
+ }
+ } else {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3);
+ break;
+ }
+ }
+ } else {
+ if (!useScale) {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], reg);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], reg);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], reg | k1);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], reg | k1);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], reg | k2);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], reg | k2);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], reg | k3);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], reg | k3);
+ break;
+ }
+ } else {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], reg);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], reg);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3);
+ break;
+ }
+ }
+ }
+ vpxorq(reg, reg, reg);
+ };
+
+ // Loop with unroll_n - 2 FMAs; called by innerkernel
+ auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) {
+ for (int i = 2; i < unroll_n; i++) {
+ if (ver == ver_avx512_core) {
+ if (!isTransB) {
+ switch (i) {
+ case 2:
+ vbroadcastss(
+ zmm3,
+ ptr[BO1 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 3:
+ vbroadcastss(
+ zmm3,
+ ptr[BO1 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 4:
+ vbroadcastss(zmm3,
+ ptr[BO2 + (iteration - OFFSET) * SIZE]);
+ break;
+ case 5:
+ vbroadcastss(
+ zmm3,
+ ptr[BO2 + LDB * 1
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 6:
+ vbroadcastss(
+ zmm3,
+ ptr[BO2 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 7:
+ vbroadcastss(
+ zmm3,
+ ptr[BO2 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ }
+ } else {
+ vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
+ }
+ vfmadd231ps(regs[i], zmm3, zmm0);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm3, zmm1);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm3, zmm2);
+ } else {
+ if (!isTransB) {
+ switch (i) {
+ case 2:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO1 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO1 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO1 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 3:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO1 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO1 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO1 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 4:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO2 + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO2 + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO2 + (iteration - OFFSET) * SIZE]);
+ break;
+ case 5:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO2 + LDB * 1
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO2 + LDB * 1
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO2 + LDB * 1
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 6:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO2 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO2 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO2 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 7:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO2 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO2 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO2 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ }
+ } else {
+ vfmadd231ps(
+ regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ }
+ }
+ }
+ };
+
+ // Innerkernel; called by kernel
+ auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect,
+ bool isCopy, bool doCPrefetch, bool isUnmasked = true) {
+ for (int i = 0; i < 8; i++) {
+ if (!isDirect) {
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET)
+ * SIZE]);
+ if (unroll_m >= 32)
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET)
+ * SIZE]);
+ if (unroll_m >= 48)
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]);
+ if (unroll_m >= 32)
+ prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]);
+ if (unroll_m >= 48)
+ prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]);
+ }
+
+ if (!isDirect) {
+ if (i != 0) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0,
+ ptr[AO1
+ + (unroll_m * i + 0 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1
+ + (unroll_m * i + 0 * 16 - OFFSET)
+ * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1
+ + (unroll_m * i + 1 * 16
+ - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1
+ + (unroll_m * i + 1 * 16
+ - OFFSET)
+ * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1
+ + (unroll_m * i + 2 * 16
+ - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1
+ + (unroll_m * i + 2 * 16
+ - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+ } else {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (ver == ver_avx512_core) {
+ if (!isTransB) {
+ vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ vfmadd231ps(regs[0], zmm3, zmm0);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[0 + 8], zmm3, zmm1);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[0 + 16], zmm3, zmm2);
+ } else {
+ if (!isTransB) {
+ vfmadd231ps(regs[0], zmm0,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[0 + 8], zmm1,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[0 + 16], zmm2,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ } else {
+ vfmadd231ps(regs[0], zmm0,
+ zword_b[BO1 + (0 - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[0 + 8], zmm1,
+ zword_b[BO1 + (0 - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[0 + 16], zmm2,
+ zword_b[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ }
+
+ if (unroll_n >= i + 1) {
+ if (!isTransB) {
+ switch (i) {
+ case 0:
+ prefetcht0(
+ ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 1:
+ prefetcht0(ptr[BO1 + LDB
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 2:
+ prefetcht0(ptr[BO1 + LDB * 2
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 3:
+ prefetcht0(ptr[BO1 + LDB3
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 4:
+ prefetcht0(
+ ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 5:
+ prefetcht0(ptr[BO2 + LDB
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 6:
+ prefetcht0(ptr[BO2 + LDB * 2
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 7:
+ prefetcht0(ptr[BO2 + LDB3
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ }
+ }
+ }
+
+ if (unroll_n >= 2) {
+ if (ver == ver_avx512_core) {
+ if (!isTransB) {
+ vbroadcastss(zmm3,
+ ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ vfmadd231ps(regs[1], zmm3, zmm0);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[1 + 8], zmm3, zmm1);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[1 + 16], zmm3, zmm2);
+ } else {
+ if (!isTransB) {
+ vfmadd231ps(regs[1], zmm0,
+ zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[1 + 8], zmm1,
+ zword_b[BO1 + LDB * 1
+ + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[1 + 16], zmm2,
+ zword_b[BO1 + LDB * 1
+ + (i - OFFSET) * SIZE]);
+ } else {
+ vfmadd231ps(regs[1], zmm0,
+ zword_b[BO1 + (1 - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[1 + 8], zmm1,
+ zword_b[BO1 + (1 - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[1 + 16], zmm2,
+ zword_b[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ }
+ }
+
+ if (isCopy) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 0 * 16 - OFFSET)
+ * SIZE],
+ zmm0);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 0 * 16 - OFFSET)
+ * SIZE],
+ zmm0 | k1);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1 | k2);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2 | k3);
+ }
+ }
+ if (i == 7)
+ sub(LDA4, -unroll_m * 8 * SIZE);
+ }
+ fmaloop(unroll_m, unroll_n, i);
+
+ if (i == 1) {
+ if (doCPrefetch) {
+ if (ver == ver_avx512_core)
+ prefetchw(ptr[CO2 + 0 * 16 * SIZE]);
+ else
+ prefetcht0(ptr[CO2 + 0 * 16 * SIZE]);
+ }
+ }
+ if (i == 3) {
+ if (doCPrefetch && unroll_m >= 32) {
+ if (ver == ver_avx512_core)
+ prefetchw(ptr[CO2 + 1 * 16 * SIZE]);
+ else
+ prefetcht0(ptr[CO2 + 1 * 16 * SIZE]);
+ }
+ if (!isTransA) {
+ if (ver == ver_avx512_core)
+ prefetcht0(ptr[AA + 16 * 0 * SIZE]);
+ else
+ prefetcht2(ptr[AA + 16 * 0 * SIZE]);
+ }
+ }
+ if (i == 5) {
+ if (doCPrefetch) {
+ if (unroll_m >= 48) {
+ if (ver == ver_avx512_core)
+ prefetchw(ptr[CO2 + 2 * 16 * SIZE]);
+ else
+ prefetcht0(ptr[CO2 + 2 * 16 * SIZE]);
+ }
+ add(CO2, LDC);
+ }
+ if (!isTransA) {
+ if (unroll_m >= 32) {
+ if (ver == ver_avx512_core)
+ prefetcht0(ptr[AA + 16 * 1 * SIZE]);
+ else
+ prefetcht2(ptr[AA + 16 * 1 * SIZE]);
+ }
+ }
+ }
+
+ if (isTransB) {
+ prefetcht0(ptr[BO1 + BO2]);
+ add(BO1, LDB);
+ }
+ } // end of for loop
+
+ if (!isTransB) {
+ sub(BO1, -8 * SIZE);
+ if (unroll_n >= 4)
+ sub(BO2, -8 * SIZE);
+ }
+ if (!isTransA) {
+ if (unroll_m >= 48) {
+ if (ver == ver_avx512_core)
+ prefetcht0(ptr[AA + 16 * 2 * SIZE]);
+ else
+ prefetcht2(ptr[AA + 16 * 2 * SIZE]);
+ }
+ lea(AA, ptr[AA + LDA]);
+ }
+
+ if (!isDirect) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0,
+ ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1
+ + (unroll_m * 8 + 1 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1
+ + (unroll_m * 8 + 1 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1
+ + (unroll_m * 8 + 2 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1
+ + (unroll_m * 8 + 2 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ sub(AO1, -unroll_m * 8 * SIZE);
+ }
+
+ sub(LL, 1);
+ };
+
+ // Main kernel; does prefetching and calls innerkernel
+ // After calculating results in registers, writes back to C matrix by
+ // calling update
+ auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
+ bool isCopy, bool isUnmasked = true) {
+ if (!isDirect) {
+ lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
+ } else {
+ mov(AO1, A);
+ }
+
+ if (isCopy) {
+ lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]);
+ } else {
+ auto step = ver == ver_avx512_core ? 2 : 4;
+ lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]);
+ }
+
+ if (isTransB) {
+ lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]);
+ }
+
+ if (!isDirect) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0,
+ ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1
+ + (unroll_m * 0 + 1 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1
+ + (unroll_m * 0 + 1 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1
+ + (unroll_m * 0 + 2 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1
+ + (unroll_m * 0 + 2 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+
+ Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18;
+
+ mov(LL, K);
+ sar(LL, 3);
+ sub(LL, SECOND_FETCH);
+ jle(kernel13, T_NEAR);
+ align(16);
+
+ L(kernel12);
+ innerkernel(
+ unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked);
+ jg(kernel12, T_NEAR);
+ align(16);
+
+ L(kernel13);
+ lea(CO2, ptr[CO1 + (16 - 1) * SIZE]);
+ add(LL, unroll_n);
+ jle(kernel15, T_NEAR);
+ align(16);
+
+ L(kernel14);
+ innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked);
+ jg(kernel14, T_NEAR);
+ align(16);
+
+ L(kernel15);
+ mov(LL, K);
+ and_(LL, 7);
+ jle(kernel18, T_NEAR);
+ align(16);
+
+ L(kernel16);
+ if (isDirect) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ for (int i = 0; i < unroll_n; i++) {
+ if (!isTransB) {
+ switch (i) {
+ case 0:
+ vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ break;
+ case 1:
+ vbroadcastss(
+ zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ break;
+ case 2:
+ vbroadcastss(
+ zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ break;
+ case 3:
+ vbroadcastss(
+ zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]);
+ break;
+ case 4:
+ vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]);
+ break;
+ case 5:
+ vbroadcastss(
+ zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ break;
+ case 6:
+ vbroadcastss(
+ zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ break;
+ case 7:
+ vbroadcastss(
+ zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]);
+ break;
+ }
+ } else {
+ vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
+ }
+ vfmadd231ps(regs[i], zmm3, zmm0);
+ if (unroll_m >= 32) {
+ vfmadd231ps(regs[i + 8], zmm3, zmm1);
+ }
+ if (unroll_m >= 48) {
+ vfmadd231ps(regs[i + 16], zmm3, zmm2);
+ }
+ }
+
+ if (isCopy) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
+ zmm0);
+ } else {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
+ zmm0 | k1);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1 | k2);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2 | k3);
+ }
+ }
+ sub(LDA4, -unroll_m * SIZE);
+ }
+
+ if (!isDirect) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0,
+ ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1
+ + (unroll_m * 1 + 1 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1
+ + (unroll_m * 1 + 1 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1
+ + (unroll_m * 1 + 2 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1
+ + (unroll_m * 1 + 2 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ sub(AO1, -unroll_m * SIZE);
+ }
+
+ if (!isTransB) {
+ sub(BO1, -SIZE);
+ if (unroll_n >= 4) {
+ sub(BO2, -SIZE);
+ }
+ } else {
+ add(BO1, LDB);
+ }
+
+ sub(LL, 1);
+ jg(kernel16, T_NEAR);
+ align(16);
+
+ L(kernel18);
+ vbroadcastss(VALPHA, ALPHA);
+
+ if (isBetaN) {
+ vbroadcastss(VBETA, BETA);
+ }
+
+ // Write back the results; all beta cases need to be handled
+ if (hasBias) {
+ mov(BIAS1, BIAS);
+ if (isUnmasked || unroll_m > 16)
+ vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
+ else
+ vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]);
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32)
+ vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]);
+ else
+ vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]);
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked)
+ vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]);
+ else
+ vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]);
+ }
+ }
+
+ for (int i = 0; i < unroll_n; i++) {
+ bool useScale = i % 2 != 0;
+ bool useCO1 = i < 2;
+ if (i == 2)
+ lea(CO2, ptr[CO1 + LDC * 2]);
+ if (i == 4 || i == 6)
+ lea(CO2, ptr[CO2 + LDC * 2]);
+ if (hasBias)
+ vaddps(regs[i], VBIAS1, regs[i]);
+ if (isUnmasked || unroll_m > 16) {
+ update(regs[i], useCO1, 0, 0, useScale);
+ } else {
+ update(regs[i], useCO1, 0, 1, useScale);
+ }
+ if (unroll_m >= 32) {
+ if (hasBias)
+ vaddps(regs[i + 8], VBIAS2, regs[i + 8]);
+ if (isUnmasked || unroll_m > 32) {
+ update(regs[i + 8], useCO1, 16, 0, useScale);
+ } else {
+ update(regs[i + 8], useCO1, 16, 2, useScale);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (hasBias)
+ vaddps(regs[i + 16], VBIAS3, regs[i + 16]);
+ if (isUnmasked) {
+ update(regs[i + 16], useCO1, 32, 0, useScale);
+ } else {
+ update(regs[i + 16], useCO1, 32, 3, useScale);
+ }
+ }
+ }
+
+ switch (unroll_n) {
+ case 1: add(CO1, LDC); break;
+ case 2: lea(CO1, ptr[CO1 + LDC * 2]); break;
+ case 3: lea(CO1, ptr[CO2 + LDC * 1]); break;
+ case 4: lea(CO1, ptr[CO2 + LDC * 2]); break;
+ case 5: lea(CO1, ptr[CO2 + LDC * 1]); break;
+ case 6: lea(CO1, ptr[CO2 + LDC * 2]); break;
+ case 7: lea(CO1, ptr[CO2 + LDC * 1]); break;
+ case 8: lea(CO1, ptr[CO2 + LDC * 2]); break;
+ }
+
+ // Compute next address of B
+ if (!isTransB) {
+ lea(rax, ptr[K * SIZE]);
+ switch (unroll_n) {
+ case 1:
+ add(BO1, LDB);
+ add(BO2, LDB);
+ break;
+ case 2:
+ lea(BO1, ptr[BO1 + LDB * 2]);
+ lea(BO2, ptr[BO2 + LDB * 2]);
+ break;
+ case 3:
+ lea(BO1, ptr[BO1 + LDB3]);
+ lea(BO2, ptr[BO2 + LDB3]);
+ break;
+ case 4:
+ lea(BO1, ptr[BO1 + LDB * 4]);
+ lea(BO2, ptr[BO2 + LDB * 4]);
+ break;
+ case 5:
+ lea(BO1, ptr[BO1 + LDB * 4]);
+ add(BO1, LDB);
+ lea(BO2, ptr[BO2 + LDB * 4]);
+ add(BO2, LDB);
+ break;
+ case 6:
+ lea(BO1, ptr[BO1 + LDB3 * 2]);
+ lea(BO2, ptr[BO2 + LDB3 * 2]);
+ break;
+ case 7:
+ lea(BO1, ptr[BO1 + LDB * 8]);
+ sub(BO1, LDB);
+ lea(BO2, ptr[BO2 + LDB * 8]);
+ sub(BO2, LDB);
+ break;
+ case 8:
+ lea(BO1, ptr[BO1 + LDB * 8]);
+ lea(BO2, ptr[BO2 + LDB * 8]);
+ break;
+ }
+ sub(BO1, rax);
+ sub(BO2, rax);
+ } else {
+ mov(rax, LDB);
+ imul(rax, K);
+ sub(BO1, rax);
+ add(BO1, unroll_n * SIZE);
+ }
+ };
+
+ // High-level subroutine; does packing if needed, then splits C matrix.
+ // Operates on chunks of 48 rows, 8 columns at a time (handling tail
+ // cases appropriately by doing 32 or 16 rows, and/or with masking,
+ // and/or fewer columns).
+ auto subloop = [&](int unroll_m) {
+ Label l_subloop_20x[8], l_subloop_mask_20x[8];
+ Label l_subloop_30x[8], l_subloop_mask_30x[8];
+
+ Label subloop11, subloop11mask;
+ Label subloop30, subloop30mask;
+ Label subloop31, subloop31mask;
+ Label subloop96;
+ Label subloop98, subloop98mask;
+ Label subloop99;
+
+ // Create mask
+ mov(BO1, rcx);
+ mov(rcx, M);
+ sub(rcx, unroll_m - 16);
+ mov(CO1, 16);
+ cmp(rcx, 16);
+
+ cmovg(rcx, CO1);
+ mov(rax, 1);
+ sal(rax, cl);
+ sub(rax, 1);
+ mov(rcx, 0xffff);
+
+ if (unroll_m == 16) {
+ kmovw(k1, eax);
+ } else if (unroll_m == 32) {
+ kmovw(k1, ecx);
+ kmovw(k2, eax);
+ } else {
+ kmovw(k1, ecx);
+ kmovw(k2, ecx);
+ kmovw(k3, eax);
+ }
+ mov(rcx, BO1);
+
+ and_(rax, 0xffff);
+ cmp(rax, 0xffff);
+ jne(subloop96, T_NEAR);
+
+ if (isTransA) {
+ do_pack(unroll_m);
+ }
+
+ mov(CO1, C);
+ add(C, unroll_m * SIZE);
+
+ mov(BO1, B);
+ if (!isTransB) {
+ lea(BO2, ptr[B + LDB * 4]);
+ }
+
+ if (!isTransA) {
+ lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
+ cmp(M, UNROLL_M);
+ jg(subloop98, T_NEAR);
+
+ mov(AA, ORIG_A);
+ lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
+ L(subloop98);
+ }
+
+ mov(LL, N);
+ mov(I, LL);
+ if (!isTransA) {
+ // If N is too small, skip copy operation
+ cmp(LL, UNROLL_N * 3);
+ jle(subloop30, T_NEAR);
+
+ // If A is not aligned to cache line
+ cmp(FLAG, 0);
+ je(subloop30, T_NEAR);
+ } else {
+ cmp(LL, UNROLL_N);
+ jl(l_subloop_20x[1], T_NEAR);
+ }
+ align(16);
+
+ if (!isTransA) {
+ kernel(unroll_m, UNROLL_N, true, true);
+ } else {
+ kernel(unroll_m, UNROLL_N, false, false);
+ }
+
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jl(l_subloop_20x[1], T_NEAR);
+ align(16);
+
+ L(subloop11);
+ kernel(unroll_m, UNROLL_N, false, false);
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop11, T_NEAR);
+ align(16);
+
+ for (int i = 1; i <= 7; i++) {
+ L(l_subloop_20x[i]);
+ cmp(I, i);
+ if (i < 7) {
+ jne(l_subloop_20x[i + 1], T_NEAR);
+ } else {
+ jne(subloop99, T_NEAR);
+ }
+ kernel(unroll_m, i, false, false);
+ jmp(subloop99, T_NEAR);
+ align(16);
+ }
+
+ if (!isTransA) {
+ L(subloop30);
+ cmp(I, UNROLL_N);
+ jl(l_subloop_30x[1], T_NEAR);
+ align(16);
+
+ L(subloop31);
+ kernel(unroll_m, UNROLL_N, true, false);
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop31, T_NEAR);
+ align(16);
+
+ for (int i = 1; i <= 7; i++) {
+ L(l_subloop_30x[i]);
+ cmp(I, i);
+ if (i < 7) {
+ jne(l_subloop_30x[i + 1], T_NEAR);
+ } else {
+ jne(subloop99, T_NEAR);
+ }
+ kernel(unroll_m, i, true, false);
+ if (i < 7)
+ jmp(subloop99, T_NEAR);
+ align(16);
+ }
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop96);
+ if (isTransA) {
+ do_pack(unroll_m);
+ }
+
+ mov(CO1, C);
+ add(C, unroll_m * SIZE);
+ mov(BO1, B);
+ if (!isTransB) {
+ lea(BO2, ptr[B + LDB * 4]);
+ }
+
+ if (!isTransA) {
+ lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
+ cmp(M, UNROLL_M);
+ jg(subloop98mask, T_NEAR);
+ mov(AA, ORIG_A);
+ lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
+ L(subloop98mask);
+ }
+
+ mov(LL, N);
+ mov(I, LL);
+ if (!isTransA) {
+ // If N is too small, skip copy operation
+ cmp(LL, UNROLL_N * 3);
+ jle(subloop30mask, T_NEAR);
+
+ // If A is not aligned to cache line
+ cmp(FLAG, 0);
+ je(subloop30mask, T_NEAR);
+ } else {
+ cmp(LL, UNROLL_N);
+ jl(l_subloop_mask_20x[1], T_NEAR);
+ }
+ align(16);
+
+ if (!isTransA) {
+ kernel(unroll_m, UNROLL_N, true, true, false);
+ } else {
+ kernel(unroll_m, UNROLL_N, false, false, false);
+ }
+
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jl(l_subloop_mask_20x[1], T_NEAR);
+ align(16);
+
+ L(subloop11mask);
+ kernel(unroll_m, UNROLL_N, false, false, false);
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop11mask, T_NEAR);
+ align(16);
+
+ for (int i = 1; i <= 7; i++) {
+ L(l_subloop_mask_20x[i]);
+ cmp(I, i);
+ if (i < 7) {
+ jne(l_subloop_mask_20x[i + 1], T_NEAR);
+ } else {
+ jne(subloop99, T_NEAR);
+ }
+ kernel(unroll_m, i, false, false, false);
+ jmp(subloop99, T_NEAR);
+ align(16);
+ }
+
+ if (!isTransA) {
+ L(subloop30mask);
+ cmp(I, UNROLL_N);
+ jl(l_subloop_mask_30x[1], T_NEAR);
+ align(16);
+
+ L(subloop31mask);
+ kernel(unroll_m, UNROLL_N, true, false, false);
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop31mask, T_NEAR);
+ align(16);
+
+ for (int i = 1; i <= 7; i++) {
+ L(l_subloop_mask_30x[i]);
+ cmp(I, i);
+ if (i < 7) {
+ jne(l_subloop_mask_30x[i + 1], T_NEAR);
+ } else {
+ jne(subloop99, T_NEAR);
+ }
+ kernel(unroll_m, i, true, false, false);
+ if (i < 7)
+ jmp(subloop99, T_NEAR);
+ align(16);
+ }
+ }
+
+ L(subloop99);
+ // Compute address for A
+ if (!isTransA) {
+ add(A, unroll_m * SIZE);
+ } else {
+ mov(rax, LDA);
+ imul(rax, rax, unroll_m);
+ add(A, rax);
+ }
+
+ // Compute next address of BIAS
+ if (hasBias) {
+ add(BIAS, unroll_m * SIZE);
+ }
+ };
+
+ preamble();
+
+ Label buffer_in_ws, buffer_allocated;
+
+ // Get the registers
+ mov(B, ARG_B);
+ mov(LDB, ARG_LDB);
+ mov(r15, ARG_BETA);
+ mov(r12, ARG_C);
+ if (hasBias)
+ mov(r10, ARG_BIAS);
+ mov(LDC, ARG_LDC);
+ mov(rbp, rsp);
+
+ vmovss(xmm0, ptr[ARG_ALPHA]);
+ vmovss(xmm1, ptr[r15]);
+
+#if _WIN32
+ mov(A, ARG_A);
+ mov(LDA, ARG_LDA);
+#endif
+
+ cmp(K, STACK_K_CAPACITY);
+ jg(buffer_in_ws, T_NEAR);
+
+ // Create buffer and align to 4kB page
+ lea(rax, ptr[K * SIZE]);
+ imul(rax, rax, 0x30);
+ add(rax, 256);
+ sub(rsp, rax);
+ and_(rsp, -PAGE_4K);
+ jmp(buffer_allocated, T_NEAR);
+
+ L(buffer_in_ws);
+ mov(rsp, ARG_WS);
+
+ L(buffer_allocated);
+
+ mov(ORIG_SP, rbp);
+ mov(M, ARG_M);
+ mov(N, ARG_N);
+ mov(C, r12);
+ if (hasBias)
+ mov(BIAS, r10);
+ vmovss(ALPHA, xmm0);
+ vmovss(BETA, xmm1);
+ sub(A, -OFFSET * SIZE);
+ sub(B, -OFFSET * SIZE);
+ mov(ORIG_A, A);
+ sal(LDA, BASE_SHIFT);
+ sal(LDB, BASE_SHIFT);
+ sal(LDC, BASE_SHIFT);
+ lea(LDB3, ptr[LDB + LDB * 2]);
+
+ if (isTransA) {
+ vpbroadcastq(zmm2, LDA);
+ vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE);
+ mov(rax, -2);
+ kmovw(k4, eax);
+
+ for (int i = 0; i < 6; i++) {
+ vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
+ kshiftlw(k4, k4, 1);
+ }
+ vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
+ }
+
+ // Check A alignment and leading dimension; take copy-based path as
+ // needed
+ mov(rax, LDA);
+ or_(rax, A);
+ and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f);
+ mov(FLAG, rax);
+
+ for (int i = 8; i < 16; i++) {
+ for (int j = 0; j < 3; j++) {
+ vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j));
+ }
+ }
+
+ Label main0, main1, main2, main999;
+
+ cmp(M, 32);
+ jle(main0, T_NEAR);
+ align(16);
+
+ L(main1);
+ subloop(48);
+ sub(M, UNROLL_M);
+ cmp(M, 32);
+ jg(main1, T_NEAR);
+ align(16);
+
+ L(main0);
+ cmp(M, 16);
+ jle(main2, T_NEAR);
+
+ subloop(32);
+ jmp(main999, T_NEAR);
+ align(16);
+
+ L(main2);
+ cmp(M, 0);
+ jle(main999, T_NEAR);
+ subloop(16);
+ align(16);
+
+ L(main999);
+ // Restore original stack
+ mov(rsp, ORIG_SP);
+
+ vzeroupper();
+ postamble();
+
+ ker_ = this->getCode<ker_t>();
+ }
+
+ typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws);
+
+ void operator()(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws) const
+ {
+ ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
+ }
+
+private:
+ ker_t ker_;
+};
+
+const xbyak_gemm *get_xbyak_gemm(
+ bool isTransA, bool isTransB, float beta, bool hasBias) {
+ auto beta_idx = [](float beta) {
+ return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
+ };
+
+ // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
+ static xbyak_gemm *kernel_table[2][2][2][3];
+ static std::once_flag initialized;
+ std::call_once(initialized, [=]{
+ for (bool isTransA: {false, true})
+ for (bool isTransB: {false, true})
+ for (bool hasBias: {false, true})
+ for (float beta: {0.0f, 1.0f, 2.0f}) {
+ // nocopy sgemm with bias for beta != 0.0 is not supported
+ if (hasBias && beta != 0.0)
+ continue;
+ kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
+ new xbyak_gemm(isTransA, isTransB, beta, hasBias);
+ }
+ });
+
+ return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
+}
+
+void sgemm_nocopy_driver(const char *transa,
+ const char *transb, int m, int n, int k, const float *alpha,
+ const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
+ float *c, dim_t ldc, const float *bias, float *ws)
+{
+ bool isTransA = (*transa == 'T' || *transa == 't');
+ bool isTransB = (*transb == 'T' || *transb == 't');
+
+ int Bm, sizeM, Bn, sizeN, Bk, sizeK;
+
+ int i, j;
+
+ if ((m <= 0) || (n <= 0))
+ return;
+
+ if ((k <= 0) || (alpha[0] == 0.)) {
+
+ if (beta[0] == 0.) {
+ for (j = 0; j < n; j++)
+ for (i = 0; i < m; i++)
+ c[i + j * ldc] = 0.0;
+ } else if (beta[0] != 1.) {
+ for (j = 0; j < n; j++)
+ for (i = 0; i < m; i++)
+ c[i + j * ldc] *= beta[0];
+ }
+
+ return;
+ }
+
+ assert(IMPLICATION(bias != nullptr, *beta == 0.0));
+
+ // XXX: this happens on every thread...
+ bool hasBias = (bias != nullptr);
+ auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
+ auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
+ auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
+ assert(ker_bn && ker_b1 && ker_b0);
+
+ int BM = 4032, BN, BK;
+ if (mayiuse(avx512_core)) {
+ BN = isTransA ? 384 : 64;
+ BK = 384;
+ } else {
+ BN = isTransA ? 96 : 64;
+ BK = isTransB ? 96 : 192;
+ if (!isTransA && !isTransB)
+ BK = 128;
+ }
+ const float *curA, *curB, *curBias = nullptr;
+ float *curC;
+
+ for (Bk = 0; Bk < k; Bk += sizeK) {
+ sizeK = k - Bk;
+ if (sizeK >= BK * 2)
+ sizeK = BK;
+ else {
+ if (sizeK > BK)
+ sizeK = (sizeK + 1) / 2;
+ }
+
+ for (Bm = 0; Bm < m; Bm += sizeM) {
+ sizeM = m - Bm;
+ if (sizeM >= BM * 2)
+ sizeM = BM;
+ else {
+ if (sizeM > BM + BM / 2)
+ sizeM = (sizeM + 1) / 2;
+ }
+
+ for (Bn = 0; Bn < n; Bn += sizeN) {
+ sizeN = n - Bn;
+ if (sizeN >= BN * 2)
+ sizeN = BN;
+ else {
+ if (sizeN > BN + BN / 2)
+ sizeN = (sizeN + 1) / 2;
+ }
+
+ if (!isTransA) {
+ curA = a + Bm + Bk * lda;
+ } else {
+ curA = a + Bk + Bm * lda;
+ }
+ if (!isTransB) {
+ curB = b + Bk + Bn * ldb;
+ } else {
+ curB = b + Bn + Bk * ldb;
+ }
+ curC = c + Bm + (size_t)Bn * ldc;
+ if (bias != nullptr) {
+ if (Bk == 0) {
+ curBias = bias + Bm;
+ } else {
+ curBias = nullptr;
+ }
+ }
+ if (Bk == 0) {
+ if (*beta == 0.0 && bias == nullptr)
+ (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ else
+ (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ } else {
+ (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ }
+ }
+ }
+ }
+}
+
+}
+
+mkldnn_status_t jit_avx512_common_gemm_f32(
+ const char *transa, const char *transb,
+ const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
+ const float *A, const int *p_lda, const float *B, const int *p_ldb,
+ const float *p_beta, float *C, const int *p_ldc, const float *bias)
+{
+ using namespace mkldnn::impl::utils;
+ using namespace avx512_common_gemm_f32;
+ using namespace gemm_utils;
+
+ if (*p_beta != 0 && bias)
+ return ref_gemm(transa, transb, p_m, p_n, p_k,
+ p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
+
+ int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
+
+ int m = *p_m;
+ int n = *p_n;
+ int k = *p_k;
+ dim_t lda = *p_lda;
+ dim_t ldb = *p_ldb;
+ dim_t ldc = *p_ldc;
+ float beta = *p_beta;
+ int MB, NB, KB;
+
+ int nthr_m, nthr_n, nthr_k, nthr_mn;
+
+ // Determine threading partitioning
+ calc_nthr_nocopy_avx512_common(
+ m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
+
+ // May not happen, but just in case
+ if (nthr < nthr_m * nthr_n * nthr_k)
+ nthr = nthr_m * nthr_n * nthr_k;
+
+ nthr_mn = nthr_m * nthr_n;
+
+ unsigned char * ompstatus_ = nullptr;
+ unsigned char volatile *ompstatus = nullptr;
+
+ float *c_buffers = nullptr;
+ float *ws_buffers = nullptr;
+
+ if (nthr_k > 1) {
+ ompstatus_ = (unsigned char *) malloc(
+ nthr * CACHE_LINE_SIZE,
+ CACHE_LINE_SIZE);
+ ompstatus = (unsigned char volatile *) ompstatus_;
+ assert(ompstatus);
+
+ for (int i = 0; i < nthr; i++)
+ ompstatus[i * CACHE_LINE_SIZE] = 0;
+
+ c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
+ * sizeof(float), PAGE_4K);
+ }
+
+ const size_t ws_elems_per_thr = (size_t)k * 48 + 64;
+ const size_t ws_size_per_thr
+ = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
+ if (k > STACK_K_CAPACITY) {
+ ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
+ }
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_m, ithr_n, ithr_k, ithr_mn;
+ int m_from, m_to, myM;
+ int n_from, n_to, myN;
+ int k_from, k_to, myK;
+ int cbase, ibase;
+ const float *myA, *myB, *myBias = nullptr;
+ float *myC = C, myBeta;
+ float *ws = ws_buffers ?
+ ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
+ dim_t ld = ldc;
+
+ int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
+
+ if (ithr < nthr_m * nthr_n * nthr_k) {
+
+ ithr_mn = ithr % nthr_mn;
+ ithr_m = ithr_mn % nthr_m;
+ ithr_n = ithr_mn / nthr_m;
+ ithr_k = ithr / nthr_mn;
+
+ /* swap ithr_k for performance improvement */
+ if (ithr_k == 0)
+ ithr_k = nthr_k - 1;
+ else if (ithr_k == nthr_k - 1)
+ ithr_k = 0;
+
+ m_from = MB * (ithr_m);
+ m_to = MB * (ithr_m + 1);
+ if (m_to > m)
+ m_to = m;
+ myM = m_to - m_from;
+
+ n_from = NB * (ithr_n);
+ n_to = NB * (ithr_n + 1);
+ if (n_to > n)
+ n_to = n;
+ myN = n_to - n_from;
+
+ k_from = KB * (ithr_k);
+ k_to = KB * (ithr_k + 1);
+ if (k_to > k)
+ k_to = k;
+ myK = k_to - k_from;
+
+ cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+ ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
+
+ if ((myM > 0) && (myN > 0)) {
+
+ if (*transa == 'N' || *transa == 'n') {
+ myA = &(A[m_from + k_from * lda]);
+ } else {
+ myA = &(A[k_from + m_from * lda]);
+ }
+ if (*transb == 'N' || *transb == 'n') {
+ myB = &(B[k_from + n_from * ldb]);
+ } else {
+ myB = &(B[n_from + k_from * ldb]);
+ }
+ if (ithr_k == 0) {
+ myC = &(C[m_from + n_from * ldc]);
+ myBeta = beta;
+ ld = ldc;
+ if (bias)
+ myBias = &(bias[m_from]);
+ } else {
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
+ myBeta = 0.0;
+ ld = MB;
+ myBias = nullptr;
+ }
+
+ sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
+ lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
+
+ if (nthr_k > 1 && !sum_later)
+ ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
+ }
+
+ if (nthr_k > 1 && !sum_later) {
+
+ // sum matrices partitioned along K dimension
+ int n1, n2;
+
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+ if (ithr_k > 0) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
+ /* need to wait until main thread finishes */
+ while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
+ };
+
+ /* my cache is hot */
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ if (ik != ithr_k) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
+
+ while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
+ };
+
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+ }
+ }
+ }
+ });
+
+
+ // handle C summation later
+ if (nthr_k > 1 && ompstatus[0] == 0) {
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_m, ithr_n, ithr_k, ithr_mn;
+ int m_from, m_to, myM;
+ int n_from, n_to, myN;
+ int cbase;
+ float *myC = C;
+
+ if (ithr < nthr_m * nthr_n * nthr_k) {
+
+ ithr_mn = ithr % nthr_mn;
+ ithr_m = ithr_mn % nthr_m;
+ ithr_n = ithr_mn / nthr_m;
+ ithr_k = ithr / nthr_mn;
+
+ /* swap ithr_k for performance improvement */
+ if (ithr_k == 0)
+ ithr_k = nthr_k - 1;
+ else if (ithr_k == nthr_k - 1)
+ ithr_k = 0;
+
+ m_from = MB * (ithr_m);
+ m_to = MB * (ithr_m + 1);
+ if (m_to > m)
+ m_to = m;
+ myM = m_to - m_from;
+
+ n_from = NB * (ithr_n);
+ n_to = NB * (ithr_n + 1);
+ if (n_to > n)
+ n_to = n;
+ myN = n_to - n_from;
+
+ cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ if (nthr_k > 1) {
+ // sum matrices partitioned along K dimension
+ int n1, n2;
+
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+ if (ithr_k > 0) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
+
+ /* my cache is hot */
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ if (ik != ithr_k) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
+
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+ }
+ }
+ }
+ });
+ }
+
+ free(c_buffers);
+ free(ompstatus_);
+ free(ws_buffers);
+
+ return mkldnn_success;
+}
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp
new file mode 100644
index 0000000000..d581b7fd71
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp
@@ -0,0 +1,36 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX512_COMMON_GEMM_F32_HPP
+#define JIT_AVX512_COMMON_GEMM_F32_HPP
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t jit_avx512_common_gemm_f32(
+ const char *transa, const char *transb, const int *M,
+ const int *N, const int *K, const float *alpha, const float *A,
+ const int *lda, const float *B, const int *ldb, const float *beta,
+ float *C, const int *ldc, const float *bias = nullptr);
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp
new file mode 100644
index 0000000000..60d4220837
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp
@@ -0,0 +1,2705 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <cmath>
+#include <mutex>
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "ref_gemm_f32.hpp"
+#include "gemm_utils_f32.hpp"
+#include "jit_avx_gemm_f32.hpp"
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+#define CACHE_LINE_SIZE 64
+
+#define STACKSIZE get_size_of_abi_save_regs()
+#if _WIN32
+#define STACK_K_CAPACITY 128
+#else
+#define STACK_K_CAPACITY 8192
+#endif
+#define SIZE 4
+#define OFFSET 32
+#define BASE_SHIFT 2
+#define SECOND_FETCH 14
+
+namespace avx_gemm_f32 {
+using namespace gemm_utils;
+
+struct xbyak_gemm : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
+
+ xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
+ void *code_ptr = nullptr,
+ size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
+ : jit_generator(code_ptr, code_size)
+ {
+ using namespace Xbyak;
+
+ const bool is_avx2 = mayiuse(avx2);
+ assert(IMPLICATION(!is_avx2, mayiuse(avx)));
+
+ const int UNROLL_M = is_avx2 ? 16 : 8;
+ const int UNROLL_N = 6;
+
+ bool isBeta0 = (beta == 0.0);
+ bool isBetaN = (!isBeta0 && beta != 1.0);
+
+ // various definitions for convenience
+ auto ARG_M = abi_param1;
+ auto ARG_N = abi_param2;
+ auto K = abi_param3;
+ auto ARG_ALPHA = abi_param4;
+#ifdef _WIN32
+ auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
+ auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
+ sizeof(float *) + STACKSIZE];
+ const auto stackOffset = OFFSET_SHADOWSPACE +
+ sizeof(float *) + STACKSIZE;
+ auto A = rsi;
+ auto LDA = rdi;
+#else
+ auto ARG_A = r8;
+ auto ARG_LDA = r9;
+ const auto stackOffset = STACKSIZE;
+ auto A = ARG_A;
+ auto LDA = ARG_LDA;
+#endif
+ auto ARG_B = ptr[rsp + 8 + stackOffset];
+ auto ARG_LDB = ptr[rsp + 16 + stackOffset];
+ auto ARG_BETA = ptr[rsp + 24 + stackOffset];
+ auto ARG_C = ptr[rsp + 32 + stackOffset];
+ auto ARG_LDC = ptr[rsp + 40 + stackOffset];
+ auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
+ auto ARG_WS = ptr[rsp + 56 + stackOffset];
+
+ auto B = r11;
+ auto LDB = rbx;
+ auto LDC = r13;
+ auto LL = rax;
+ auto AO1 = abi_param2;
+ auto BO1 = abi_param4;
+ auto BO2 = rbp;
+ auto CO1 = r14;
+ auto CO2 = r15;
+ auto LDB3 = r10;
+ auto LDA4 = abi_param1;
+ auto AA = r12;
+ auto BIAS1 = abi_param1;
+
+ auto M = qword[rsp + 0];
+ auto N = qword[rsp + 8];
+ auto FLAG = qword[rsp + 16];
+ auto I = qword[rsp + 24];
+ auto C = qword[rsp + 32];
+ auto BIAS = qword[rsp + 40];
+ auto ALPHA = qword[rsp + 48];
+ auto BETA = qword[rsp + 64];
+ auto ORIG_A = qword[rsp + 80];
+ auto MASK = dword[rsp + 88];
+ auto STRIDE = qword[rsp + 120];
+ auto ORIG_SP = qword[rsp + 152];
+
+ auto VALPHA = ymm1;
+ auto VBETA = ymm2;
+ auto VMASK = ymm3;
+ auto VBIAS1 = ymm2;
+ auto VBIAS2 = ymm4;
+
+ auto PREFETCHSIZEA = 128;
+ auto PREFETCHSIZEB = (!isTransB) ? -16 : 0;
+
+ // Function for packing if needed
+ auto do_pack = [&](
+ int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
+ Label pack2, pack3, pack4, pack10;
+
+ int regIdx;
+ Reg64 reg;
+
+ mov(BO1, A);
+ lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
+
+ if (isTransA) {
+ lea(BO2, ptr[BO1 + LDA * 4]);
+ lea(CO1, ptr[LDA + LDA * 2]);
+ vmovupd(ymm7, STRIDE);
+ }
+
+ mov(LL, K);
+ sar(LL, 2);
+ jle(pack3, T_NEAR);
+ align(16);
+
+ L(pack2);
+ if (!isTransA) {
+ for (int i = 0; i < 4; i++) {
+ regIdx = (i % 2 == 0) ? 4 : 6;
+ if (isLoad1Unmasked) {
+ vmovups(Ymm(regIdx),
+ ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(Ymm(regIdx), VMASK,
+ ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m > 8) {
+ if (isLoad2Unmasked) {
+ vmovups(Ymm(regIdx + 1),
+ ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(Ymm(regIdx + 1), VMASK,
+ ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(BO1, LDA);
+
+ vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
+ Ymm(regIdx));
+ if (unroll_m > 8) {
+ vmovups(ptr[AO1
+ + (unroll_m * i + 1 * 8 - OFFSET)
+ * SIZE],
+ Ymm(regIdx + 1));
+ }
+ }
+
+ } else {
+ if (isLoad1Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ reg = (i % 2 == 0) ? BO1 : BO2;
+ vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]);
+ vmovups(xmm1,
+ ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[reg + LDA * 2]);
+ vunpcklps(xmm4, xmm0, xmm1);
+ vunpckhps(xmm5, xmm0, xmm1);
+ vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovups(xmm1,
+ ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(xmm6, xmm0, xmm1);
+ vunpckhps(xmm2, xmm0, xmm1);
+
+ vunpcklpd(xmm0, xmm4, xmm6);
+ vunpckhpd(xmm1, xmm4, xmm6);
+ vmovups(ptr[AO1
+ + (unroll_m * 0 + i * 4 - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * 1 + i * 4 - OFFSET)
+ * SIZE],
+ xmm1);
+ vunpcklpd(xmm0, xmm5, xmm2);
+ vunpckhpd(xmm1, xmm5, xmm2);
+ vmovups(ptr[AO1
+ + (unroll_m * 2 + i * 4 - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * 3 + i * 4 - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+ } else if (is_avx2) {
+ for (int i = 0; i < 2; i++) {
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm0,
+ ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE],
+ xmm4);
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm1,
+ ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
+ xmm4);
+
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i) + 0 * 4 - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i + 1) + 0 * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+
+ lea(BO2, ptr[BO1 + LDA * 4]);
+
+ for (int i = 0; i < 2; i++) {
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm0,
+ ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
+ xmm4);
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
+ xmm4);
+
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i) + 1 * 4 - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i + 1) + 1 * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ } else {
+ vxorps(xmm4, xmm4, xmm4);
+ lea(BO2, ptr[BO1 + LDA * 4]);
+
+ auto el_cp = [&](int section, int ld_step) {
+ RegExp src_addr = section == 0 ? BO1 : BO2;
+ if (ld_step == 1 || ld_step == 2)
+ src_addr = src_addr + LDA * ld_step;
+ else if (ld_step == 3)
+ src_addr = src_addr + CO1;
+ src_addr = src_addr - OFFSET * SIZE;
+
+ vmovups(Xmm(ld_step % 2), ptr[src_addr]);
+ RegExp dst_addr = AO1
+ + (ld_step + section * 4 - OFFSET) * SIZE;
+ for (int off = 0; off < 4; ++off)
+ pextrd(ptr[dst_addr + unroll_m * off * SIZE],
+ Xmm(ld_step % 2), off);
+ };
+
+ Label l_end;
+ el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
+ el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
+ el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
+ el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
+ el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
+ el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
+ el_cp(1, 2);
+ L(l_end);
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ }
+
+ if (unroll_m >= 16) {
+ assert(is_avx2);
+ if (isLoad2Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovups(xmm1, ptr[BO2 + LDA * 1
+ + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(xmm4, xmm0, xmm1);
+ vunpckhps(xmm5, xmm0, xmm1);
+ vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovups(xmm1, ptr[BO2 + LDA * 1
+ + (0 * 8 - OFFSET) * SIZE]);
+ if (i == 0)
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(xmm6, xmm0, xmm1);
+ vunpckhps(xmm2, xmm0, xmm1);
+
+ vunpcklpd(xmm0, xmm4, xmm6);
+ vunpckhpd(xmm1, xmm4, xmm6);
+ vmovups(ptr[AO1
+ + (unroll_m * 0 + (i + 2) * 4
+ - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * 1 + (i + 2) * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ vunpcklpd(xmm0, xmm5, xmm2);
+ vunpckhpd(xmm1, xmm5, xmm2);
+ vmovups(ptr[AO1
+ + (unroll_m * 2 + (i + 2) * 4
+ - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * 3 + (i + 2) * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+ } else {
+ for (int i = 0; i < 2; i++) {
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm0,
+ ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
+ xmm4);
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7
+ + ((2 * i + 1) - OFFSET) * SIZE],
+ xmm4);
+
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i) + 2 * 4
+ - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i + 1) + 2 * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+
+ for (int i = 0; i < 2; i++) {
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm0,
+ ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
+ xmm4);
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7
+ + ((2 * i + 1) - OFFSET) * SIZE],
+ xmm4);
+
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i) + 3 * 4
+ - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i + 1) + 3 * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ }
+ }
+ add(BO1, (4 * SIZE));
+ }
+
+ add(AO1, unroll_m * 4 * SIZE);
+ sub(LL, 1);
+ jg(pack2, T_NEAR);
+ align(16);
+
+ L(pack3);
+ mov(LL, K);
+ and_(LL, 3);
+ jle(pack10, T_NEAR);
+ align(16);
+
+ L(pack4);
+ if (!isTransA) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m > 8) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm5, VMASK,
+ ptr[BO1 + (1 + 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(BO1, LDA);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
+ ymm4);
+ if (unroll_m > 8) {
+ vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
+ ymm5);
+ }
+ } else {
+ if (isLoad1Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ reg = (i % 2 == 0) ? BO1 : BO2;
+ vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]);
+ vmovss(xmm0,
+ ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[reg + LDA * 2]);
+ vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
+ }
+ vunpcklpd(xmm1, xmm1, xmm2);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
+ xmm1);
+
+ for (int i = 0; i < 2; i++) {
+ vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovss(xmm0,
+ ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
+ }
+ vunpcklpd(xmm1, xmm1, xmm2);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
+ xmm1);
+ } else if (is_avx2) {
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE],
+ xmm4);
+ lea(BO2, ptr[BO1 + LDA * 4]);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
+ xmm1);
+
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
+ xmm4);
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
+ xmm1);
+ } else {
+ vxorps(xmm4, xmm4, xmm4);
+ lea(BO2, ptr[BO1 + LDA * 4]);
+
+ auto el_cp = [&](int section, int ld_step) {
+ RegExp src_addr = section == 0 ? BO1 : BO2;
+ if (ld_step == 1 || ld_step == 2)
+ src_addr = src_addr + LDA * ld_step;
+ else if (ld_step == 3)
+ src_addr = src_addr + CO1;
+ src_addr = src_addr - OFFSET * SIZE;
+
+ vmovss(xmm1, ptr[src_addr]);
+ RegExp dst_addr = AO1
+ + (ld_step + section * 4 - OFFSET) * SIZE;
+ movss(ptr[dst_addr], xmm1);
+ };
+
+ Label l_end;
+ el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
+ el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
+ el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
+ el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
+ el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
+ el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
+ el_cp(1, 2);
+ L(l_end);
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ }
+
+ if (unroll_m >= 16) {
+ assert(is_avx2);
+ if (isLoad2Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ vmovss(Xmm(i + 1),
+ ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovss(xmm0, ptr[BO2 + LDA * 1
+ + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
+ }
+ vunpcklpd(xmm1, xmm1, xmm2);
+ } else {
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
+ xmm4);
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ }
+ vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE],
+ xmm1);
+
+ if (isLoad2Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ vmovss(Xmm(i + 1),
+ ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovss(xmm0, ptr[BO2 + LDA * 1
+ + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
+ }
+ vunpcklpd(xmm1, xmm1, xmm2);
+ } else {
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
+ xmm4);
+ }
+ vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE],
+ xmm1);
+ }
+ add(BO1, SIZE);
+ }
+
+ add(AO1, unroll_m * SIZE);
+ sub(LL, 1);
+ jg(pack4, T_NEAR);
+ align(16);
+
+ L(pack10);
+ };
+
+ // Fused multiply add; may become one or two instructions
+ auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2,
+ bool overWrite = false) {
+ if (useFma) {
+ if (is_avx2) {
+ vfmadd231ps(reg2, reg1, reg0);
+ } else {
+ assert(UNROLL_M == 8);
+ auto tent_vreg = overWrite ? reg1 : ymm1;
+ vmulps(tent_vreg, reg1, reg0);
+ vaddps(reg2, reg2, tent_vreg);
+ }
+ } else {
+ if (!overWrite) {
+ vmulps(ymm15, reg1, reg0);
+ vaddps(reg2, reg2, ymm15);
+ } else {
+ vmulps(reg1, reg1, reg0);
+ vaddps(reg2, reg2, reg1);
+ }
+ }
+ };
+
+ // Inner kernel with k=8
+ auto innerkernel8 = [&](int unroll_m, int unroll_n,
+ bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
+ bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
+ Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
+ Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
+ Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
+ Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
+ Ymm reg23) {
+
+ Ymm fmareg;
+
+ if (!isDirect) {
+ prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+
+ for (int i = 0; i < 8; i++) {
+ if (isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg00 : reg12;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg06 : reg18;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ if (i == 0) {
+ if (!isTransB) {
+ prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
+ }
+ }
+ if (unroll_n >= 2) {
+ if (!isTransB) {
+ if (i == 1) {
+ prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg01 : reg13;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg07 : reg19;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (isCopy) {
+ vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
+ ymm0);
+ if (unroll_m >= 16) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 1 * 8 - OFFSET)
+ * SIZE],
+ ymm1);
+ }
+ if (i == 7) {
+ sub(LDA4, -unroll_m * 8 * SIZE);
+ }
+ }
+
+ if (unroll_n >= 3) {
+ if (!isTransB) {
+ if (i == 2) {
+ prefetcht0(
+ ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg02 : reg14;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg08 : reg20;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (i == 7) {
+ if (!isTransB) {
+ sub(BO1, -8 * SIZE);
+ }
+ }
+
+ if (unroll_n >= 4) {
+ if (!isTransB) {
+ if (i == 3) {
+ prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg03 : reg15;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg09 : reg21;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 5) {
+ if (!isTransB) {
+ if (i == 4) {
+ prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg04 : reg16;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg10 : reg22;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 6) {
+ if (!isTransB) {
+ if (i == 5) {
+ prefetcht0(
+ ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg05 : reg17;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg11 : reg23;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+ if (isTransB) {
+ prefetcht0(ptr[BO1 + BO2]);
+ add(BO1, LDB);
+ }
+
+ if (i == 0) {
+ if (unroll_m >= 4) {
+ if (!isDirect) {
+ prefetcht0(
+ ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 1 || i == 2) {
+ if (unroll_m >= 8) {
+ if (!isDirect) {
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + (2 + 2 * i) * 8)
+ * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 3 || i == 4 || i == 5 || i == 6) {
+ if (unroll_m >= 16) {
+ if (!isDirect) {
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + (2 + 2 * i) * 8)
+ * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 7) {
+ if (!isTransB) {
+ if (unroll_n >= 4) {
+ sub(BO2, -8 * SIZE);
+ }
+ }
+ if (!isTransA) {
+ prefetcht2(ptr[AA]);
+ lea(AA, ptr[AA + LDA]);
+ }
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(
+ ymm0, VMASK,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
+ * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1
+ + (unroll_m * (i + 1) + 1 * 8
+ - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 1 * 8
+ - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+ }
+
+ if (!isDirect) {
+ sub(AO1, -unroll_m * 8 * SIZE);
+ }
+ sub(LL, 1);
+
+ };
+
+ // Inner kernel with k=4
+ auto innerkernel4 = [&](int unroll_m, int unroll_n,
+ bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
+ bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
+ Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
+ Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
+ Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
+ Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
+ Ymm reg23) {
+
+ Ymm fmareg;
+
+ if (!isDirect) {
+ prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+
+ for (int i = 0; i < 4; i++) {
+ if (isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg00 : reg12;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg06 : reg18;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ if (i == 0) {
+ if (!isTransB) {
+ prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
+ }
+ }
+ if (unroll_n >= 2) {
+ if (!isTransB) {
+ if (i == 1) {
+ prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg01 : reg13;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg07 : reg19;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (isCopy) {
+ vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
+ ymm0);
+ if (unroll_m >= 16) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 1 * 8 - OFFSET)
+ * SIZE],
+ ymm1);
+ }
+ if (i == 3) {
+ sub(LDA4, -unroll_m * 4 * SIZE);
+ }
+ }
+
+ if (unroll_n >= 3) {
+ if (!isTransB) {
+ if (i == 2) {
+ prefetcht0(
+ ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg02 : reg14;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg08 : reg20;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (i == 7) {
+ if (!isTransB) {
+ sub(BO1, -8 * SIZE);
+ }
+ }
+
+ if (unroll_n >= 4) {
+ if (!isTransB) {
+ if (i == 3) {
+ prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg03 : reg15;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg09 : reg21;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 5) {
+ if (!isTransB) {
+ if (i == 4) {
+ prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg04 : reg16;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg10 : reg22;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 6) {
+ if (!isTransB) {
+ if (i == 5) {
+ prefetcht0(
+ ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg05 : reg17;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg11 : reg23;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+ if (isTransB) {
+ prefetcht0(ptr[BO1 + BO2]);
+ add(BO1, LDB);
+ }
+
+ if (i == 0) {
+ if (unroll_m >= 4) {
+ if (!isDirect) {
+ prefetcht0(
+ ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 1 || i == 2) {
+ if (unroll_m >= 8) {
+ if (!isDirect) {
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + (2 + 2 * i) * 8)
+ * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 3) {
+ if (!isTransB) {
+ sub(BO1, -4 * SIZE);
+ if (unroll_n >= 4) {
+ sub(BO2, -4 * SIZE);
+ }
+ }
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(
+ ymm0, VMASK,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
+ * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1
+ + (unroll_m * (i + 1) + 1 * 8
+ - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 1 * 8
+ - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+ }
+
+ if (!isDirect) {
+ sub(AO1, -unroll_m * 4 * SIZE);
+ }
+
+ };
+
+ // Inner kernel with k=2
+ auto innerkernel2 = [&](int unroll_m, int unroll_n,
+ bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
+ bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
+ Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
+ Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
+ Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
+ Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
+ Ymm reg23) {
+
+ Ymm fmareg;
+
+ for (int i = 0; i < 2; i++) {
+ if (isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg00 : reg12;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg06 : reg18;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ if (unroll_n >= 2) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg01 : reg13;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg07 : reg19;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 3) {
+ if (!isTransB) {
+ if (i == 2) {
+ prefetcht0(
+ ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg02 : reg14;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg08 : reg20;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 4) {
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg03 : reg15;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg09 : reg21;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 5) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg04 : reg16;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg10 : reg22;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 6) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg05 : reg17;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg11 : reg23;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (isCopy) {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
+ ymm0);
+ if (unroll_m >= 16) {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 1 * 8 - OFFSET)
+ * SIZE],
+ ymm1);
+ }
+ sub(LDA4, -unroll_m * SIZE);
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1
+ + (unroll_m * 1 + 0 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1
+ + (unroll_m * 1 + 0 * 8 - OFFSET)
+ * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1,
+ ptr[AO1
+ + (unroll_m * 1 + 1 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * 1 + 1 * 8 - OFFSET)
+ * SIZE]);
+ }
+ }
+ sub(AO1, -unroll_m * SIZE);
+ }
+
+ if (!isTransB) {
+ sub(BO1, -SIZE);
+ if (unroll_n >= 4) {
+ sub(BO2, -SIZE);
+ }
+ } else {
+ add(BO1, LDB);
+ }
+ }
+
+ };
+
+ // Inner kernel with k=1
+ auto innerkernel1 = [&](int unroll_m, int unroll_n,
+ bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
+ bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
+ Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
+ Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) {
+
+ if (isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg00);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg06);
+ }
+
+ if (unroll_n >= 2) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg01);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg07);
+ }
+ }
+
+ if (unroll_n >= 3) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg02);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg08);
+ }
+ }
+
+ if (unroll_n >= 4) {
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg03);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg09);
+ }
+ }
+
+ if (unroll_n >= 5) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg04);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg10);
+ }
+ }
+
+ if (unroll_n >= 6) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg05);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg11);
+ }
+ }
+
+ if (isCopy) {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
+ ymm0);
+ if (unroll_m >= 16) {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
+ ymm1);
+ }
+ sub(LDA4, -unroll_m * SIZE);
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0,
+ ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1
+ + (unroll_m * 1 + 1 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * 1 + 1 * 8 - OFFSET)
+ * SIZE]);
+ }
+ }
+ sub(AO1, -unroll_m * SIZE);
+ }
+
+ if (!isTransB) {
+ sub(BO1, -SIZE);
+ if (unroll_n >= 4) {
+ sub(BO2, -SIZE);
+ }
+ } else {
+ add(BO1, LDB);
+ }
+
+ };
+
+ // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as
+ // appropriate
+ // After calculating results in registers, writes back to C matrix
+ auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma,
+ Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6),
+ Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9),
+ Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12),
+ Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15),
+ Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6),
+ Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
+ Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
+ Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
+ if (!isDirect) {
+ lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
+ } else {
+ mov(AO1, A);
+ }
+
+ if (isCopy) {
+ lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
+ } else {
+ lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
+ }
+
+ if (isTransB) {
+ lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDB * 2]);
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0,
+ ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1
+ + (unroll_m * 0 + 1 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * 0 + 1 * 8 - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+
+ for (int i = 4; i < 10; i++) {
+ vxorps(Ymm(i), Ymm(i), Ymm(i));
+ vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6));
+ }
+
+ mov(LL, K);
+ sar(LL, 3);
+
+ Label kernel12, kernel13, kernel14, kernel15;
+ Label kernel16, kernel17, kernel18;
+
+ sub(LL, SECOND_FETCH);
+ jle(kernel13, T_NEAR);
+ align(16);
+
+ L(kernel12);
+ innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
+ reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
+ reg21, reg22, reg23);
+ jg(kernel12, T_NEAR);
+ align(16);
+
+ L(kernel13);
+ prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 2)
+ prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 3)
+ prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 4)
+ prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 5)
+ prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 6)
+ prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
+
+ add(LL, SECOND_FETCH);
+ jle(kernel15, T_NEAR);
+ align(16);
+
+ L(kernel14);
+ innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
+ reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
+ reg21, reg22, reg23);
+ jg(kernel14, T_NEAR);
+ align(16);
+
+ L(kernel15);
+ test(K, 4);
+ jle(kernel16, T_NEAR);
+ innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
+ reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
+ reg21, reg22, reg23);
+
+ L(kernel16);
+ test(K, 2);
+ jle(kernel17, T_NEAR);
+ innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
+ reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
+ reg21, reg22, reg23);
+ align(16);
+
+ L(kernel17);
+ if (unroll_m == 16) {
+ if (unroll_n <= 3) {
+ vaddps(reg00, reg00, reg12);
+ vaddps(reg01, reg01, reg13);
+ vaddps(reg02, reg02, reg14);
+ vaddps(reg06, reg06, reg18);
+ vaddps(reg07, reg07, reg19);
+ vaddps(reg08, reg08, reg20);
+ }
+ }
+
+ if (unroll_m <= 8) {
+ vaddps(reg00, reg00, reg12);
+ vaddps(reg01, reg01, reg13);
+ vaddps(reg02, reg02, reg14);
+ vaddps(reg03, reg03, reg15);
+ vaddps(reg04, reg04, reg16);
+ vaddps(reg05, reg05, reg17);
+ }
+
+ test(K, 1);
+ jle(kernel18, T_NEAR);
+ innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11);
+ align(16);
+
+ L(kernel18);
+ vbroadcastss(VALPHA, ALPHA);
+
+ if (isBetaN) {
+ vbroadcastss(VBETA, BETA);
+ }
+
+ // Write back the results; all beta and bias cases need to be
+ // handled
+ switch (unroll_n) {
+ case 1: mov(rax, LDC); break;
+ case 2: lea(rax, ptr[LDC * 2]); break;
+ case 3: lea(rax, ptr[LDC + LDC * 2]); break;
+ case 4: lea(rax, ptr[LDC + LDC * 4]); break;
+ case 5:
+ lea(rax, ptr[LDC * 4]);
+ add(rax, LDC);
+ break;
+ case 6:
+ lea(rax, ptr[LDC + LDC * 2]);
+ add(rax, rax);
+ break;
+ }
+
+ if (hasBias) {
+ mov(BIAS1, BIAS);
+ if (isLoad1Unmasked) {
+ vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
+ } else {
+ vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]);
+ }
+ }
+
+ for (int i = 0; i < unroll_n; i++) {
+ vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA);
+ if (!isBeta0) {
+ if (isLoad1Unmasked) {
+ switch (i) {
+ case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break;
+ case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break;
+ case 2:
+ vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]);
+ break;
+ case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break;
+ case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break;
+ case 5:
+ vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]);
+ break;
+ }
+ } else {
+ switch (i) {
+ case 0:
+ vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]);
+ break;
+ case 1:
+ vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]);
+ break;
+ case 2:
+ vmaskmovps(
+ ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]);
+ break;
+ case 3:
+ vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]);
+ break;
+ case 4:
+ vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]);
+ break;
+ case 5:
+ vmaskmovps(
+ ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]);
+ break;
+ }
+ }
+
+ if (!isBetaN) {
+ vaddps(Ymm(i + 4), ymm0, Ymm(i + 4));
+ } else {
+ fma(useFma, VBETA, ymm0, Ymm(i + 4), true);
+ }
+ }
+ if (hasBias) {
+ vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4));
+ }
+ if (isLoad1Unmasked) {
+ switch (i) {
+ case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break;
+ case 1:
+ vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4));
+ break;
+ case 2:
+ vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
+ break;
+ case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break;
+ case 4:
+ vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4));
+ break;
+ case 5:
+ vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
+ break;
+ }
+ } else {
+ switch (i) {
+ case 0:
+ vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4));
+ break;
+ case 1:
+ vmaskmovps(
+ ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
+ break;
+ case 2:
+ vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK,
+ Ymm(i + 4));
+ break;
+ case 3:
+ vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4));
+ break;
+ case 4:
+ vmaskmovps(
+ ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
+ break;
+ case 5:
+ vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK,
+ Ymm(i + 4));
+ break;
+ }
+ }
+
+ if (unroll_m >= 16) {
+ // Re-use ymm4 (VBIAS2)
+ if (i == 0) {
+ if (hasBias) {
+ if (isLoad1Unmasked) {
+ vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]);
+ } else {
+ vmaskmovps(
+ VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]);
+ }
+ }
+ }
+ vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA);
+ if (!isBeta0) {
+ if (isLoad2Unmasked) {
+ switch (i) {
+ case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break;
+ case 1:
+ vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]);
+ break;
+ case 2:
+ vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]);
+ break;
+ case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break;
+ case 4:
+ vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]);
+ break;
+ case 5:
+ vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]);
+ break;
+ }
+ } else {
+ switch (i) {
+ case 0:
+ vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]);
+ break;
+ case 1:
+ vmaskmovps(
+ ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]);
+ break;
+ case 2:
+ vmaskmovps(ymm0, VMASK,
+ ptr[CO1 + LDC * 2 + 8 * SIZE]);
+ break;
+ case 3:
+ vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]);
+ break;
+ case 4:
+ vmaskmovps(
+ ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]);
+ break;
+ case 5:
+ vmaskmovps(ymm0, VMASK,
+ ptr[CO2 + LDC * 2 + 8 * SIZE]);
+ break;
+ }
+ }
+ if (!isBetaN) {
+ vaddps(Ymm(i + 10), ymm0, Ymm(i + 10));
+ } else {
+ fma(useFma, VBETA, ymm0, Ymm(i + 10), true);
+ }
+ }
+ if (hasBias) {
+ vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10));
+ }
+ if (isLoad2Unmasked) {
+ switch (i) {
+ case 0:
+ vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 1:
+ vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 2:
+ vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 3:
+ vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 4:
+ vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 5:
+ vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
+ break;
+ }
+ } else {
+ switch (i) {
+ case 0:
+ vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10));
+ break;
+ case 1:
+ vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK,
+ Ymm(i + 10));
+ break;
+ case 2:
+ vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK,
+ Ymm(i + 10));
+ break;
+ case 3:
+ vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10));
+ break;
+ case 4:
+ vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK,
+ Ymm(i + 10));
+ break;
+ case 5:
+ vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK,
+ Ymm(i + 10));
+ break;
+ }
+ }
+ }
+ if (i == 2)
+ add(CO1, rax);
+ }
+ if (unroll_n >= 4) {
+ add(CO2, rax);
+ }
+
+ // Compute next address of B
+ if (!isTransB) {
+ lea(rax, ptr[K * SIZE]);
+ switch (unroll_n) {
+ case 1:
+ add(BO1, LDB);
+ add(BO2, LDB);
+ break;
+ case 2:
+ lea(BO1, ptr[BO1 + LDB * 2]);
+ lea(BO2, ptr[BO2 + LDB * 2]);
+ break;
+ case 3:
+ lea(BO1, ptr[BO1 + LDB3]);
+ lea(BO2, ptr[BO2 + LDB3]);
+ break;
+ case 4:
+ lea(BO1, ptr[BO1 + LDB * 4]);
+ lea(BO2, ptr[BO2 + LDB * 4]);
+ break;
+ case 5:
+ lea(BO1, ptr[BO1 + LDB * 4]);
+ add(BO1, LDB);
+ lea(BO2, ptr[BO2 + LDB * 4]);
+ add(BO2, LDB);
+ break;
+ case 6:
+ lea(BO1, ptr[BO1 + LDB3 * 2]);
+ lea(BO2, ptr[BO2 + LDB3 * 2]);
+ break;
+ }
+ sub(BO1, rax);
+ sub(BO2, rax);
+ } else {
+ mov(rax, LDB);
+ imul(rax, K);
+ sub(BO1, rax);
+ add(BO1, unroll_n * SIZE);
+ }
+ };
+
+ auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, true);
+ };
+
+ auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, true);
+ };
+
+ auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, true);
+ };
+
+ auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy,
+ bool useFma = true) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
+ Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
+ Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
+ Ymm(13), Ymm(14), Ymm(15));
+ };
+
+ auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, false);
+ };
+
+ auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, false);
+ };
+
+ auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy,
+ bool useFma = true) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
+ Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
+ Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
+ Ymm(15));
+ };
+
+ auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy);
+ };
+
+ auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy);
+ };
+
+ auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy,
+ bool useFma = true) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
+ Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
+ Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
+ Ymm(13), Ymm(14), Ymm(15));
+ };
+
+ auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, false);
+ };
+
+ auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, false);
+ };
+
+ // High-level subroutine; does packing if needed, then splits C matrix.
+ // Operates on chunks of 16 rows, 6 columns at a time (handling tail
+ // cases appropriately).
+ // Masking is used for tail cases where M is not divisible by 8.
+ auto subloop = [&](
+ int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
+ if (isTransA) {
+ do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
+ }
+
+ Label subloop11, subloop11mask;
+ Label subloop20, subloop21, subloop22, subloop23;
+ Label subloop24, subloop25;
+ Label subloop30, subloop31, subloop32, subloop33;
+ Label subloop34, subloop35;
+ Label subloop98, subloop98mask;
+ Label subloop99, subloop99mask;
+
+ mov(CO1, C);
+ lea(CO2, ptr[CO1 + LDC * 2]);
+ add(CO2, LDC);
+ add(C, unroll_m * SIZE);
+ mov(BO1, B);
+ if (!isTransB) {
+ lea(BO2, qword[B + LDB3]);
+ }
+
+ if (!isTransA) {
+ lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
+ cmp(M, UNROLL_M);
+ jg(subloop98, T_NEAR);
+
+ mov(AA, ORIG_A);
+ lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
+ L(subloop98);
+ }
+
+ mov(LL, N);
+ mov(I, LL);
+ if (!isTransA) {
+ // If N is too small, skip copy operation
+ cmp(LL, UNROLL_N * 3);
+ jle(subloop30, T_NEAR);
+
+ // If A is not aligned to cache line
+ cmp(FLAG, 0);
+ je(subloop30, T_NEAR);
+ } else {
+ cmp(LL, UNROLL_N);
+ jl(subloop20, T_NEAR);
+ }
+ align(16);
+
+ if (!isTransA) {
+ if (unroll_m == 16) {
+ kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, true, true);
+ } else {
+ kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, true, true);
+ }
+ } else {
+ if (unroll_m == 16) {
+ kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, false, false);
+ } else {
+ kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, false, false);
+ }
+ }
+
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jl(subloop20, T_NEAR);
+ align(16);
+
+ L(subloop11);
+ if (unroll_m == 16) {
+ kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, false, false);
+ } else {
+ kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ }
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop11, T_NEAR);
+ align(16);
+
+ L(subloop20);
+ cmp(I, 1);
+ jne(subloop21, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop21);
+ cmp(I, 2);
+ jne(subloop22, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop22);
+ cmp(I, 3);
+ jne(subloop23, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop23);
+ cmp(I, 4);
+ jne(subloop24, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop24);
+ cmp(I, 5);
+ jne(subloop99, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ if (!isTransA) {
+ L(subloop30);
+ cmp(I, UNROLL_N);
+ jl(subloop25, T_NEAR);
+ align(16);
+
+ L(subloop31);
+ if (unroll_m == 16) {
+ kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, true, false);
+ } else {
+ kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, true, false);
+ }
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop31, T_NEAR);
+ align(16);
+
+ L(subloop25);
+ cmp(I, 1);
+ jne(subloop32, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop32);
+ cmp(I, 2);
+ jne(subloop33, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop33);
+ cmp(I, 3);
+ jne(subloop34, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop34);
+ cmp(I, 4);
+ jne(subloop35, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop35);
+ cmp(I, 5);
+ jne(subloop99, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ align(16);
+ }
+
+ L(subloop99);
+ // Compute address for A
+ if (!isTransA) {
+ add(A, unroll_m * SIZE);
+ } else {
+ mov(rax, LDA);
+ imul(rax, rax, unroll_m);
+ add(A, rax);
+ }
+
+ // Compute next address of BIAS
+ if (hasBias) {
+ add(BIAS, unroll_m * SIZE);
+ }
+ };
+
+ preamble();
+
+ Label buffer_in_ws, buffer_allocated;
+
+ // Get the registers
+ mov(B, ARG_B);
+ mov(LDB, ARG_LDB);
+ mov(r15, ARG_BETA);
+ mov(r12, ARG_C);
+ if (hasBias)
+ mov(r10, ARG_BIAS);
+ mov(LDC, ARG_LDC);
+ mov(rbp, rsp);
+
+ vmovss(xmm0, ptr[ARG_ALPHA]);
+ vmovss(xmm1, ptr[r15]);
+
+#if _WIN32
+ mov(A, ARG_A);
+ mov(LDA, ARG_LDA);
+#endif
+
+ cmp(K, STACK_K_CAPACITY);
+ jg(buffer_in_ws, T_NEAR);
+
+ // Create buffer and align to 4kB page
+ lea(rax, ptr[K * SIZE]);
+ sal(rax, 4);
+ add(rax, 256);
+ sub(rsp, rax);
+ and_(rsp, -PAGE_4K);
+ jmp(buffer_allocated, T_NEAR);
+
+ L(buffer_in_ws);
+ mov(rsp, ARG_WS);
+
+ L(buffer_allocated);
+
+ mov(ORIG_SP, rbp);
+ mov(M, ARG_M);
+ mov(N, ARG_N);
+ mov(C, r12);
+ if (hasBias)
+ mov(BIAS, r10);
+ vmovss(ALPHA, xmm0);
+ vmovss(BETA, xmm1);
+ sub(A, -OFFSET * SIZE);
+ sub(B, -OFFSET * SIZE);
+ mov(ORIG_A, A);
+ sal(LDA, BASE_SHIFT);
+ sal(LDB, BASE_SHIFT);
+ sal(LDC, BASE_SHIFT);
+ lea(LDB3, ptr[LDB + LDB * 2]);
+
+ for (int i = 0; i < 8; i++) {
+ mov(dword[rsp + 88 + i * 4], i);
+ }
+
+ if (isTransA && is_avx2) {
+ movq(xmm0, LDA);
+ vpbroadcastq(ymm1, xmm0);
+ vinsertf128(ymm0, ymm0, xmm0, 1);
+ vpermilpd(ymm0, ymm0, 5);
+ vpaddq(ymm1, ymm1, ymm1);
+ vperm2f128(ymm1, ymm1, ymm1, 8);
+ vpaddq(ymm0, ymm0, ymm1);
+ vmovups(STRIDE, ymm0);
+ }
+
+ // Check A alignment and leading dimension; take copy-based path as
+ // needed
+ mov(rax, LDA);
+ or_(rax, A);
+ and_(rax, 0x1f);
+ mov(FLAG, rax);
+
+ Label main0, main1, main2, main3, main999;
+
+ cmp(M, UNROLL_M);
+ jl(main0, T_NEAR);
+ align(16);
+
+ L(main1);
+ subloop(UNROLL_M, true, true);
+ sub(M, UNROLL_M);
+ cmp(M, UNROLL_M);
+ jge(main1, T_NEAR);
+ align(16);
+
+ L(main0);
+ cmp(M, 0);
+ jle(main999, T_NEAR);
+
+ if (UNROLL_M > 8) {
+ cmp(M, 8);
+ jle(main2, T_NEAR);
+
+ sub(M, 8);
+ vbroadcastss(VMASK, M);
+ vpcmpgtd(VMASK, VMASK, MASK);
+
+ subloop(16, true, false);
+ jmp(main999, T_NEAR);
+ align(16);
+
+ L(main2);
+ cmp(M, 8);
+ jne(main3, T_NEAR);
+ subloop(8, true, true);
+ jmp(main999, T_NEAR);
+ }
+
+ align(16);
+
+ L(main3);
+ vbroadcastss(VMASK, M);
+ if (is_avx2) {
+ vpcmpgtd(VMASK, VMASK, MASK);
+ } else {
+ auto xmask = Xmm(VMASK.getIdx());
+ auto xmm_tmp = xmm4;
+
+ vextractf128(xmm_tmp, VMASK, 1);
+ vpcmpgtd(xmask, xmask, MASK);
+ vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4
+ vinsertf128(VMASK, VMASK, xmm_tmp, 1);
+ }
+ subloop(8, false, false);
+ align(16);
+
+ L(main999);
+ // Restore original stack
+ mov(rsp, ORIG_SP);
+
+ vzeroupper();
+ postamble();
+
+ ker_ = this->getCode<ker_t>();
+ }
+
+ typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws);
+
+ void operator()(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws) const
+ {
+ ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
+ }
+
+private:
+ ker_t ker_;
+};
+
+const xbyak_gemm *get_xbyak_gemm(
+ bool isTransA, bool isTransB, float beta, bool hasBias) {
+ auto beta_idx = [](float beta) {
+ return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
+ };
+
+ // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
+ static xbyak_gemm *kernel_table[2][2][2][3];
+ static std::once_flag initialized;
+ std::call_once(initialized, [=]{
+ for (bool isTransA: {false, true})
+ for (bool isTransB: {false, true})
+ for (bool hasBias: {false, true})
+ for (float beta: {0.0f, 1.0f, 2.0f}) {
+ // nocopy sgemm with bias for beta != 0.0 is not supported
+ if (hasBias && beta != 0.0)
+ continue;
+ kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
+ new xbyak_gemm(isTransA, isTransB, beta, hasBias);
+ }
+ });
+
+ return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
+}
+
+void sgemm_nocopy_driver(const char *transa,
+ const char *transb, int m, int n, int k, const float *alpha,
+ const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
+ float *c, dim_t ldc, const float *bias, float *ws)
+{
+ bool isTransA = (*transa == 'T' || *transa == 't');
+ bool isTransB = (*transb == 'T' || *transb == 't');
+
+ int Bm, sizeM, Bn, sizeN, Bk, sizeK;
+
+ int i, j;
+
+ if ((m <= 0) || (n <= 0))
+ return;
+
+ if ((k <= 0) || (alpha[0] == 0.)) {
+
+ if (beta[0] == 0.) {
+ for (j = 0; j < n; j++)
+ for (i = 0; i < m; i++)
+ c[i + j * ldc] = 0.0;
+ } else if (beta[0] != 1.) {
+ for (j = 0; j < n; j++)
+ for (i = 0; i < m; i++)
+ c[i + j * ldc] *= beta[0];
+ }
+
+ return;
+ }
+
+ assert(IMPLICATION(bias != nullptr, *beta == 0.0));
+
+ // XXX: this happens on every thread...
+ bool hasBias = (bias != nullptr);
+ auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
+ auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
+ auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
+ assert(ker_bn && ker_b1 && ker_b0);
+
+ int BM = 4032;
+ int BN = isTransA ? 96 : 48;
+ int BK = isTransB ? 96 : 256;
+ const float *curA, *curB, *curBias = nullptr;
+ float *curC;
+
+ for (Bk = 0; Bk < k; Bk += sizeK) {
+ sizeK = k - Bk;
+ if (sizeK >= BK * 2)
+ sizeK = BK;
+ else {
+ if (sizeK > BK)
+ sizeK = (sizeK + 1) / 2;
+ }
+
+ for (Bm = 0; Bm < m; Bm += sizeM) {
+ sizeM = m - Bm;
+ if (sizeM >= BM * 2)
+ sizeM = BM;
+ else {
+ if (sizeM > BM + BM / 2)
+ sizeM = (sizeM + 1) / 2;
+ }
+
+ for (Bn = 0; Bn < n; Bn += sizeN) {
+ sizeN = n - Bn;
+ if (sizeN >= BN * 2)
+ sizeN = BN;
+ else {
+ if (sizeN > BN + BN / 2)
+ sizeN = (sizeN + 1) / 2;
+ }
+
+ if (!isTransA) {
+ curA = a + Bm + Bk * lda;
+ } else {
+ curA = a + Bk + Bm * lda;
+ }
+ if (!isTransB) {
+ curB = b + Bk + Bn * ldb;
+ } else {
+ curB = b + Bn + Bk * ldb;
+ }
+ curC = c + Bm + (size_t)Bn * ldc;
+ if (bias != nullptr) {
+ if (Bk == 0) {
+ curBias = bias + Bm;
+ } else {
+ curBias = nullptr;
+ }
+ }
+ if (Bk == 0) {
+ if (*beta == 0.0 && bias == nullptr)
+ (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ else
+ (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ } else {
+ (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ }
+ }
+ }
+ }
+}
+
+}
+
+mkldnn_status_t jit_avx_gemm_f32(
+ const char *transa, const char *transb,
+ const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
+ const float *A, const int *p_lda, const float *B, const int *p_ldb,
+ const float *p_beta, float *C, const int *p_ldc, const float *bias)
+{
+ using namespace mkldnn::impl::utils;
+ using namespace avx_gemm_f32;
+ using namespace gemm_utils;
+
+ if (*p_beta != 0 && bias)
+ return ref_gemm(transa, transb, p_m, p_n, p_k,
+ p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
+
+ int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
+
+ int m = *p_m;
+ int n = *p_n;
+ int k = *p_k;
+ dim_t lda = *p_lda;
+ dim_t ldb = *p_ldb;
+ dim_t ldc = *p_ldc;
+ float beta = *p_beta;
+ int MB, NB, KB;
+
+ int nthr_m, nthr_n, nthr_k, nthr_mn;
+
+ // Determine threading partitioning
+ calc_nthr_nocopy_avx(
+ m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
+
+ // May not happen, but just in case
+ if (nthr < nthr_m * nthr_n * nthr_k)
+ nthr = nthr_m * nthr_n * nthr_k;
+
+ nthr_mn = nthr_m * nthr_n;
+
+ unsigned char * ompstatus_ = nullptr;
+ unsigned char volatile *ompstatus = nullptr;
+
+ float *c_buffers = nullptr;
+ float *ws_buffers = nullptr;
+
+ if (nthr_k > 1) {
+ ompstatus_ = (unsigned char *) malloc(
+ nthr * CACHE_LINE_SIZE,
+ CACHE_LINE_SIZE);
+ ompstatus = (unsigned char volatile *) ompstatus_;
+ assert(ompstatus);
+
+ for (int i = 0; i < nthr; i++)
+ ompstatus[i * CACHE_LINE_SIZE] = 0;
+
+ c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
+ * sizeof(float), PAGE_4K);
+ }
+
+ const size_t ws_elems_per_thr = (size_t)k * 16 + 64;
+ const size_t ws_size_per_thr
+ = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
+ if (k > STACK_K_CAPACITY) {
+ ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
+ }
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_m, ithr_n, ithr_k, ithr_mn;
+ int m_from, m_to, myM;
+ int n_from, n_to, myN;
+ int k_from, k_to, myK;
+ int cbase, ibase;
+ const float *myA, *myB, *myBias = nullptr;
+ float *myC = C, myBeta;
+ float *ws = ws_buffers ?
+ ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
+ dim_t ld = ldc;
+
+ int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
+
+ if (ithr < nthr_m * nthr_n * nthr_k) {
+
+ ithr_mn = ithr % nthr_mn;
+ ithr_m = ithr_mn % nthr_m;
+ ithr_n = ithr_mn / nthr_m;
+ ithr_k = ithr / nthr_mn;
+
+ /* swap ithr_k for performance improvement */
+ if (ithr_k == 0)
+ ithr_k = nthr_k - 1;
+ else if (ithr_k == nthr_k - 1)
+ ithr_k = 0;
+
+ m_from = MB * (ithr_m);
+ m_to = MB * (ithr_m + 1);
+ if (m_to > m)
+ m_to = m;
+ myM = m_to - m_from;
+
+ n_from = NB * (ithr_n);
+ n_to = NB * (ithr_n + 1);
+ if (n_to > n)
+ n_to = n;
+ myN = n_to - n_from;
+
+ k_from = KB * (ithr_k);
+ k_to = KB * (ithr_k + 1);
+ if (k_to > k)
+ k_to = k;
+ myK = k_to - k_from;
+
+ cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+ ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
+
+ if ((myM > 0) && (myN > 0)) {
+
+ if (*transa == 'N' || *transa == 'n') {
+ myA = &(A[m_from + k_from * lda]);
+ } else {
+ myA = &(A[k_from + m_from * lda]);
+ }
+ if (*transb == 'N' || *transb == 'n') {
+ myB = &(B[k_from + n_from * ldb]);
+ } else {
+ myB = &(B[n_from + k_from * ldb]);
+ }
+ if (ithr_k == 0) {
+ myC = &(C[m_from + n_from * ldc]);
+ myBeta = beta;
+ ld = ldc;
+ if (bias)
+ myBias = &(bias[m_from]);
+ } else {
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
+ myBeta = 0.0;
+ ld = MB;
+ myBias = nullptr;
+ }
+
+ sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
+ lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
+
+ if (nthr_k > 1 && !sum_later)
+ ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
+ }
+
+ if (nthr_k > 1 && !sum_later) {
+
+ // sum matrices partitioned along K dimension
+ int n1, n2;
+
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+ if (ithr_k > 0) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
+ /* need to wait until main thread finishes */
+ while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
+ };
+
+ /* my cache is hot */
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ if (ik != ithr_k) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
+
+ while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
+ };
+
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+ }
+ }
+ }
+ });
+
+ // handle C summation later
+ if (nthr_k > 1 && ompstatus[0] == 0) {
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_m, ithr_n, ithr_k, ithr_mn;
+ int m_from, m_to, myM;
+ int n_from, n_to, myN;
+ int cbase;
+ float *myC = C;
+
+ if (ithr < nthr_m * nthr_n * nthr_k) {
+
+ ithr_mn = ithr % nthr_mn;
+ ithr_m = ithr_mn % nthr_m;
+ ithr_n = ithr_mn / nthr_m;
+ ithr_k = ithr / nthr_mn;
+
+ /* swap ithr_k for performance improvement */
+ if (ithr_k == 0)
+ ithr_k = nthr_k - 1;
+ else if (ithr_k == nthr_k - 1)
+ ithr_k = 0;
+
+ m_from = MB * (ithr_m);
+ m_to = MB * (ithr_m + 1);
+ if (m_to > m)
+ m_to = m;
+ myM = m_to - m_from;
+
+ n_from = NB * (ithr_n);
+ n_to = NB * (ithr_n + 1);
+ if (n_to > n)
+ n_to = n;
+ myN = n_to - n_from;
+
+ cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ if (nthr_k > 1) {
+ // sum matrices partitioned along K dimension
+ int n1, n2;
+
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+ if (ithr_k > 0) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
+
+ /* my cache is hot */
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ if (ik != ithr_k) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
+
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+ }
+ }
+ }
+ });
+ }
+
+
+ free(c_buffers);
+ free(ompstatus_);
+ free(ws_buffers);
+
+ return mkldnn_success;
+}
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp
new file mode 100644
index 0000000000..aabf520a3c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp
@@ -0,0 +1,37 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX_GEMM_F32_HPP
+#define JIT_AVX_GEMM_F32_HPP
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t jit_avx_gemm_f32(
+ const char *transa, const char *transb, const int *M,
+ const int *N, const int *K, const float *alpha, const float *A,
+ const int *lda, const float *B, const int *ldb, const float *beta,
+ float *C, const int *ldc, const float *bias = nullptr);
+
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp
new file mode 100644
index 0000000000..5147885a89
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp
@@ -0,0 +1,346 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_types.h"
+
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "jit_generator.hpp"
+
+#include "gemm_utils_f32.hpp"
+#include "ref_gemm_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace gemm_utils;
+
+namespace {
+
+template <typename data_t>
+void copy_A(
+ bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) {
+ for (int k = 0; k < K; k++) {
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < unroll_factor<data_t>::m; i++) {
+ ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
+ }
+ ws += unroll_factor<data_t>::m;
+ }
+}
+
+template <typename data_t, bool isTransA, bool isTransB>
+void kernel_mxn(int K, const data_t *A, const dim_t lda,
+ const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc,
+ const data_t alpha, const data_t beta) {
+ data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] =
+ { static_cast<data_t>(0.) };
+ for (int k = 0; k < K; k++) {
+ for (int j = 0; j < unroll_factor<data_t>::n; j++) {
+ data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb];
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < unroll_factor<data_t>::m; i++) {
+ data_t a = isTransA ? A[i * lda + k] : A[i + lda * k];
+ c[i + unroll_factor<data_t>::m * j] += a * b;
+ }
+ }
+ }
+ for (int j = 0; j < unroll_factor<data_t>::n; j++) {
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < unroll_factor<data_t>::m; i++) {
+ C[i + j * ldc] = (beta == static_cast<data_t>(0.))
+ ? alpha * c[i + unroll_factor<data_t>::m * j]
+ : alpha * c[i + unroll_factor<data_t>::m * j]
+ + beta * C[i + j * ldc];
+ }
+ }
+}
+
+template <typename data_t, bool isTransA, bool isTransB>
+void block_ker(const int M, const int N, const int K,
+ const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
+ data_t *C, const dim_t ldc, const data_t alpha, const data_t beta,
+ data_t *ws, bool do_copy) {
+ int Nu = rnd_dn(N, unroll_factor<data_t>::n);
+ int Mu = rnd_dn(M, unroll_factor<data_t>::m);
+ for (int i = 0; i < Mu; i += unroll_factor<data_t>::m) {
+ for (int j = 0; j < Nu; j += unroll_factor<data_t>::n) {
+ const data_t *b = isTransB ? &B[j] : &B[j * ldb];
+ const data_t *a = isTransA ? &A[i * lda] : &A[i];
+ if (do_copy) {
+ if (j == 0) {
+ copy_A<data_t>(isTransA, K, a, lda, ws);
+ }
+ kernel_mxn<data_t, false, isTransB>(
+ K, ws, unroll_factor<data_t>::m, b, ldb,
+ &C[i + j * ldc], ldc, alpha, beta);
+ } else {
+ kernel_mxn<data_t, isTransA, isTransB>(
+ K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta);
+ }
+ }
+ }
+ // tail processing
+ for (int i = 0; i < M; i++) {
+ for (int j = Nu; j < N; j++) {
+ data_t c = beta == static_cast<data_t>(0.)
+ ? static_cast<data_t>(0.)
+ : beta * C[i + j * ldc];
+ for (int p = 0; p < K; p++) {
+ data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
+ data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
+ c += alpha * a * b;
+ }
+ C[i + j * ldc] = c;
+ }
+ }
+ for (int i = Mu; i < M; i++) {
+ for (int j = 0; j < Nu; j++) {
+ data_t c = beta == static_cast<data_t>(0.)
+ ? static_cast<data_t>(0.)
+ : beta * C[i + j * ldc];
+ for (int p = 0; p < K; p++) {
+ data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
+ data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
+ c += alpha * a * b;
+ }
+ C[i + j * ldc] = c;
+ }
+ }
+}
+
+template <typename data_t, bool isTransA, bool isTransB>
+void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
+ const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
+ const data_t beta, data_t *C, const dim_t ldc, bool do_copy,
+ data_t *ws) {
+ constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM;
+ constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN;
+ constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK;
+
+ const data_t *curA;
+ const data_t *curB;
+ data_t *curC;
+
+ if ((M <= 0) || (N <= 0))
+ return;
+
+ if ((K <= 0) || (alpha == static_cast<data_t>(0))) {
+ dim_t MN = N * M;
+ if (beta == static_cast<data_t>(0.)) {
+ for (dim_t j = 0; j < MN; j++)
+ C[j] = static_cast<data_t>(0.);
+ } else if (beta != static_cast<data_t>(1.)) {
+ for (dim_t j = 0; j < MN; j++)
+ C[j] *= beta;
+ }
+ return;
+ }
+
+ for (int Bk = 0; Bk < K; Bk += BK) {
+ int kb = nstl::min(K - Bk, BK);
+ for (int Bm = 0; Bm < M; Bm += BM) {
+ int mb = nstl::min(M - Bm, BM);
+ for (int Bn = 0; Bn < N; Bn += BN) {
+ int nb = nstl::min(N - Bn, BN);
+ curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda;
+ curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb;
+ curC = C + Bm + Bn * ldc;
+ if (Bk == 0) {
+ block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
+ curB, ldb, curC, ldc, alpha, beta, ws, do_copy);
+ } else {
+ block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
+ curB, ldb, curC, ldc, alpha, static_cast<data_t>(1.0),
+ ws, do_copy);
+ }
+ }
+ }
+ }
+}
+
+}
+
+template <typename data_t>
+mkldnn_status_t ref_gemm(
+ const char *transa_, const char *transb_, const int *M_,
+ const int *N_, const int *K_, const data_t *alpha_, const data_t *A,
+ const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_,
+ data_t *C, const int *ldc_, const data_t *bias) {
+
+ bool isTransA = (*transa_ == 'T' || *transa_ == 't');
+ bool isTransB = (*transb_ == 'T' || *transb_ == 't');
+ const int M = *M_, N = *N_, K = *K_;
+ const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_;
+ const data_t alpha = *alpha_, beta = *beta_;
+
+ int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
+ int nthr_m, nthr_n, nthr_k;
+ int MB, NB, KB;
+ // thread balancing over M, N, K & size of blocking dimensions
+ calc_nthr_nocopy_avx(
+ M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
+
+ data_t *c_buffers = nullptr;
+ data_t *ws_buffers = nullptr;
+ if (nthr_k > 1) {
+ c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
+ * sizeof(data_t), PAGE_4K);
+ if (!c_buffers) {
+ nthr_k = 1;
+ KB = K;
+ }
+ }
+
+ bool do_copy = (NB / unroll_factor<data_t>::n > 3);
+ const int nthr_mn = nthr_m * nthr_n;
+ const int nthr = nthr_mn * nthr_k;
+ const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m;
+ const size_t ws_size_per_thr
+ = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
+ if (do_copy) {
+ ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K);
+ if (!ws_buffers)
+ do_copy = false;
+ }
+
+ auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N,
+ int ithr) {
+ from = NB * (ithr);
+ to = NB * (ithr + 1);
+ if (to > N)
+ to = N;
+ myN = to - from;
+ };
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_mn = ithr % nthr_mn;
+ int ithr_m = ithr_mn % nthr_m;
+ int ithr_n = ithr_mn / nthr_m;
+ int ithr_k = ithr / nthr_mn;
+
+ int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ data_t *ws = do_copy
+ ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t)
+ : nullptr;
+
+ int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
+ k_from = 0, k_to = 0, myK = 0;
+
+ get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
+ get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
+ get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
+
+ if (myM > 0 && myN > 0) {
+ data_t myBeta, *myC;
+ dim_t ld;
+ if (ithr_k == 0) {
+ myC = &(C[m_from + n_from * ldc]);
+ myBeta = beta;
+ ld = ldc;
+ } else {
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
+ myBeta = 0.0f;
+ ld = MB;
+ }
+ const data_t *myA = isTransA
+ ? &(A[k_from + m_from * lda])
+ : &(A[m_from + k_from * lda]);
+ const data_t *myB = isTransB
+ ? &(B[n_from + k_from * ldb])
+ : &(B[k_from + n_from * ldb]);
+
+ if (!isTransA) {
+ if (!isTransB) {
+ gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
+ } else {
+ gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
+ }
+ } else {
+ if (!isTransB) {
+ gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
+ } else {
+ gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
+ }
+ }
+ }
+ });
+
+ if (nthr_k > 1) {
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_mn = ithr % nthr_mn;
+ int ithr_m = ithr_mn % nthr_m;
+ int ithr_k = ithr / nthr_mn;
+ int ithr_n = ithr_mn / nthr_m;
+
+ int n_from = 0, n_to = 0, myN = 0;
+ int m_from = 0, m_to = 0, myM = 0;
+
+ int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
+ get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
+
+ // sum matrices partitioned along K dimension
+ int offset = 0, block = 0;
+ gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset,
+ &block);
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ data_t *myC = c_buffers
+ + MB * ((dim_t)NB * (cbase + ik - 1) + offset);
+
+ gemm_utils::sum_two_matrices(myM, block, myC, MB,
+ &C[m_from + (n_from + offset) * ldc], ldc);
+ }
+ });
+ }
+
+ if (bias) {
+ parallel_nd(N, M, [&](int i, int j) {
+ C[i*ldc + j] += bias[j];
+ });
+ }
+
+ free(ws_buffers);
+ free(c_buffers);
+
+ return mkldnn_success;
+}
+
+template mkldnn_status_t ref_gemm<float>(
+ const char *transa_, const char *transb_,
+ const int *M_, const int *N_, const int *K_, const float *alpha_,
+ const float *A, const int *lda_, const float *B, const int *ldb_,
+ const float *beta_, float *C, const int *ldc_, const float *bias);
+
+template mkldnn_status_t ref_gemm<double>(
+ const char *transa_, const char *transb_,
+ const int *M_, const int *N_, const int *K_, const double *alpha_,
+ const double *A, const int *lda_, const double *B, const int *ldb_,
+ const double *beta_, double *C, const int *ldc_, const double *bias);
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp
new file mode 100644
index 0000000000..7c90ba6277
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp
@@ -0,0 +1,36 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef REF_GEMM_F32_HPP
+#define REF_GEMM_F32_HPP
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <typename data_t>
+mkldnn_status_t ref_gemm(const char *transa, const char *transb, const int *M,
+ const int *N, const int *K, const data_t *alpha, const data_t *A,
+ const int *lda, const data_t *B, const int *ldb, const data_t *beta,
+ data_t *C, const int *ldc, const data_t *bias);
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp
new file mode 100644
index 0000000000..3dbe07d743
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp
@@ -0,0 +1,280 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "mkldnn_traits.hpp"
+#include "nstl.hpp"
+
+#include "jit_generator.hpp"
+
+#include "gemm.hpp"
+
+#include "f32/jit_avx512_common_gemm_f32.hpp"
+#include "f32/jit_avx_gemm_f32.hpp"
+#include "f32/ref_gemm_f32.hpp"
+
+#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp"
+#include "s8x8s32/simple_gemm_s8s8s32.hpp"
+#include "s8x8s32/ref_gemm_s8x8s32.hpp"
+
+#include "os_blas.hpp"
+
+/* USE_MKL USE_CBLAS effect
+ * ------- --------- ------
+ * yes yes use Intel(R) MKL CBLAS
+ * yes no use jit
+ * no yes system-dependent CBLAS
+ * no no use jit
+ */
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
+ const int *M, const int *N, const int *K, const int *lda,
+ const int *ldb, const int *ldc, const float *alpha, const float *beta,
+ const bool with_bias) {
+ if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta))
+ return mkldnn_invalid_arguments;
+ if (with_bias && *beta != 0)
+ return mkldnn_unimplemented;
+ bool consistency = true
+ && utils::one_of(*transa, 'T', 't', 'N', 'n')
+ && utils::one_of(*transb, 'T', 't', 'N', 'n')
+ && *M >= 0
+ && *N >= 0
+ && *K >= 0;
+
+ if (!consistency)
+ return mkldnn_invalid_arguments;
+ bool isTransA = utils::one_of(*transa, 'T', 't');
+ bool isTransB = utils::one_of(*transb, 'T', 't');
+ int nrowA = isTransA ? *K : *M;
+ int nrowB = isTransB ? *N : *K;
+ consistency = true
+ && *lda >= nstl::max(1, nrowA)
+ && *ldb >= nstl::max(1, nrowB)
+ && *ldc >= nstl::max(1, *M);
+ if (!consistency)
+ return mkldnn_invalid_arguments;
+
+ return mkldnn_success;
+}
+
+mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc,
+ const char *transa, const char *transb, const int *M, const int *N,
+ const int *K, const int *lda, const int *ldb, const int *ldc,
+ const float *alpha, const float *beta, const bool with_bias) {
+ if (offsetc == nullptr)
+ return mkldnn_invalid_arguments;
+ if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
+ return mkldnn_invalid_arguments;
+
+ return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha,
+ beta, with_bias);
+}
+
+mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
+ const int *M, const int *N, const int *K, const float *alpha,
+ const float *A, const int *lda, const float *B, const int *ldb,
+ const float *beta, float *C, const int *ldc,
+ const float *bias, const bool force_jit_gemm) {
+ mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K,
+ lda, ldb, ldc, alpha, beta, bias != nullptr);
+ if (status != mkldnn_success)
+ return status;
+
+#ifdef USE_CBLAS
+ if (!force_jit_gemm) {
+ bool trA = *transa == 't' || *transa == 'T';
+ bool trB = *transb == 't' || *transb == 'T';
+ CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
+ CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
+ cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB,
+ *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
+
+ if (bias) {
+ // Add bias if necessary (bias is applied to columns of C)
+ cblas_int incx = 1, incy = 1;
+ parallel_nd(*N, [&](int n) {
+ ptrdiff_t offset = (ptrdiff_t)n * (*ldc);
+ cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy);
+ });
+ }
+ return mkldnn_success;
+ }
+#endif
+
+ if (mayiuse(avx512_common))
+ return jit_avx512_common_gemm_f32(transa, transb,
+ M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
+ else if (mayiuse(avx))
+ return jit_avx_gemm_f32(transa, transb,
+ M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
+ else
+ return ref_gemm<float>(transa, transb,
+ M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
+}
+
+template <typename b_dt>
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co) {
+ mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb,
+ M, N, K, LDA, LDB, LDC, alpha, beta, false);
+ if (status != mkldnn_success)
+ return status;
+
+ if (*M == 0 || *N == 0 || *K == 0)
+ return mkldnn_success;
+
+#if USE_MKL_IGEMM
+ bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
+ bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
+ bool AisN = (*transa == 'N' || *transa == 'n');
+ bool BisN = (*transb == 'N' || *transb == 'n');
+
+ if (data_traits<b_dt>::data_type == data_type::u8) {
+ CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans;
+ CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans;
+ CBLAS_OFFSET Cblas_offsetc =
+ OCisR
+ ? CblasRowOffset
+ : OCisC
+ ? CblasColOffset
+ : CblasFixOffset;
+ cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc,
+ *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo,
+ *beta, C, *LDC, co);
+ return mkldnn_success;
+ } else {
+ assert(data_traits<b_dt>::data_type == data_type::s8);
+ // TODO CBLAS implementation of gemm_s8s8s32 goes here.
+ // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
+ if (utils::everyone_is(0, *ao, *bo)) {
+ return simple_gemm_s8s8s32(transa, transb, offsetc, M,
+ N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
+ C, LDC, co);
+ } else {
+ return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
+ alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
+ }
+ }
+#else
+ cpu_isa_t isa = isa_any;
+ if (mayiuse(avx512_core_vnni)) {
+ isa = avx512_core_vnni;
+ } else if (mayiuse(avx512_core)) {
+ isa = avx512_core;
+ }
+
+ if (data_traits<b_dt>::data_type == data_type::u8) {
+ switch (isa) {
+ case avx512_core:
+ case avx512_core_vnni:
+ return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M,
+ N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta,
+ C, LDC, co);
+ default:
+ return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
+ alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
+ }
+ } else {
+ assert(data_traits<b_dt>::data_type == data_type::s8);
+ // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
+ if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni))
+ && *ao == 0 && *bo == 0) {
+ return simple_gemm_s8s8s32(transa, transb, offsetc, M,
+ N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
+ C, LDC, co);
+ } else {
+ return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
+ alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
+ }
+ }
+#endif
+}
+
+template
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const int8_t *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co);
+
+template
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const uint8_t *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co);
+
+}
+}
+}
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::cpu;
+
+mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb,
+ const int64_t *M, const int64_t *N, const int64_t *K, const float *alpha,
+ const float *A, const int64_t *lda, const float *B, const int64_t *ldb,
+ const float *beta, float *C, const int64_t *ldc) {
+ int M_s32 = (int)*M;
+ int N_s32 = (int)*N;
+ int K_s32 = (int)*K;
+ int lda_s32 = (int)*lda;
+ int ldb_s32 = (int)*ldb;
+ int ldc_s32 = (int)*ldc;
+
+ return extended_sgemm(transa, transb, &M_s32, &N_s32, &K_s32,
+ alpha, A, &lda_s32, B, &ldb_s32, beta, C, &ldc_s32);
+}
+
+mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb,
+ const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K,
+ const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao,
+ const uint8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta,
+ int32_t *C, const int64_t *ldc, const int32_t *co) {
+ int M_s32 = (int)*M;
+ int N_s32 = (int)*N;
+ int K_s32 = (int)*K;
+ int lda_s32 = (int)*lda;
+ int ldb_s32 = (int)*ldb;
+ int ldc_s32 = (int)*ldc;
+ return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32,
+ alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co);
+}
+
+mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb,
+ const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K,
+ const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao,
+ const int8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta,
+ int32_t *C, const int64_t *ldc, const int32_t *co) {
+ int M_s32 = (int)*M;
+ int N_s32 = (int)*N;
+ int K_s32 = (int)*K;
+ int lda_s32 = (int)*lda;
+ int ldb_s32 = (int)*ldb;
+ int ldc_s32 = (int)*ldc;
+
+ return gemm_s8x8s32<int8_t>(transa, transb, offsetc, &M_s32, &N_s32, &K_s32,
+ alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co);
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp
new file mode 100644
index 0000000000..dc15ff7130
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp
@@ -0,0 +1,58 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef GEMM_HPP
+#define GEMM_HPP
+
+#include "mkldnn_types.h"
+#include "os_blas.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
+ const int *M, const int *N, const int *K, const float *alpha,
+ const float *A, const int *lda, const float *B, const int *ldb,
+ const float *beta, float *C, const int *ldc,
+ const float *bias = nullptr, bool force_jit_gemm = false);
+
+template <typename b_dt>
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
+ const b_dt *B, const int *ldb, const int8_t *bo, const float *beta,
+ int32_t *c, const int *ldc, const int32_t *co);
+
+#ifdef USE_CBLAS
+#define GEMM_IMPL_STR "gemm:blas"
+#else
+#define GEMM_IMPL_STR "gemm:jit"
+#endif
+
+#if USE_MKL_IGEMM
+#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:blas"
+#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:blas"
+#else
+#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:jit"
+#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:jit"
+#endif
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp
new file mode 100644
index 0000000000..4d34ede0bd
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp
@@ -0,0 +1,86 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef OS_BLAS_HPP
+#define OS_BLAS_HPP
+
+/** \file
+ * Common stuff respecting USE_MKL and USE_CBLAS compile flags
+ *
+ * USE_MKL USE_CBLAS effect
+ * ------- --------- ------
+ * yes yes normal compile: jit *may* be preferred over Intel(R) MKL CBLAS
+ * yes no jit calls OK; assert if cblas is ever called
+ * no yes system-dependent CBLAS
+ * no no gemm convolution (or other blas) N/A; create stubs
+ */
+
+#if defined(USE_MKL)
+
+#include "mkl_version.h"
+
+#define USE_MKL_PACKED_GEMM (INTEL_MKL_VERSION >= 20190001)
+#define USE_MKL_IGEMM \
+ (INTEL_MKL_VERSION >= 20180000 && __INTEL_MKL_BUILD_DATE >= 20170628)
+
+#include "mkl_cblas.h"
+#if !defined(USE_CBLAS)
+#define cblas_sgemm(...) assert(!"CBLAS is unavailable")
+#endif
+
+#else /* defined(USE_MKL) */
+
+#define USE_MKL_PACKED_GEMM 0
+#define USE_MKL_IGEMM 0
+
+#if defined(_SX)
+/* TODO: _SX should also define USE_CBLAS in case the later is available */
+extern "C" {
+#include "cblas.h" // CHECK: does SX also have a fortran API sgemm?
+}
+
+#elif defined(USE_CBLAS)
+#include "cblas.h" // Maybe a system/cmake cblas works for you?
+#else
+/* put the stubs to make a code compilable but not workable */
+#define cblas_sgemm(...) assert(!"CBLAS is unavailable")
+#endif /* defined(_SX) */
+
+#endif /* defined(USE_MKL) */
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+#if defined(USE_MKL) && defined(USE_CBLAS)
+typedef MKL_INT cblas_int;
+
+#elif defined(USE_CBLAS)
+typedef int cblas_int;
+
+#if defined(_SX)
+/* this cblas.h is peculiar... */
+typedef CBLAS_ORDER CBLAS_LAYOUT;
+#endif
+#endif
+
+}
+}
+}
+
+#endif /* OS_BLAS_HPP */
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp
new file mode 100644
index 0000000000..dde72f4a17
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp
@@ -0,0 +1,206 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef COMMON_H
+#define COMMON_H
+
+#define GEMM_CODE_SIZE (4096L * 32)
+
+#define AVX512_UNROLL_M 48
+#define AVX512_UNROLL_N 8
+#define AVX512_UNROLL_K 1
+#define AVX512_BM 9984
+#define AVX512_BN 384
+#define AVX512_BK 768
+#define AVX512_BK_VNNI 1536
+#define AVX512_BK_TRADITIONAL 384
+#define AVX512_BLOCKING_SMALL_K 48
+#define AVX512_BN_SMALL_K 24
+
+
+#define PAGESIZE 4096
+
+#define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE
+#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+enum {
+ PARTITION_1D_ROW,
+ PARTITION_1D_COL,
+ PARTITION_2D_COL_MAJOR,
+ PARTITION_2D = PARTITION_2D_COL_MAJOR,
+};
+
+enum {
+ COPY_NONE,
+ COPY_A,
+};
+
+enum {
+ NO_OFFSET,
+ FIX_OFFSET,
+ COL_OFFSET,
+ ROW_OFFSET,
+};
+
+// Alias for any dimension related variable.
+typedef long long int dim_t;
+
+typedef struct {
+ // Interface arguments.
+ int transa, transb, offsetc;
+ dim_t m, n, k;
+ dim_t lda, ldb, ldc;
+ const int8_t *a;
+ const uint8_t *b;
+ int32_t *c;
+ const float *alpha, *beta;
+
+ int8_t ao, bo;
+ const int32_t *co;
+
+ // Kernel parameters.
+ dim_t um, un, uk, bm, bn, bk;
+ dim_t bn_small_k, bk_traditional, blocking_small_k;
+
+ int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ // Gemv kernels
+ void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float,
+ const int8_t*, const dim_t, const uint8_t*,
+ const float, int32_t*);
+
+ void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float,
+ const uint8_t*, const dim_t, const int8_t*,
+ const float, int32_t*);
+
+ // Gemv parameters
+ int swap;
+
+} blas_t;
+
+
+class jit_avx512_core_u8_copy_an_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern);
+
+ public:
+ jit_avx512_core_u8_copy_an_kern();
+};
+
+class jit_avx512_core_u8_copy_at_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern);
+
+ public:
+ jit_avx512_core_u8_copy_at_kern();
+};
+
+class jit_avx512_core_u8_copy_bn_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern);
+
+ public:
+ jit_avx512_core_u8_copy_bn_kern();
+};
+
+class jit_avx512_core_u8_copy_bt_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern);
+
+ public:
+ jit_avx512_core_u8_copy_bt_kern();
+};
+
+class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern);
+
+ public:
+ jit_avx512_core_u8_copy_sum_an_kern();
+};
+
+class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern);
+
+ public:
+ jit_avx512_core_u8_copy_sum_at_kern();
+};
+
+class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern);
+
+ public:
+ jit_avx512_core_u8_copy_sum_bn_kern();
+};
+
+class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern);
+
+ public:
+ jit_avx512_core_u8_copy_sum_bt_kern();
+};
+
+}
+}
+}
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp
new file mode 100644
index 0000000000..db9dd9ef97
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp
@@ -0,0 +1,28 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg);
+int gemv_threading_driver(blas_t *arg);
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp
new file mode 100644
index 0000000000..e4b8e1cde2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp
@@ -0,0 +1,1409 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <cstdint>
+#include <mutex>
+
+#include "common.hpp"
+#include "mkldnn_types.h"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "jit_avx512_core_gemm_s8u8s32.hpp"
+#include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
+#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
+#include "gemv.hpp"
+
+#if defined(_MSC_VER)
+#include <malloc.h>
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+typedef struct {
+ int nthrs_m, nthrs_n;
+ int partition;
+ int copy_type;
+} blas_thread_t;
+
+static inline void round_to_nearest(int32_t *rounded_val, double fp_val) {
+ if (fp_val >= 0.) {
+ fp_val += 0.5;
+ if (fp_val > INT32_MAX) {
+ fp_val = INT32_MAX;
+ }
+ } else {
+ fp_val -= 0.5;
+ if (fp_val < INT32_MIN) {
+ fp_val = INT32_MIN;
+ }
+ }
+ *rounded_val = (int32_t) fp_val;
+}
+
+static inline void add_results(const dim_t m, const dim_t n, const dim_t k,
+ const float alpha, const float beta, const int32_t *c_partial_sum,
+ const dim_t ldcp, int32_t *c_data, const dim_t ldc,
+ const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao,
+ const int8_t bo, const int32_t *co, const int offsetc)
+{
+ for (dim_t j = 0; j < n; ++j) {
+ for (dim_t i = 0; i < m; ++i) {
+ int32_t ctemp = c_partial_sum[i + j * ldcp];
+
+ if (alpha == 1.0f) {
+ if (beta == 0.0f) {
+ c_data[i + j * ldc] = ctemp;
+ } else {
+ double c_float = (double) beta
+ * (double) c_data[i + j * ldc];
+ c_float += (double) ctemp;
+ round_to_nearest(&c_data[i + j * ldc], c_float);
+ }
+ } else if (alpha == -1.0f) {
+ if (beta == 0.0f) {
+ c_data[i + j * ldc] = -ctemp;
+ } else {
+ double c_float = (double) beta
+ * (double) c_data[i + j * ldc];
+ c_float -= (double) ctemp;
+ round_to_nearest(&c_data[i + j * ldc], c_float);
+ }
+ } else {
+ if (beta == 0.0f) {
+ double c_float = alpha * (double) ctemp;
+ round_to_nearest(&c_data[i + j * ldc], c_float);
+ } else {
+ double c_float = alpha * (double) ctemp +
+ beta * (double) c_data[i + j * ldc];
+ round_to_nearest(&c_data[i + j * ldc], c_float);
+ }
+ }
+
+ if (offsetc == FIX_OFFSET) {
+ c_data[i + j * ldc] += co[0];
+ } else if (offsetc == ROW_OFFSET) {
+ c_data[i + j * ldc] += co[j];
+ } else if (offsetc == COL_OFFSET) {
+ c_data[i + j * ldc] += co[i];
+ }
+ }
+ }
+}
+
+// TODO Find a better place for those functions.
+static inline dim_t ld_padd(const dim_t x)
+{
+ return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t)))
+ * (2048 / sizeof(int32_t)) + (64 / sizeof(int32_t));
+}
+
+void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k,
+ const int8_t *a, const uint8_t *b, float beta, int32_t *c,
+ const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum,
+ const int32_t *co, const int offsetc, const blas_t *arg)
+{
+ int8_t ao = arg->ao;
+ int8_t bo = arg->bo;
+ int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0];
+
+ // Since m and n are limited by blocking, stack overflow may not happen;
+ // it's up to 32kB
+#if !defined(_MSC_VER)
+ int32_t col_offset[m];
+ int32_t row_offset[n];
+#else
+ int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m);
+ int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n);
+#endif
+
+ int col_req = 0;
+ int row_req = 0;
+
+ if ((bo != 0) || (offsetc == COL_OFFSET))
+ col_req = 1;
+ if ((ao != 0) || (offsetc == ROW_OFFSET))
+ row_req = 1;
+
+ // It needs one of colum or row offsets, but it doesn't need both
+ if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) {
+ if ((col_req == 0) && (row_req == 0)) {
+ if (m <= n) {
+ col_req = 1;
+ } else {
+ row_req = 1;
+ }
+ }
+ }
+
+ if (col_req) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] = 0;
+
+ if (offsetc == COL_OFFSET) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] += co[i];
+ }
+
+ if (bo != 0) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] += bo * a_row_sum[i];
+ }
+ }
+
+ if (row_req) {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] = 0;
+
+ if (offsetc == ROW_OFFSET) {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] += co[i];
+ }
+
+ if (ao != 0) {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] += ao * b_col_sum[i];
+ }
+ }
+
+ if ((offsetc == FIX_OFFSET) && (co_0 != 0)) {
+ if (col_req) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] += co_0;
+ } else {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] += co_0;
+ }
+ }
+
+ if ((ao != 0) && (bo != 0)) {
+ if (col_req) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] += (int32_t) k * ao * bo;
+ } else {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] += (int32_t) k * ao * bo;
+ }
+ }
+
+ if (col_req == 0) {
+ if (row_req == 0) {
+ if (beta == 0.0) {
+ arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ } else {
+ arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ }
+ } else {
+ if (beta == 0.0) {
+ arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ } else {
+ arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ }
+ }
+ } else {
+ if (row_req == 0) {
+ if (beta == 0.0) {
+ arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ } else {
+ arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ }
+ } else {
+ if (beta == 0.0) {
+ arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ } else {
+ arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ }
+ }
+ }
+}
+
+static inline void *align(void *ptr, size_t alignment)
+{
+ return (void *) utils::rnd_up((uintptr_t) ptr, alignment);
+}
+
+static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k,
+ const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co,
+ const blas_t *arg)
+{
+ dim_t lda = arg->lda;
+ dim_t ldb = arg->ldb;
+ dim_t ldc = arg->ldc;
+ int8_t ao = arg->ao;
+ int8_t bo = arg->bo;
+ float alpha = *arg->alpha;
+ float beta = *arg->beta;
+
+ if (m <= 0 || n <= 0) {
+ return 0;
+ }
+
+ // Padding along K dimension.
+ dim_t k_padd = 0;
+ if (k <= arg->bk_traditional) {
+ k_padd = utils::rnd_up(k, arg->uk);
+ k_padd = nstl::max(128LL, k_padd);
+ } else if (k < 2 * arg->bk) {
+ k_padd = utils::rnd_up(k / 2, arg->uk);
+ } else {
+ k_padd = arg->bk;
+ }
+
+ // Padding along M dimension.
+ dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
+ arg->um);
+
+ // Padding along N dimension.
+ dim_t n_padd = 0;
+ if (k < arg->blocking_small_k) {
+ n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
+ arg->bn_small_k), arg->un);
+ } else {
+ n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
+ arg->un);
+ }
+
+ // Padding for temporary buffer for C
+ dim_t ldc_buf = ld_padd(m_padd);
+
+ dim_t strideAm = (arg->transa == 0)? 1 : lda;
+ dim_t strideAn = (arg->transa != 0)? 1 : lda;
+ dim_t strideBm = (arg->transb == 0)? 1 : ldb;
+ dim_t strideBn = (arg->transb != 0)? 1 : ldb;
+
+ size_t a_buf_nelems = m_padd * k_padd;
+ size_t b_buf_nelems = k_padd * n_padd;
+ size_t a_row_sum_nelems = m_padd;
+ size_t b_col_sum_nelems = n_padd;
+
+ size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K
+ + b_buf_nelems * sizeof(*b) + PAGE_4K
+ + a_row_sum_nelems * sizeof(*c) + PAGE_4K
+ + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
+
+ bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
+ if (need_c_buffer) {
+ size_t c_buf_nelems = ldc_buf * n_padd;
+ mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
+ }
+
+ char *mem = (char *) malloc(mem_size, 128);
+
+ if (!mem) {
+ return -1;
+ }
+
+ int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
+ uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K);
+ int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
+ int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems,
+ PAGE_4K);
+
+ int32_t *bufferC = NULL;
+ if (need_c_buffer) {
+ bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
+ }
+
+ float beta_saved = beta;
+
+ int a_block_copied = 0;
+ dim_t sizeM = 0;
+ for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
+ sizeM = m - Bm;
+ if (sizeM > m_padd)
+ sizeM = m_padd;
+
+ dim_t sizeK = 0;
+ for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
+ sizeK = k - Bk;
+ if (sizeK > k_padd)
+ sizeK = k_padd;
+
+ // Scale C blocks by beta only for the first time
+ if (Bk == 0)
+ beta = beta_saved;
+ else
+ beta = 1.0f;
+
+ // Apply C offset when to the last k-block of the partial sum.
+ int offsetc = NO_OFFSET;
+ if (Bk + sizeK == k)
+ offsetc = arg->offsetc;
+
+ dim_t sizeN = 0;
+ for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
+ sizeN = n - Bn;
+ if (sizeN > n_padd)
+ sizeN = n_padd;
+
+ const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn;
+ arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL,
+ NULL, b_col_sum);
+
+ dim_t sizeUM = 0;
+ for (dim_t Um = 0; Um < sizeM; Um += sizeUM) {
+ sizeUM = sizeM - Um;
+ if (sizeUM > arg->um)
+ sizeUM = arg->um;
+
+ /*
+ * Use the whole A buffer only if we have multiple B blocks
+ * for k-dimension, otherwise we are wasting cache to store
+ * B and C blocks.
+ */
+ dim_t Um_forA = 0;
+ if (sizeN < n)
+ Um_forA = Um;
+
+ const int8_t *a_block = a + (Bm + Um) * strideAm
+ + Bk * strideAn;
+ if (!a_block_copied) {
+ arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL,
+ bufferA + Um_forA * sizeK, NULL, NULL,
+ a_row_sum + Um_forA);
+ }
+
+ int32_t *c_block = c + (Bm + Um) + Bn * ldc;
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = Bn;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = Bm + Um;
+ }
+ if (need_c_buffer) {
+ igemm_inner_kernel(sizeUM, sizeN, sizeK,
+ bufferA + Um_forA * sizeK, bufferB, 0.0f,
+ bufferC + Um, ldc_buf, a_row_sum + Um_forA,
+ b_col_sum, NULL, NO_OFFSET, arg);
+
+ // Finish the block adding the necessary alpha, beta
+ // and offsets.
+ add_results(sizeUM, sizeN, sizeK, alpha, beta,
+ bufferC + Um, ldc_buf, c_block, ldc,
+ a_row_sum + Um_forA, b_col_sum, ao, bo,
+ co + co_stride, offsetc);
+ } else {
+ igemm_inner_kernel(sizeUM, sizeN, sizeK,
+ bufferA + Um_forA * sizeK, bufferB, beta,
+ c_block, ldc, a_row_sum + Um_forA, b_col_sum,
+ co + co_stride, offsetc, arg);
+ }
+ }
+ a_block_copied = 1;
+ }
+ a_block_copied = 0;
+ }
+ }
+
+ free(mem);
+
+ return 0;
+}
+
+static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n,
+ const dim_t k, const int8_t *bufferA, const uint8_t *b,
+ const float beta, int32_t *c, const int offsetc, const int32_t *co,
+ const int32_t *a_row_sum, const blas_t *arg)
+{
+ dim_t ldb = arg->ldb;
+ dim_t ldc = arg->ldc;
+ int8_t ao = arg->ao;
+ int8_t bo = arg->bo;
+ float alpha = *arg->alpha;
+
+ if (m <= 0 || n <= 0) {
+ return 0;
+ }
+
+ // Padding along N dimension.
+ dim_t n_padd = 0;
+ if (k < arg->blocking_small_k) {
+ n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
+ arg->bn_small_k), arg->un);
+ } else {
+ n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
+ arg->un);
+ }
+
+ // Padding for temporary buffer for C
+ dim_t ldc_buf = ld_padd(m);
+
+ dim_t strideBn = (arg->transb != 0)? 1 : ldb;
+
+ size_t b_buf_nelems = k * n_padd;
+ size_t b_col_sum_nelems = n_padd;
+
+ size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K
+ + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
+
+ bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
+ if (need_c_buffer) {
+ size_t c_buf_nelems = ldc_buf * n_padd;
+ mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
+ }
+
+ char *mem = (char *) malloc(mem_size, 128);
+
+ if (!mem) {
+ return -1;
+ }
+
+ uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K);
+ int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
+
+ int32_t *bufferC = NULL;
+ if (need_c_buffer) {
+ bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
+ }
+
+ dim_t sizeN = 0;
+ for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
+ sizeN = n - Bn;
+ if (sizeN > n_padd)
+ sizeN = n_padd;
+
+ // Implement the kernel here.
+ const uint8_t *b_block = b + Bn * strideBn;
+ arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL,
+ b_col_sum);
+
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = Bn;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = 0;
+ }
+ int32_t *c_block = c + Bn * ldc;
+ if (need_c_buffer) {
+ igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC,
+ ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg);
+
+ // Finish the block adding the necessary alpha, beta and offsets.
+ add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block,
+ ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride,
+ offsetc);
+ } else {
+ igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block,
+ ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg);
+ }
+ }
+
+ free(mem);
+
+ return 0;
+
+}
+
+#define N2D_MAX_AVX512 384
+#define M2D_MIN_AVX512 384
+#define VECLEN 16
+#define NCONS 1
+static inline void set_thread_opts_avx512(int *p_nthrs,
+ blas_thread_t *thread_info, const blas_t *arg)
+{
+ int nthrs = *p_nthrs;
+ dim_t m = arg->m;
+ dim_t n = arg->n;
+
+ thread_info->nthrs_m = 0;
+ thread_info->nthrs_n = 0;
+ thread_info->copy_type = COPY_NONE; // By default don't do parallel copy.
+
+ int condition_2D_bsrc = -1;
+ if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) {
+ condition_2D_bsrc = 1;
+ } else {
+ condition_2D_bsrc = 0;
+ }
+
+ int condition_1D_copya = 0;
+ if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) {
+ condition_2D_bsrc = 0;
+ condition_1D_copya = 1;
+ }
+
+ // If offset is non-zero, we need to keep 1D_copya to reduce update overhead
+ if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0
+ || arg->offsetc != FIX_OFFSET) {
+ condition_2D_bsrc = 0;
+ condition_1D_copya = 1;
+ }
+
+ if (condition_2D_bsrc == 1) {
+ int nthrs_m = 1;
+ int nthrs_n = nthrs;
+
+ while ((nthrs_n % 2 == 0) &&
+ (n / nthrs > N2D_MAX_AVX512 ||
+ n / nthrs_n <= N2D_MAX_AVX512 / 2) &&
+ (m / nthrs_m >= 2 * M2D_MIN_AVX512) &&
+ (nthrs_m < 4)) {
+ nthrs_m *= 2;
+ nthrs_n /= 2;
+ }
+
+ thread_info->nthrs_m = nthrs_m;
+ thread_info->nthrs_n = nthrs_n;
+ thread_info->partition = PARTITION_2D;
+
+ // Reset the total number of threads that will be used.
+ *p_nthrs = nthrs_m * nthrs_n;
+
+ } else if (condition_1D_copya && mkldnn_thr_syncable()) {
+ // Use parallel copy A algorithm
+ thread_info->copy_type = COPY_A;
+ thread_info->partition = PARTITION_1D_COL;
+ } else {
+ if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) {
+ thread_info->partition = PARTITION_1D_ROW;
+ } else {
+ thread_info->partition = PARTITION_1D_COL;
+ }
+ }
+}
+#undef N2D_MAX_AVX512
+#undef M2D_MIN_AVX512
+#undef VECLEN
+#undef NCONS
+
+static inline void partition_1d(const int ithr, const int nthrs, const dim_t n,
+ dim_t *t_offset, dim_t *t_block)
+{
+ dim_t band = n / nthrs;
+
+ dim_t tail = n - (nthrs - 1) * band;
+ if (tail > (band + 1))
+ band++;
+ tail = n - (nthrs - 1) * band;
+
+ if (ithr < (nthrs - 1))
+ *t_block = band;
+ else
+ *t_block = tail;
+
+ *t_offset = ithr * band;
+
+ if (*t_offset >= n) {
+ *t_block = 0;
+ *t_offset = 0;
+ } else if ((*t_offset + *t_block) > n) {
+ *t_block = n - *t_offset;
+ }
+}
+
+static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i,
+ const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m,
+ const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp,
+ dim_t *p_n_band)
+{
+ dim_t m_disp = 0, n_disp = 0;
+ dim_t m_band = 0, n_band = 0;
+
+ int mdiv = nthrs_m;
+ int ndiv = nthrs_n;
+
+ dim_t m_bandt = m / mdiv; /* size per thread */
+ dim_t n_bandt = n / ndiv; /* size per thread */
+ int firstmgroup = mdiv - 1;
+ int firstngroup = ndiv - 1;
+ dim_t firstmval = m_bandt;
+ dim_t firstnval = n_bandt;
+
+ int mthr_used = mdiv;
+ if (m - (mdiv - 1) * m_bandt > m_bandt + 1) {
+ if (m - (mdiv - 1) * m_bandt > mdiv)
+ ++m_bandt;
+
+ firstmval = m_bandt + 1;
+ mthr_used = (int) (m / firstmval);
+
+ if (mthr_used * firstmval < m)
+ ++mthr_used;
+
+ firstmgroup = mthr_used - 1;
+ }
+
+ int nthr_used = ndiv;
+ if (n - (ndiv - 1) * n_bandt > n_bandt + 1) {
+ firstnval = n_bandt + 1;
+ nthr_used = (int) (n / firstnval);
+
+ if (nthr_used * firstnval < n)
+ ++nthr_used;
+
+ firstngroup = nthr_used - 1;
+ }
+
+ *nthrs = mthr_used * nthr_used;
+
+ if (ithr < *nthrs) {
+ if (ithr_i < firstmgroup) {
+ m_band = firstmval;
+ m_disp = ithr_i * firstmval;
+ } else if (ithr_i <= mthr_used - 2) {
+ m_band = m_bandt;
+ m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt;
+ } else {
+ m_disp = firstmgroup * firstmval
+ + (mthr_used - 1 - firstmgroup) * m_bandt;
+ m_band = nstl::max(0LL, m - m_disp);
+ }
+
+ if (ithr_j < firstngroup) {
+ n_band = firstnval;
+ n_disp = ithr_j * firstnval;
+ } else if (ithr_j <= nthr_used - 2) {
+ n_band = n_bandt;
+ n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt;
+ } else {
+ n_disp = firstngroup * firstnval
+ + (nthr_used - 1 - firstngroup) * n_bandt;
+ n_band = nstl::max(0LL, n - n_disp);
+ }
+ m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL);
+ n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL);
+ }
+
+ if (ithr < *nthrs) {
+ *p_m_disp = m_disp;
+ *p_n_disp = n_disp;
+ *p_m_band = m_band;
+ *p_n_band = n_band;
+ } else {
+ *p_m_disp = 0;
+ *p_n_disp = 0;
+ *p_m_band = 0;
+ *p_n_band = 0;
+ }
+
+ return;
+}
+
+static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m,
+ dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c,
+ const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg)
+{
+ dim_t strideAm = (arg->transa == 0)? 1 : arg->lda;
+ dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb;
+ int offsetc = arg->offsetc;
+
+ switch (thread_info->partition) {
+ case PARTITION_1D_ROW:
+ {
+ dim_t offset = 0;
+ dim_t block = 0;
+ partition_1d(ithr, *nthrs, arg->m, &offset, &block);
+
+ *m = block;
+ *n = arg->n;
+ *k = arg->k;
+
+ // Set matrix A.
+ *a = arg->a + offset * strideAm;
+
+ // Set matrix B.
+ *b = arg->b;
+
+ // Set matrix C.
+ *c = arg->c + offset;
+
+ // Set offset vector for C matrix
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = offset;
+ }
+ *co = arg->co + co_stride;
+ break;
+ }
+
+ case PARTITION_1D_COL:
+ {
+ dim_t offset = 0;
+ dim_t block = 0;
+ partition_1d(ithr, *nthrs, arg->n, &offset, &block);
+
+ *m = arg->m;
+ *n = block;
+ *k = arg->k;
+
+ // Set matrix A.
+ *a = arg->a;
+
+ // Set matrix B.
+ *b = arg->b + offset * strideBn;
+
+ // Set matrix C.
+ *c = arg->c + offset * arg->ldc;
+
+ // Set offset vector for C matrix
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = offset;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = 0;
+ }
+ *co = arg->co + co_stride;
+ break;
+ }
+
+ case PARTITION_2D_COL_MAJOR:
+ {
+ int nthrs_m = thread_info->nthrs_m;
+ int nthrs_n = thread_info->nthrs_n;
+ int ithr_i = ithr % nthrs_m;
+ int ithr_j = ithr / nthrs_m;
+
+ dim_t m_disp = 0;
+ dim_t m_band = 0;
+ dim_t n_disp = 0;
+ dim_t n_band = 0;
+
+ partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n,
+ arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band);
+
+ *m = m_band;
+ *n = n_band;
+ *k = arg->k;
+
+ // Set matrix A.
+ *a = arg->a + m_disp * strideAm;
+
+ // Set matrix B.
+ *b = arg->b + n_disp * strideBn;
+
+ // Set matrix C.
+ *c = arg->c + m_disp + n_disp * arg->ldc;
+
+ // Set offset vector for C matrix
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = n_disp;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = m_disp;
+ }
+ *co = arg->co + co_stride;
+ break;
+ }
+ }
+}
+
+#define MULTIPLIER 10
+static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m,
+ const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b,
+ int32_t *c, const int32_t *co, const blas_t *arg,
+ char **p_shared_mem)
+{
+ const dim_t lda = arg->lda;
+ const dim_t ldb = arg->ldb;
+ const dim_t strideAm = (arg->transa == 0)? 1 : lda;
+ const dim_t strideAn = (arg->transa != 0)? 1 : lda;
+ const dim_t strideBm = (arg->transb == 0)? 1 : ldb;
+
+ // Padding along M dimension.
+ dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
+ arg->um);
+
+ // Padding along K dimension.
+ dim_t k_padd = 0;
+ if (k <= arg->bk_traditional) {
+ k_padd = utils::rnd_up(k, arg->uk);
+ k_padd = nstl::max(128LL, k_padd);
+ } else if (k < 2 * arg->bk) {
+ k_padd = utils::rnd_up(k / 2, arg->uk);
+ } else {
+ k_padd = arg->bk;
+ }
+
+ m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs;
+ if (m_padd > m) {
+ m_padd = utils::rnd_up(m, arg->um);
+ }
+
+ size_t a_buf_nelems = m_padd * k_padd;
+
+ // Allocate shared memory for A and its row sum buffers in master thread.
+ if (ithr == 0) { // If thread master
+ size_t a_row_sum_nelems = m_padd;
+
+ size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K)
+ + a_row_sum_nelems * sizeof(*c) + PAGE_4K;
+
+ *p_shared_mem = (char *) malloc(mem_size, 128);
+
+ }
+ mkldnn_thr_barrier();
+
+ char *mem = *p_shared_mem;
+ int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
+ int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K);
+
+ if (!mem) {
+ return -1;
+ }
+
+ int result = 0; // Return status
+
+ dim_t sizeK = 0;
+ for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
+ sizeK = k - Bk;
+ if (sizeK > k_padd)
+ sizeK = k_padd;
+
+ // Scale C blocks by beta only for the first term of partial sum.
+ float beta = 1.0f;
+ if (Bk == 0)
+ beta = *(arg->beta);
+
+ // Apply C offset for the last k-block of the partial sum.
+ int offsetc = NO_OFFSET;
+ if (Bk + sizeK == k)
+ offsetc = arg->offsetc;
+
+ dim_t sizeM = 0;
+ for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
+ sizeM = m - Bm;
+ if (sizeM > m_padd)
+ sizeM = m_padd;
+
+ if (ithr < nthrs) {
+ dim_t band = (sizeM + nthrs - 1) / nthrs;
+ band = utils::rnd_up(band, arg->um);
+
+ dim_t offset = band * ithr;
+
+ // If offset is too large don't use that thread for copying.
+ if (offset >= sizeM) {
+ offset = 0;
+ band = 0;
+ }
+
+ // Handle the tail of the copy.
+ if (offset + band > sizeM) {
+ band = sizeM - offset;
+ }
+
+ if (band > 0) {
+ const int8_t *a_block = a + (Bm + offset) * strideAm
+ + Bk * strideAn;
+ arg->copyA(&sizeK, &band, a_block, &lda, NULL,
+ bufferA + offset * sizeK, NULL, NULL,
+ a_row_sum + offset);
+ }
+ }
+ mkldnn_thr_barrier(); // Wait for finishing parallel copy.
+
+ const uint8_t *b_block = b + Bk * strideBm;
+ int32_t *c_block = c + Bm;
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = Bm;
+ }
+
+ result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK,
+ bufferA, b_block, beta, c_block, offsetc, co + co_stride,
+ a_row_sum, arg);
+
+ mkldnn_thr_barrier(); // Wait for kernel computations to finish.
+ }
+ }
+
+ // Free memory allocated in master thread
+ if (ithr == 0) {
+ free(mem);
+ }
+
+ return result;
+}
+#undef MULTIPLIER
+
+static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k,
+ double fp_per_cycle, int *nthrs)
+{
+ double omp_overhead_small_core = 3.0e+3;
+ double omp_intercept_big_core = 4.0e+3;
+ double omp_slope_big_core = 5.0e+2;
+
+ double gemm_cycles = 8.0 * m * n * k / fp_per_cycle;
+
+ int i = *nthrs;
+
+ // Use a different model for omp overheads if nthrs is <= 4
+ if (*nthrs <= 4 && omp_overhead_small_core > 0) {
+ double omp_cycles = omp_overhead_small_core;
+ if (gemm_cycles < omp_cycles) {
+ *nthrs = 1;
+ return;
+ } else {
+ while (i > 1) {
+ if (omp_cycles * i < gemm_cycles * (i - 1)) break;
+ --i;
+ }
+ }
+ } else {
+ if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) {
+ *nthrs = 1;
+ return;
+ }
+
+ // adaptive decrement to march faster·
+ while (i > 1) {
+ double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core;
+ if (omp_cycles * i < gemm_cycles * (i - 1))
+ break;
+
+ if (i < 10)
+ i -= 2;
+ else if (i < 30)
+ i -= 4;
+ else
+ i -= 8;
+ }
+ }
+
+ if (i < 1)
+ i = 1;
+
+ *nthrs = i;
+}
+
+#define CACHE_LINE_SIZE 64
+static int gemm_threading_driver(blas_t *arg)
+{
+ if ((arg->m <= 0) || (arg->n <= 0))
+ return mkldnn_success;
+
+ if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) {
+ return mkldnn_success;
+ }
+
+ int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
+ get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr);
+
+ if (nthr == 1) {
+ return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b,
+ arg->c, arg->co, arg);
+ }
+
+ int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE,
+ PAGE_4K);
+
+ if (!results) {
+ return -1;
+ }
+
+ for (int i = 0; i < nthr; i++) {
+ results[i * CACHE_LINE_SIZE] = 0; // Initialize to success
+ }
+
+ char *shared_mem = NULL;
+
+ parallel(nthr, [&](const int ithr, const int nthr) {
+ int nthrs = nthr;
+ if (nthrs == 1) {
+ results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a,
+ arg->b, arg->c, arg->co, arg);
+ } else {
+ blas_thread_t thread_info;
+ set_thread_opts_avx512(&nthrs, &thread_info, arg);
+
+ const int8_t *a = NULL;
+ const uint8_t *b = NULL;
+ int32_t *c = NULL;
+ const int32_t *co = NULL;
+ dim_t m = -1;
+ dim_t n = -1;
+ dim_t k = -1;
+ decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co,
+ &thread_info, arg);
+
+ if (ithr < nthrs) {
+ switch (thread_info.copy_type) {
+ case COPY_A:
+ results[ithr * CACHE_LINE_SIZE] =
+ parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg,
+ &shared_mem);
+ break;
+
+ default:
+ case COPY_NONE:
+ results[ithr * CACHE_LINE_SIZE] =
+ gemm_kernel_driver(m, n, k, a, b, c, co, arg);
+ break;
+ }
+ }
+ }
+ });
+
+ int result = 0; // Initialize to success
+ for (int i = 0; i < nthr; i++) {
+ if (results[i] != 0) {
+ result = results[i * CACHE_LINE_SIZE];
+ break;
+ }
+ }
+
+ free(results);
+
+ return result;
+}
+#undef CACHE_LINE_SIZE
+
+static jit_avx512_core_u8_copy_an_kern *copy_an;
+static jit_avx512_core_u8_copy_at_kern *copy_at;
+static jit_avx512_core_u8_copy_bn_kern *copy_bn;
+static jit_avx512_core_u8_copy_bt_kern *copy_bt;
+static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an;
+static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at;
+static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn;
+static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_r;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_c;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c;
+static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel;
+static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel;
+
+static void jit_init(blas_t *arg)
+{
+ static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float,
+ const int8_t*, const dim_t, const uint8_t*,
+ const float, int32_t*);
+
+ static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float,
+ const uint8_t*, const dim_t, const int8_t*,
+ const float, int32_t*);
+
+ if (mayiuse(avx512_core_vnni)) {
+ arg->um = AVX512_UNROLL_M;
+ arg->un = AVX512_UNROLL_N;
+ arg->uk = AVX512_UNROLL_K;
+ arg->bm = AVX512_BM;
+ arg->bn = AVX512_BN;
+ arg->bk = AVX512_BK_VNNI;
+
+ arg->bk_traditional = AVX512_BK_TRADITIONAL;
+ arg->bn_small_k = AVX512_BN_SMALL_K;
+ arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
+ } else {
+ arg->um = AVX512_UNROLL_M;
+ arg->un = AVX512_UNROLL_N;
+ arg->uk = AVX512_UNROLL_K;
+ arg->bm = AVX512_BM;
+ arg->bn = AVX512_BN;
+ arg->bk = AVX512_BK;
+
+ arg->bk_traditional = AVX512_BK_TRADITIONAL;
+ arg->bn_small_k = AVX512_BN_SMALL_K;
+ arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
+ }
+
+ static std::once_flag initialized;
+ std::call_once(initialized, []{
+
+ copy_an = new jit_avx512_core_u8_copy_an_kern();
+ copy_at = new jit_avx512_core_u8_copy_at_kern();
+ copy_bn = new jit_avx512_core_u8_copy_bn_kern();
+ copy_bt = new jit_avx512_core_u8_copy_bt_kern();
+
+ copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern();
+ copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern();
+ copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern();
+ copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern();
+
+ kernel = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false);
+ kernel_b = new jit_avx512_core_gemm_s8u8s32_kern(false, true, true);
+ kernel_r = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true);
+ kernel_c = new jit_avx512_core_gemm_s8u8s32_kern(false, true, false);
+ kernel_b0 = new jit_avx512_core_gemm_s8u8s32_kern(true, false, false);
+ kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true, true, true);
+ kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true, false, true);
+ kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true, true, false);
+
+ gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
+ gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
+
+
+ copyAn = copy_an->getCode<int (*)(const dim_t *, const dim_t *,
+ const int8_t *, const dim_t *, const int8_t *, int8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copyAt = copy_at->getCode<int (*)(const dim_t *, const dim_t *,
+ const int8_t *, const dim_t *, const int8_t *, int8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copyBn = copy_bn->getCode<int (*)(const dim_t *, const dim_t *,
+ const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copyBt = copy_bt->getCode<int (*)(const dim_t *, const dim_t *,
+ const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copySumAn = copy_sum_an->getCode<int (*)(const dim_t *, const dim_t *,
+ const int8_t *, const dim_t *, const int8_t *, int8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copySumAt = copy_sum_at->getCode<int (*)(const dim_t *, const dim_t *,
+ const int8_t *, const dim_t *, const int8_t *, int8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copySumBn = copy_sum_bn->getCode<int (*)(const dim_t *, const dim_t *,
+ const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copySumBt = copy_sum_bt->getCode<int (*)(const dim_t *, const dim_t *,
+ const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ kern = kernel->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b = kernel_b->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_r = kernel_r->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_c = kernel_c->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b0 = kernel_b0->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b0_b = kernel_b0_b->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b0_r = kernel_b0_r->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b0_c = kernel_b0_c->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ gemv_s8u8s32_kern =
+ gemv_s8u8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>
+ (mayiuse(avx512_core_vnni));
+ gemv_u8s8s32_kern =
+ gemv_u8s8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>
+ (mayiuse(avx512_core_vnni));
+ });
+
+ if (arg->bo == 0) { // No need to compute A row sum if bo is zero
+ if (arg->transa == 0) {
+ arg->copyA = copyAn;
+ } else {
+ arg->copyA = copyAt;
+ }
+ } else {
+ if (arg->transa == 0) {
+ arg->copyA = copySumAn;
+ } else {
+ arg->copyA = copySumAt;
+ }
+ }
+
+ if (arg->ao == 0) { // No need to compute B column sum if ao is zero
+ if (arg->transb == 0) {
+ arg->copyB = copyBn;
+ } else {
+ arg->copyB = copyBt;
+ }
+ } else {
+ if (arg->transb == 0) {
+ arg->copyB = copySumBn;
+ } else {
+ arg->copyB = copySumBt;
+ }
+ }
+
+ arg->kernel = kern;
+ arg->kernel_b = kern_b;
+ arg->kernel_r = kern_r;
+ arg->kernel_c = kern_c;
+ arg->kernel_b0 = kern_b0;
+ arg->kernel_b0_b = kern_b0_b;
+ arg->kernel_b0_r = kern_b0_r;
+ arg->kernel_b0_c = kern_b0_c;
+ arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern;
+ arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern;
+}
+
+mkldnn_status_t jit_avx512_core_gemm_s8u8s32(
+ const char *transA, const char *transB, const char *offsetC,
+ const int *m, const int *n, const int *k,
+ const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
+ const uint8_t *b, const int *ldb, const int8_t *ob,
+ const float *beta, int32_t *c, const int *ldc, const int32_t *oc)
+{
+ char transa = *transA;
+ char transb = *transB;
+ char offsetc = *offsetC;
+
+ blas_t args;
+
+ // Initialize blas structure
+ args.m = *m;
+ args.n = *n;
+ args.k = *k;
+ args.alpha = alpha;
+ args.a = a;
+ args.lda = *lda;
+ args.b = b;
+ args.ldb = *ldb;
+ args.beta = beta;
+ args.c = c;
+ args.ldc = *ldc;
+ args.transa = (transa == 'N' || transa == 'n') ? 0 : 1;
+ args.transb = (transb == 'N' || transb == 'n') ? 0 : 1;
+ args.um = 0;
+ args.un = 0;
+ args.bm = 0;
+ args.bn = 0;
+ args.bk = 0;
+ args.copyA = NULL;
+ args.copyB = NULL;
+ args.kernel = NULL;
+ args.kernel_b0 = NULL;
+ args.ao = *oa;
+ args.bo = *ob;
+ args.co = oc;
+
+ if (offsetc == 'F' || offsetc == 'f') {
+ args.offsetc = FIX_OFFSET;
+ } else if (offsetc == 'R' || offsetc == 'r') {
+ args.offsetc = ROW_OFFSET;
+ } else { // offsetc == 'C' || offsetc == 'c'
+ args.offsetc = COL_OFFSET;
+ }
+
+ jit_init(&args);
+ int result = gemm_threading_driver(&args);
+
+ return (result < 0) ? mkldnn_out_of_memory : mkldnn_success;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp
new file mode 100644
index 0000000000..b2e2902a12
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp
@@ -0,0 +1,38 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX512_CORE_GEMM_S8U8S32_HPP
+#define JIT_AVX512_CORE_GEMM_S8U8S32_HPP
+
+#include <cstdint>
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t jit_avx512_core_gemm_s8u8s32(
+ const char *transA, const char *transB, const char *offsetC,
+ const int *m, const int *n, const int *k,
+ const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
+ const uint8_t *b, const int *ldb, const int8_t *ob,
+ const float *beta, int32_t *c, const int *ldc, const int32_t *oc);
+
+}
+}
+}
+
+#endif // JIT_AVX512_CORE_GEMM_S8U8S32_HPP
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp
new file mode 100644
index 0000000000..57554a1852
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp
@@ -0,0 +1,539 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
+
+
+#ifdef _WIN32
+static const bool is_windows = 1;
+#else
+static const bool is_windows = 0;
+#endif
+
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace Xbyak;
+
+
+
+
+// Convert between vector register lengths.
+static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); }
+static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); }
+
+// Load from or store to C.
+void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst,
+ const Xbyak::Address &src, int nelems)
+{
+ switch (nelems) {
+ default: vmovups(dst, src); break;
+ case 8: vmovups(make_ymm(dst), src); break;
+ case 4: vmovups(make_xmm(dst), src); break;
+ case 2: vmovlps(make_xmm(dst), src); break;
+ case 1: vmovss(make_xmm(dst), src); break;
+ }
+}
+void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst,
+ const Xbyak::Xmm &src, int nelems)
+{
+ switch (nelems) {
+ default: vmovups(dst, src); break;
+ case 8: vmovups(dst, make_ymm(src)); break;
+ case 4: vmovups(dst, make_xmm(src)); break;
+ case 2: vmovsd(dst, make_xmm(src)); break;
+ case 1: vmovss(dst, make_xmm(src)); break;
+ }
+}
+
+// Perform length-4 dot product accumulations of unsigned and signed bytes
+// in parallel.
+// Use vpdpbusd if VNNI available, otherwise emulate.
+void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst,
+ const Xmm &src1, const Xmm &src2)
+{
+ if (vnni)
+ vpdpbusd(dst, src1, src2);
+ else {
+ vpmaddubsw(dp_scratch, src1, src2);
+ vpmaddwd(dp_scratch, ones, dp_scratch);
+ vpaddd(dst, dst, dp_scratch);
+ }
+}
+
+// Inner kernel.
+void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n,
+ bool cfetch)
+{
+ int um_vecs = (unroll_m + 15) >> 4;
+ Label label_kernel_loop;
+
+ L_aligned(label_kernel_loop); {
+ for (int h = 0; h < 4; h++) {
+ for (int j = 0; j < unroll_n; j++) {
+ const Zmm b = b_regs[j & 1];
+
+ vpbroadcastd(b, ptr[BO + isize *
+ (2 * j + 2 * h * unroll_n - offset_b)]);
+ dot_product(c_regs[0][j], b, a_regs[0]);
+
+ if (j == 1 && !(h & 1))
+ prefetch_b(ptr[BO + isize * (prefetch_size_b
+ + 2 * h * unroll_n - offset_b)]);
+ else if (j % 3 == 0)
+ prefetch_a(ptr[AO + isize * (prefetch_size_a
+ + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]);
+
+ for (int i = 1; i < um_vecs; i++)
+ dot_product(c_regs[i][j], b, a_regs[i]);
+
+ if (cfetch && (j == std::min(1, unroll_n - 1))) {
+ if (h == 3)
+ lea(CO2, ptr[CO2 + LDC]);
+ else if (h < um_vecs)
+ prefetch_c(ptr[CO2 + (16 * h * size)]);
+ }
+
+ if (h == 3 && j == std::min(3, unroll_n - 1))
+ lea(AA, ptr[AA + (32 * isize)]);
+ }
+
+ for (int i = 0; i < um_vecs; i++)
+ vmovups(a_regs[i], ptr[AO + isize *
+ (32 * i + 2 * (h + 1) * unroll_m - offset_a)]);
+
+ if (h == 2)
+ prefetch_x(ptr[AA - (offset_a * isize)]);
+ }
+
+ add(AO, 8 * isize * unroll_m);
+ add(BO, 8 * isize * unroll_n);
+ sub(LoopCount, 1);
+ jg(label_kernel_loop, T_NEAR);
+ }
+}
+
+// k remainder loop for kernel.
+void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m,
+ int unroll_n, int unroll_k, int bwidth)
+{
+ if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
+ || (unroll_m < 0) || (unroll_n < 0))
+ return;
+
+ int um_vecs = (unroll_m + 15) >> 4;
+
+ for (int h = 0; h < unroll_k; h++) {
+ for (int j = 0; j < unroll_n; j++) {
+ Zmm b = b_regs[j & 1];
+ auto b_src = ptr[BO + (-isize * offset_b
+ + bwidth * (j + h * unroll_n))];
+
+ switch (bwidth) {
+ case 4:
+ vpbroadcastd(b, b_src);
+ break;
+ case 2:
+ vpbroadcastw(b, b_src);
+ break;
+ case 1:
+ vpbroadcastb(b, b_src);
+ break;
+ }
+ for (int i = 0; i < um_vecs; i++)
+ dot_product(c_regs[i][j], b, a_regs[i]);
+ }
+
+ if (unroll_k > 1) {
+ for (int i = 0; i < um_vecs; i++)
+ vmovups(a_regs[i], ptr[AO + isize * (32 * i
+ + (h + 1) * 2 * unroll_m - offset_a)]);
+ }
+ }
+
+ add(AO, unroll_k * unroll_m * bwidth);
+ add(BO, unroll_k * unroll_n * bwidth);
+}
+
+// Inner loop.
+void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n)
+{
+ if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
+ || (unroll_m < 0) || (unroll_n < 0))
+ return;
+
+ int um_vecs = (unroll_m + 15) >> 4;
+ int stage1 = unroll_n, stage2 = unroll_n;
+
+ Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2;
+ Label label_k_main_loop_3, label_kernel_loop_3;
+ Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2;
+ Label label_k_rem_1, label_update_begin;
+
+ mov(AO, A);
+ for (int i = 0; i < um_vecs; i++)
+ vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]);
+
+ mov(LoopCount, K);
+ sar(LoopCount, 4);
+ jle(label_k_remainder_loop_begin, T_NEAR);
+
+ // Main k loops, broken into three parts to time C prefetching.
+ sub(LoopCount, stage1 + stage2);
+ jle(label_k_main_loop_2, T_NEAR);
+
+ kernel_loop(unroll_m, unroll_n, false);
+
+ L_aligned(label_k_main_loop_2);
+ lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
+ add(LoopCount, stage1);
+ jle(label_k_main_loop_3, T_NEAR);
+
+ kernel_loop(unroll_m, unroll_n, true);
+
+ L_aligned(label_k_main_loop_3);
+ lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
+ add(LoopCount, stage2);
+ jle(label_k_remainder_loop_begin, T_NEAR);
+
+ kernel_loop(unroll_m, unroll_n, true);
+
+ // k remainder handling
+ L_aligned(label_k_remainder_loop_begin);
+ mov(LoopCount, K);
+ test(LoopCount, 8);
+ je(label_k_rem_4, T_NEAR);
+
+ remainder_kernel(unroll_m, unroll_n, 2, 4);
+
+ L_aligned(label_k_rem_4);
+ mov(LoopCount, K);
+ test(LoopCount, 4);
+ je(label_k_rem_2, T_NEAR);
+
+ remainder_kernel(unroll_m, unroll_n, 1, 4);
+
+ L_aligned(label_k_rem_2);
+ mov(LoopCount, K);
+ test(LoopCount, 2);
+ je(label_k_rem_1, T_NEAR);
+
+ Zmm zero = zmm6;
+ Zmm tmp = zmm5;
+
+ vpxorq(zero, zero, zero);
+ for (int i = 0; i < um_vecs; i++) {
+ Zmm a = a_regs[i];
+ vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]);
+ vpunpcklwd(tmp, a, zero);
+ vpunpckhwd(a, a, zero);
+ vshufi32x4(a, tmp, a, 0x44);
+ vshufi32x4(a, a, a, 0xD8);
+ }
+
+ remainder_kernel(unroll_m, unroll_n, 1, 2);
+
+ L_aligned(label_k_rem_1);
+ mov(LoopCount, K);
+ test(LoopCount, 1);
+ je(label_update_begin, T_NEAR);
+
+ vpxorq(zero, zero, zero);
+ for (int i = 0; i < um_vecs; i++) {
+ Zmm a = a_regs[i];
+ vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]);
+ vpunpcklbw(tmp, a, zero);
+ vpunpckhbw(a, a, zero);
+ vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1);
+ vpunpcklwd(tmp, a, zero);
+ vpunpckhwd(a, a, zero);
+ vshufi32x4(a, tmp, a, 0x44);
+ vshufi32x4(a, a, a, 0xD8);
+ }
+
+ remainder_kernel(unroll_m, unroll_n, 1, 1);
+
+ // Add offsets and update C.
+ L_aligned(label_update_begin);
+
+ if (enable_offset_r) {
+ // Add row offsets.
+ mov(rax, coffset_ry);
+ for (int j = 0; j < unroll_n; j++) {
+ Zmm row_offset = zmm0;
+
+ vbroadcastss(row_offset, ptr[rax + size * j]);
+
+ for (int i = 0; i < um_vecs; i++)
+ vpaddd(c_regs[i][j], c_regs[i][j], row_offset);
+ }
+ add(coffset_ry, size * unroll_n);
+ }
+
+ if (enable_offset_c) {
+ // Add column offsets.
+ mov(rax, coffset_cy);
+ for (int i = 0; i < um_vecs; i++) {
+ Zmm col_offset = zmm0;
+
+ c_load(col_offset, ptr[rax + size * 16 * i], unroll_m);
+
+ for (int j = 0; j < unroll_n; j++)
+ vpaddd(c_regs[i][j], c_regs[i][j], col_offset);
+ }
+ }
+
+ Reg64 LDC3 = rax;
+ lea(LDC3, ptr[LDC + LDC * 2]);
+
+ // C updates.
+ int c_off_j = 0;
+ for (int j = 0; j < unroll_n; j++) {
+ if (j > 0 && (j & 3) == 0) {
+ lea(CO1, ptr[CO1 + LDC * 4]);
+ c_off_j += 4;
+ }
+
+ int jj = j - c_off_j;
+
+ for (int i = 0; i < um_vecs; i++) {
+ Zmm c = c_regs[i][j];
+ Zmm c_old = zmm0;
+ decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj;
+
+ auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i];
+
+ if (beta_zero)
+ c_store(c_mem, c, unroll_m);
+ else {
+ c_load(c_old, c_mem, unroll_m);
+ vpaddd(c_old, c, c_old);
+ c_store(c_mem, c_old, unroll_m);
+ }
+
+ vpxorq(c, c, c);
+ }
+ }
+
+ lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]);
+}
+
+// Outer loop.
+void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y,
+ Label *&cur_outerloop_label)
+{
+ Label label_m_loop, label_n_loop, label_n_remainder_loops[6];
+
+ L(*cur_outerloop_label);
+ cur_outerloop_label++;
+ if (unroll_x >= IGEMM_UNROLL_M) {
+ mov(J, M);
+ cmp(J, unroll_x);
+ jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label.
+ } else {
+ test(J, unroll_x);
+ jle(*cur_outerloop_label, T_NEAR);
+ }
+
+ L_aligned(label_m_loop); {
+ mov(CO1, C);
+ add(C, unroll_x * size);
+
+ mov(BO, B);
+
+ mov(AA, K);
+ imul(AA, AA, unroll_x * isize);
+ lea(AA, ptr[A + AA + isize * prefetch_size_a]);
+
+ if (enable_offset_c) {
+ mov(rax, coffset_cx);
+ mov(coffset_cy, rax);
+ add(rax, unroll_x * size);
+ mov(coffset_cx, rax);
+ }
+
+ if (enable_offset_r) {
+ mov(rax, coffset_rx);
+ mov(coffset_ry, rax);
+ }
+
+ mov(I, N);
+ cmp(I, unroll_y);
+ jl(label_n_remainder_loops[0], T_NEAR);
+
+ L_aligned(label_n_loop); {
+ innerloop(unroll_x, unroll_y);
+ sub(I, unroll_y);
+ cmp(I, unroll_y);
+ jge(label_n_loop, T_NEAR);
+ }
+
+ align(16);
+
+ int label_idx = 0;
+ for (int uy = 16; uy > 0; uy >>= 1) {
+ L(label_n_remainder_loops[label_idx++]);
+ if (unroll_y > uy) {
+ test(I, uy);
+ jle(label_n_remainder_loops[label_idx], T_NEAR);
+
+ innerloop(unroll_x, uy);
+ align(16);
+ }
+ }
+ L(label_n_remainder_loops[label_idx]);
+
+ mov(A, AO);
+ if (unroll_x >= IGEMM_UNROLL_M) {
+ sub(J, unroll_x);
+ cmp(J, unroll_x);
+ jge(label_m_loop);
+ }
+ }
+
+ align(16);
+}
+
+void jit_avx512_core_gemm_s8u8s32_kern::generate()
+{
+ // Prologue
+ preamble();
+ sub(rsp, stack_alloc_size);
+
+ if (is_windows) {
+ mov(A, arg_a);
+ mov(B, arg_b);
+ }
+
+ mov(C, arg_c);
+ mov(LDC, arg_ldc);
+
+ sub(A, -offset_a * isize);
+ sub(B, -offset_b * isize);
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(K, qword[K]);
+
+ lea(LDC, ptr[LDC * size]);
+
+ if (enable_offset_c) {
+ mov(rax, arg_coffset_c);
+ mov(coffset_cx, rax);
+ }
+ if (enable_offset_r) {
+ mov(rax, arg_coffset_r);
+ mov(coffset_rx, rax);
+ }
+
+ for (int i = 0; i < (max_unroll_m >> 4); i++) {
+ for (int j = 0; j < max_unroll_n; j++) {
+ auto &c = c_regs[i][j];
+ vpxorq(c, c, c);
+ }
+ }
+
+ if (!vnni) {
+ mov(rax, 1);
+ movq(make_xmm(ones), rax);
+ vpbroadcastw(ones, make_xmm(ones));
+ }
+
+ Label outerloop_labels[8];
+ Label *cur_outerloop_label = &outerloop_labels[0];
+
+ // Main m loop.
+ outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label);
+
+ // m remainder loops.
+ for (int um = 32; um > 0; um >>= 1)
+ if (IGEMM_UNROLL_M > um)
+ outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label);
+
+ L(*cur_outerloop_label);
+
+ // Epilogue.
+ add(rsp, stack_alloc_size);
+ postamble();
+}
+
+
+jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool
+ beta_zero_, bool enable_offset_c_, bool enable_offset_r_) :
+ jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0),
+ arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0),
+ coffset_rx(0), coffset_ry(0)
+{
+ beta_zero = beta_zero_;
+ enable_offset_c = enable_offset_c_;
+ enable_offset_r = enable_offset_r_;
+ vnni = mayiuse(avx512_core_vnni);
+
+ // Assign integer registers
+ M = is_windows ? rcx : rdi;
+ N = is_windows ? rdx : rsi;
+ K = is_windows ? r8 : rdx;
+ A = is_windows ? rsi : r8;
+ B = r9;
+ C = r10;
+ LDC = r11;
+ I = r12;
+ J = r13;
+ LoopCount = rax;
+ AO = r14;
+ BO = r15;
+ CO1 = rbx;
+ CO2 = rbp;
+ AA = is_windows ? rdi : rcx;
+
+ // Assign vector registers
+ dp_scratch = zmm6;
+ ones = zmm7;
+ for (int i = 0; i < (max_unroll_m >> 4); i++)
+ a_regs[i] = Zmm(i);
+ b_regs[0] = zmm4;
+ b_regs[1] = zmm5;
+
+ int rn = 0;
+ for (int i = 0; i < (max_unroll_m >> 4); i++)
+ for (int j = 0; j < max_unroll_n; j++)
+ c_regs[i][j] = Zmm(8 + rn++);
+
+ // Assign stack variables.
+ stack_alloc_size = 32;
+ auto args_offset = stack_alloc_size + get_size_of_abi_save_regs()
+ + 8 + (is_windows ? 48 : 0);
+
+ arg_a = ptr[rsp + (args_offset - 16)];
+ arg_b = ptr[rsp + (args_offset - 8)];
+ arg_c = ptr[rsp + (args_offset + 0)];
+ arg_ldc = ptr[rsp + (args_offset + 8)];
+ arg_coffset_c = ptr[rsp + (args_offset + 16)];
+ arg_coffset_r = ptr[rsp + (args_offset + 24)];
+
+ coffset_cx = qword[rsp + 0];
+ coffset_cy = qword[rsp + 8];
+ coffset_rx = qword[rsp + 16];
+ coffset_ry = qword[rsp + 24];
+
+ generate();
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp
new file mode 100644
index 0000000000..e8efcc1cc8
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp
@@ -0,0 +1,101 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef IGEMM_KERNEL_GENERATOR_HPP
+#define IGEMM_KERNEL_GENERATOR_HPP
+
+#include "jit_generator.hpp"
+
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+class jit_avx512_core_gemm_s8u8s32_kern : public jit_generator {
+public:
+ jit_avx512_core_gemm_s8u8s32_kern(bool beta_zero_, bool enable_offset_c_,
+ bool enable_offset_r_);
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_s8u8s32_kern);
+
+protected:
+ bool beta_zero;
+ bool enable_offset_c, enable_offset_r;
+ bool vnni;
+
+ void prefetch_a(const Xbyak::Address &src) {
+ prefetcht0(src);
+ }
+ void prefetch_b(const Xbyak::Address &src) {
+ prefetcht0(src);
+ }
+ void prefetch_c(const Xbyak::Address &src) {
+ prefetchw(src);
+ }
+ void prefetch_x(const Xbyak::Address &src) {
+ prefetcht0(src);
+ }
+
+ void c_load(const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems);
+ void c_store(const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems);
+
+ void dot_product(const Xbyak::Xmm &dst, const Xbyak::Xmm &src1,
+ const Xbyak::Xmm &src2);
+ void kernel_loop(int unroll_m, int unroll_n, bool cfetch);
+ void remainder_kernel(int unroll_m, int unroll_n, int unroll_k, int bwidth);
+ void innerloop(int unroll_m, int unroll_n);
+ void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label);
+
+ void generate();
+
+
+private:
+ static const int IGEMM_UNROLL_M = 48;
+ static const int IGEMM_UNROLL_N = 8;
+
+ static const int isize = 2;
+ static const int size = 4;
+
+ // Prefetch configuration
+ static const int prefetch_size_a = 32 * 5;
+ static const int prefetch_size_b = 32 * 4;
+
+ static const int offset_a = 256, offset_b = 256;
+ static const int max_unroll_m = 48, max_unroll_n = 8;
+
+ // Integer register assignments
+ Xbyak::Reg64 M, N, K, A, B, C, LDC, I, J, LoopCount;
+ Xbyak::Reg64 AO, BO, CO1, CO2, AA;
+
+ // Vector register assignments
+ Xbyak::Zmm dp_scratch, ones, a_regs[max_unroll_m >> 4], b_regs[2];
+ Xbyak::Zmm c_regs[max_unroll_m >> 4][max_unroll_n];
+
+ // Stack variable assignments
+ int stack_alloc_size;
+ Xbyak::Address arg_a, arg_b, arg_c, arg_ldc, arg_coffset_c, arg_coffset_r;
+ Xbyak::Address coffset_cx, coffset_cy, coffset_rx, coffset_ry;
+
+ void L_aligned(Xbyak::Label &label, int alignment = 16) {
+ align(alignment);
+ L(label);
+ }
+};
+
+}
+}
+}
+
+#endif /* header guard */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp
new file mode 100644
index 0000000000..4f0b10dadd
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp
@@ -0,0 +1,290 @@
+/*******************************************************************************
+ * Copyright 2019 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#include "gemv.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg) {
+
+ blas_t arg_gemv = *arg;
+
+ if ((arg -> offsetc == FIX_OFFSET) && // Fix offset
+ (arg -> ao == 0) &&
+ (arg -> bo == 0) &&
+ (arg -> co[0] == 0) &&
+ (*(arg -> alpha) == 1.0f) &&
+ ((*(arg -> beta) == 1.0f) || *(arg -> beta) == 0.0f)) {
+
+ if (arg -> n == 1) {
+
+ if (arg -> transa == 1) { // A transpose
+ arg_gemv.n = arg -> k;
+ arg_gemv.ldc = 1;
+ arg_gemv.swap = 0;
+ if (arg -> transb == 0) { // B non transpose
+ arg_gemv.ldb = 1;
+ }
+ // B transpose arg_gemv.ldb = arg -> ldb
+ gemv_threading_driver(&arg_gemv);
+ return 1;
+ }
+ }
+
+ if (arg -> m == 1) {
+
+ if (arg -> transb == 0) { // B non transpose
+ arg_gemv.transa = 1;
+ arg_gemv.m = arg -> n;
+ arg_gemv.n = arg -> k;
+ arg_gemv.a = (int8_t *) arg -> b;
+ arg_gemv.lda = arg -> ldb;
+ arg_gemv.b = (uint8_t *) arg -> a;
+ arg_gemv.swap = 1;
+ if (arg -> transa == 0) { // A non transpose
+ arg_gemv.ldb = arg -> lda;
+ }
+ else { // A transpose
+ arg_gemv.ldb = 1;
+ }
+ gemv_threading_driver(&arg_gemv);
+ return 1;
+ }
+ }
+ }
+
+ return 0;
+}
+
+
+int gemv_kernel_driver(blas_t *arg) {
+
+ dim_t m = arg -> m;
+ dim_t n = arg -> n;
+ uint8_t *a = (uint8_t *) arg -> a;
+ dim_t lda = arg -> lda;
+ int8_t *b = (int8_t *) arg -> b;
+ float beta = *(arg -> beta);
+
+ if (arg -> swap) {
+ arg -> gemv_u8s8s32_kernel(m, n, 1.0f, a, lda, b, beta, arg -> c);
+ }
+ else {
+ arg -> gemv_s8u8s32_kernel(arg -> m, arg -> n, 1.0f, arg -> a,
+ arg -> lda, arg -> b, *(arg -> beta), arg -> c);
+ }
+
+ return 0;
+}
+
+int gemv_threading_driver(blas_t *arg) {
+
+ dim_t nthr_m, nthr_n = 1;
+ dim_t MB, NB, UM = 16, UN = 64;
+ dim_t BLOCKM = 192, BLOCKN = 3072;
+ int status;
+ dim_t i;
+
+ dim_t nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
+
+ uint8_t *new_x = NULL;
+ int32_t *tmp_y = NULL, *new_y = NULL;
+
+ dim_t m = arg -> m, n = arg -> n;
+
+ blas_t arg_seq = *arg;
+ float zero = 0.0f;
+
+ nthr_m = std::min(std::max(m / BLOCKM, (dim_t) 1), nthr);
+ MB = m / nthr_m;
+ MB = (((MB / UM) * UM) == MB) ? MB : (MB / UM) * UM + UM;
+ nthr_m = (((m / MB) * MB) == m) ? m / MB : m / MB + 1;
+ nthr_m = std::min(std::max(nthr_m, (dim_t) 1), nthr);
+
+ while ((nthr_m * (nthr_n + 1) <= nthr) && ((n / (nthr_n + 1)) >= BLOCKN)) {
+ nthr_n++;
+ }
+
+ NB = n / nthr_n;
+ NB = (((NB / UN) * UN) == NB) ? NB : (NB / UN) * UN + UN;
+ nthr_n = (((n / NB) * NB) == n) ? n / NB : n / NB + 1;
+ nthr_n = std::min(std::max(nthr_n, (dim_t) 1), nthr / nthr_m);
+
+ nthr = nthr_m * nthr_n;
+
+ if (arg -> ldb != 1) {
+ new_x = (uint8_t *)malloc(n, 64);
+ if (new_x == NULL)
+ return 1;
+ for (i = 0; i < n; i++) {
+ new_x[i] = (arg -> b)[i * arg -> ldb];
+ }
+ arg_seq.b = new_x;
+ arg_seq.ldb = 1;
+ }
+ else new_x = (uint8_t *) arg -> b;
+
+ if (arg -> ldc != 1) {
+ new_y = (int32_t *) malloc(nthr_m * PADD_BYTESIZE_ONPAGE(MB, sizeof(int32_t)), 64);
+ if (new_y == NULL) {
+ if (arg -> ldb != 1) {
+ free(new_x);
+ }
+ return 1;
+ }
+ }
+
+ // GEMV computation
+ if (nthr == 1) {
+
+ if (arg -> ldc != 1) {
+ if (*(arg -> beta) != 0.0f) {
+ for (i = 0; i < m; i++) {
+ new_y[i] = arg -> c[i * arg -> ldc];
+ }
+ }
+ }
+
+ status = gemv_kernel_driver(&arg_seq);
+
+ if (arg -> ldc != 1) {
+ for (i = 0; i < m; i++) {
+ arg -> c[i * arg -> ldc] = new_y[i];
+ }
+ }
+
+ if (arg -> ldb != 1) {
+ free(new_x);
+ }
+ if (arg -> ldc != 1) {
+ free(new_y);
+ }
+ return status;
+ }
+
+ if (nthr_n > 1) {
+ tmp_y = (int32_t *) malloc((nthr_n - 1) * PADD_BYTESIZE_ONPAGE(m, sizeof(int32_t)), PAGESIZE);
+ if (tmp_y == NULL) {
+ if (arg -> ldb != 1) {
+ free(new_x);
+ }
+ return 1;
+ }
+ }
+
+ parallel_nd((int) nthr, [&](const dim_t ithr) {
+
+ dim_t m_from, m_to, myM;
+ dim_t n_from, n_to, myN;
+
+ dim_t n_id, m_id;
+ dim_t loc_incy = 1;
+ int32_t *loc_y;
+
+ blas_t arg_loc = arg_seq;
+ int j;
+
+ m_id = ithr / nthr_n;
+ n_id = ithr % nthr_n;
+
+ m_from = MB * m_id;
+ m_to = MB * (m_id + 1);
+ if ((m_to > m) || (m_id == nthr_m - 1))
+ m_to = m;
+
+ myM = m_to - m_from;
+
+ n_from = NB * n_id;
+ n_to = NB * (n_id + 1);
+ if ((n_to > n) || (n_id == nthr_n - 1))
+ n_to = n;
+
+ myN = n_to - n_from;
+
+ if (n_id != 0) {
+ arg_loc.beta = &zero;
+ loc_y = tmp_y + (NEXT_THR_STRIDE(m, sizeof(int32_t))) * (n_id - 1) + m_from;
+ }
+ else {
+ if (arg -> ldc == 1) {
+ loc_y = arg_seq.c + m_from;
+ }
+ else {
+ // need to copy the block of c in new_y
+ loc_y = new_y + m_id * NEXT_THR_STRIDE(MB, sizeof(int32_t));
+ if (*(arg -> beta) != 0.0f) {
+ for (j = 0; j < myM; j++) {
+ loc_y[j] = arg -> c[(m_from + j) * arg -> ldc];
+ }
+ }
+ }
+ }
+
+ arg_loc.m = myM;
+ arg_loc.n = myN;
+ arg_loc.a = arg_seq.a + m_from * arg_seq.lda + n_from;
+ arg_loc.b = arg_seq.b + n_from;
+ arg_loc.c = loc_y;
+ arg_loc.ldc = loc_incy;
+
+ gemv_kernel_driver(&arg_loc);
+
+ if ((n_id == 0) && (arg -> ldc != 1)) {
+ for (j = 0; j < myM; j++) {
+ arg -> c[(m_from + j) * arg -> ldc] = loc_y[j];
+ }
+ }
+
+ });
+
+ if (nthr_n > 1) {
+ parallel_nd((int) nthr_m, [&](const dim_t ithr) {
+
+ dim_t j, j_from, j_to, ii;
+ int32_t acc;
+
+ j_from = MB * ithr;
+ j_to = MB * (ithr + 1);
+ if ((j_to > m) || (ithr == nthr - 1))
+ j_to = m;
+
+ for (j = j_from; j < j_to; j++) {
+ acc = 0;
+ for (ii = 0; ii < nthr_n - 1; ii++) {
+ acc += tmp_y[ii * NEXT_THR_STRIDE(m, sizeof(int32_t)) + j];
+ }
+ (arg -> c)[j * arg -> ldc] += acc;
+ }
+ });
+ free(tmp_y);
+ }
+
+ if (arg -> ldb != 1) {
+ free(new_x);
+ }
+
+ if (arg -> ldc != 1) {
+ free(new_y);
+ }
+
+ return 0;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp
new file mode 100644
index 0000000000..c57a8c1d12
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp
@@ -0,0 +1,411 @@
+/*******************************************************************************
+ * Copyright 2019 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
+
+#ifdef _WIN32
+#define is_windows 1
+#else
+#define is_windows 0
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b,
+ Xbyak::Zmm a, Xbyak::Zmm tmp,
+ Xbyak::Zmm one, bool swap,
+ int use_vnni) {
+
+ if (use_vnni) {
+ if (swap)
+ vpdpbusd(acc, a, b);
+ else
+ vpdpbusd(acc, b, a);
+ }
+
+ else {
+ if (swap)
+ vpmaddubsw(tmp, a, b);
+ else
+ vpmaddubsw(tmp, b, a);
+ vpmaddwd(tmp, tmp, one);
+ vpaddd(acc, tmp, acc);
+ }
+
+}
+
+void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx,
+ int b_idx, int nreg_acc,
+ Xbyak::Reg64 A, Xbyak::Reg64 lda,
+ Xbyak::Reg64 X, Xbyak::Zmm tmp,
+ Xbyak::Zmm one, bool swap, int use_vnni,
+ int use_mask, Xbyak::Opmask mask_n) {
+
+ int i;
+ int nreg_A = nreg_acc / 2 + (nreg_acc % 2);
+
+ // load X + j
+ if (use_mask)
+ vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]);
+ else
+ vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]);
+
+ xor_(r14, r14);
+ // load values of A
+ for (i = 0; i < nreg_A; i++) {
+ if (use_mask)
+ vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
+ else
+ vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
+ add(r14, lda);
+ }
+
+ for (i = 0; i < nreg_A; i++) {
+ // vnni (acc, b, a, tmp, one, swap, use_vnni)
+ vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx),
+ Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
+ }
+
+ for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
+ if (use_mask)
+ vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
+ else
+ vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
+ add(r14, lda);
+ }
+
+ for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
+ vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx),
+ Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
+ }
+
+}
+
+void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A,
+ Xbyak::Zmm B, Xbyak::Zmm C,
+ Xbyak::Zmm D) {
+
+ vshufi32x4(dest, A, C, 0x44);
+ vshufi32x4(A, A, C, 0xEE);
+ vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3
+
+ vshufi32x4(dest, B, D, 0x44);
+ vshufi32x4(B, B, D, 0xEE);
+ vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3
+
+ vshufi32x4(A, C, D, 0x88);
+ vshufi32x4(B, C, D, 0xDD);
+ vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi
+
+}
+
+void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y,
+ int start_a_idx, int start_acc_idx,
+ Xbyak::Xmm beta, int use_mask,
+ Xbyak::Opmask mask_m) {
+
+ int l, i, k, j, last_it;
+ Xbyak::Label store_label;
+
+ l = 0;
+ for (k = 0; k < nreg_acc; k += 8) {
+ for (i = 0, j = k; i < 8; i += 4, j += 2) {
+ if (j < nreg_acc) {
+ // shuffle per block of 4 registers
+ shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest
+ Xbyak::Zmm(start_acc_idx + j), // A = acc0
+ Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1
+ Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4
+ Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5
+
+ // extract low and high from dest and hadd
+ vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0);
+ vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1);
+ vphaddd(Xbyak::Ymm(start_a_idx + l),
+ Xbyak::Ymm(start_a_idx + l + 1),
+ Xbyak::Ymm(start_a_idx + l + 2));
+ }
+ l++;
+ }
+
+ vphaddd(Xbyak::Ymm(start_a_idx + l),
+ Xbyak::Ymm(start_a_idx + l - 2),
+ Xbyak::Ymm(start_a_idx + l - 1));
+
+ l++;
+ }
+
+ // eventually add with C and store new value
+ vxorps(Xbyak::Ymm(start_a_idx),
+ Xbyak::Ymm(start_a_idx),
+ Xbyak::Ymm(start_a_idx));
+ vucomiss(beta, Xbyak::Ymm(start_a_idx));
+ je(store_label, T_NEAR);
+
+ // beta = 1
+ for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
+ // load Y and add
+ last_it = (k + 8) > nreg_acc;
+ if (use_mask && last_it)
+ vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]);
+ else
+ vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]);
+
+ vpaddd(Xbyak::Ymm(start_a_idx + l),
+ Xbyak::Ymm(start_a_idx + l),
+ Xbyak::Ymm(start_a_idx + k / 8));
+ }
+
+ // store
+ aligned_label(store_label);
+ for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
+ last_it = (k + 8) > nreg_acc;
+ if (use_mask && last_it)
+ vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m);
+ else
+ vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l));
+ }
+
+}
+
+template <typename T>
+T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) {
+
+ Xbyak::Opmask mask_n = k1, mask_m = k2;
+ Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label;
+ Xbyak::Label n_tail_label, update_c_label, end_label;
+ constexpr unsigned int n_labels = (1 << unroll_m) - 1;
+ Xbyak::Label m_tail_label_case[n_labels];
+ Xbyak::Label n_loop_label_case[n_labels];
+ Xbyak::Label n_tail_label_case[n_labels];
+ Xbyak::Label update_c_label_case[n_labels];
+
+ int i, ii;
+
+ Xbyak::Zmm one, tmp;
+ Xbyak::Reg64 n = abi_param2, m = abi_param1;
+ Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3;
+ Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4;
+ Xbyak::Reg64 X = is_windows ? rdi : r8;
+ Xbyak::Xmm beta = xmm1;
+ Xbyak::Reg64 Y = is_windows ? rsi : r9;
+
+ bool swap = !std::is_same<T, gemv_s8u8s32_kernel_t>::value;
+
+ // Windows: read on the stack lda, X, beta, Y
+
+ int zmm_idx = 1;
+ int nreg_acc = 1 << unroll_m;
+ int nreg_A = 1 << (unroll_m - 1);
+ int nreg_A_acc = nreg_acc + nreg_A;
+
+ if (!use_vnni) {
+ // set a zmm register to one
+ tmp = Xbyak::Zmm(0);
+ one = Xbyak::Zmm(zmm_idx + 1);
+ zmm_idx += 2; // one + tmp
+ }
+ else {
+ beta = xmm0;
+ }
+
+ preamble();
+
+ if (is_windows) {
+ mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]);
+ mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]);
+ movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]);
+ mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]);
+ }
+
+ if (use_vnni && !is_windows) {
+ movaps(beta, xmm1);
+ }
+
+ mov(rax, (1 << unroll_n) - 1);
+ kmovq(k3, rax);
+
+ and_(rax, n); // rax contains n & ((1 << unroll_n) - 1)
+ mov(rbx, 1);
+ shlx(rbx, rbx, rax);
+ sub(rbx, 1);
+ kmovq(mask_n, rbx);
+ // mask_n set (AVX512 only), can use rax and rbx again
+
+ // set mask_m for update of the C matrix
+ // load/store on the C matrix use Ymm so tail according to Ymm size
+ mov(rax, 7); // 8 * 32 = 256 Ymm size
+ and_(rax, m); // rax contains m & 7
+ mov(rbx, 1);
+ shlx(rbx, rbx, rax);
+ sub(rbx, 1);
+ kmovq(mask_m, rbx);
+ // mask_m set (AVX512 only), can use rax and rbx again
+
+ // setup register of ones when VNNI instructions not available
+ if (!use_vnni) {
+ vmovdqu16(one, ptr[rip + one_label]);
+ }
+
+ // M loop
+ // base pointer for A rax contains a + i * lda
+ // Loop stop when rax >= a + (m & mask_um) * lda = rbx
+ // loop increment r10 = um * lda
+ // rbp = Y + i
+ mov(rax, A); // i = 0
+ mov(rbx, m);
+ and_(rbx, mask_um);
+ imul(rbx, lda);
+ add(rbx, A);
+ mov(r10, lda);
+ sal(r10, unroll_m);
+ mov(rbp, Y);
+
+ // N loop
+ // base pointer for X r11 contains x + j
+ // Loop stop when r11 >= x + n & mask_un = r12
+ // loop increment un
+ // r13 = rax + j = A + i * lda + j
+ mov(r12, n);
+ and_(r12, mask_un);
+ add(r12, X);
+
+ // M loop
+ aligned_label(m_loop_label);
+ cmp(rax, rbx);
+ jge(m_tail_label, T_NEAR);
+
+ // enter M loop
+ for(i = 0; i < nreg_acc; i++) {
+ vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
+ Xbyak::Zmm(i + zmm_idx + nreg_A),
+ Xbyak::Zmm(i + zmm_idx + nreg_A));
+ }
+
+ // N loop
+ mov(r11, X); // j = 0
+ mov(r13, rax);
+ aligned_label(n_loop_label);
+ cmp(r11, r12);
+ jge(n_tail_label, T_NEAR);
+
+ // enter N loop
+
+ n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
+ r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
+
+ // increment rax with un
+ add(r11, 1 << unroll_n);
+ add(r13, 1 << unroll_n);
+ jmp(n_loop_label, T_NEAR);
+ // end N loop
+
+ // N tail
+ aligned_label(n_tail_label);
+
+ ktestq(mask_n, k3);
+ je(update_c_label, T_NEAR);
+ n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
+ r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
+
+ // update C matrix
+ aligned_label(update_c_label);
+
+ update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m);
+
+ // increment rax with um * lda
+ add(rax, r10);
+ add(rbp, 1 << (unroll_m + 2));
+ jmp(m_loop_label, T_NEAR);
+ // end M loop
+
+ // M tail
+ aligned_label(m_tail_label);
+
+ // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1
+ mov(r10, m);
+ and_(r10, (1 << unroll_m) - 1);
+ for (ii = 1; ii < 1 << unroll_m; ii++) {
+ aligned_label(m_tail_label_case[ii-1]);
+ cmp(r10, ii);
+ if (ii == (1 << unroll_m) - 1)
+ jne(end_label, T_NEAR);
+ else
+ jne(m_tail_label_case[ii], T_NEAR);
+
+ // m_tail = i, use i accumulators
+
+ for(i = 0; i < ii; i++) {
+ vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
+ Xbyak::Zmm(i + zmm_idx + nreg_A),
+ Xbyak::Zmm(i + zmm_idx + nreg_A));
+ }
+
+ // N loop
+ mov(r11, X); // j = 0
+ mov(r13, rax);
+ aligned_label(n_loop_label_case[ii - 1]);
+ cmp(r11, r12);
+ jge(n_tail_label_case[ii - 1], T_NEAR);
+
+ n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
+ lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
+
+ // increment rax with un
+ add(r11, 1 << unroll_n);
+ add(r13, 1 << unroll_n);
+ jmp(n_loop_label_case[ii - 1], T_NEAR);
+ // end N loop
+
+ // N tail
+ aligned_label(n_tail_label_case[ii - 1]);
+ ktestq(mask_n, k3);
+ je(update_c_label_case[ii - 1], T_NEAR);
+ n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
+ lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
+
+ // update C matrix
+ aligned_label(update_c_label_case[ii - 1]);
+ update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m);
+
+ if (ii < ((1 << unroll_m) - 1))
+ jmp(end_label, T_NEAR);
+ }
+
+ aligned_label(end_label);
+
+ postamble();
+
+ if (!use_vnni) {
+ aligned_label(one_label);
+ for (i = 0; i < size_vec_reg/8; i++)
+ dq(0x0001000100010001);
+ }
+
+ return (T) getCode();
+}
+
+template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t
+jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>(int);
+
+template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t
+jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>(int);
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp
new file mode 100644
index 0000000000..9ea23a5f56
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp
@@ -0,0 +1,64 @@
+/*******************************************************************************
+ * Copyright 2019 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+class jit_avx512_core_gemv_s8u8s32_kern : jit_generator {
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern);
+
+ // assumes untoll_{m,n} are a power of 2
+ static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m
+ const int mask_um = 0xFFFFFFF0;
+ static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n
+ const int mask_un = 0xFFFFFFC0;
+ const int size_vec_reg = 64; // bytes
+
+ void aligned_label(Xbyak::Label &label, int alignment = 16) {
+ align(alignment);
+ L(label);
+ }
+
+ void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int);
+ void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64,
+ Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask);
+ void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm);
+ void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask);
+
+public:
+ jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {};
+
+ // m, n, alpha, a, lda, x, beta, y
+ typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float,
+ const int8_t*, const dim_t, const uint8_t*,
+ const float, int32_t*);
+ typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float,
+ const uint8_t*, const dim_t, const int8_t*,
+ const float, int32_t*);
+
+ template <typename T>
+ T generate(int use_vnni);
+
+};
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp
new file mode 100644
index 0000000000..544cd2ff25
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp
@@ -0,0 +1,819 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_an_kern::jit_avx512_core_u8_copy_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l170;
+Xbyak::Label l1f0;
+Xbyak::Label l20;
+Xbyak::Label l224;
+Xbyak::Label l234;
+Xbyak::Label l240;
+Xbyak::Label l254;
+Xbyak::Label l32c;
+Xbyak::Label l34;
+Xbyak::Label l388;
+Xbyak::Label l3b0;
+Xbyak::Label l3c0;
+Xbyak::Label l3cc;
+Xbyak::Label l3dc;
+Xbyak::Label l454;
+Xbyak::Label l48c;
+Xbyak::Label l4a8;
+Xbyak::Label l4b8;
+Xbyak::Label l4c4;
+Xbyak::Label l4d8;
+Xbyak::Label l570;
+Xbyak::Label l5c4;
+Xbyak::Label l5f0;
+Xbyak::Label l60c;
+Xbyak::Label l61c;
+Xbyak::Label l628;
+Xbyak::Label l638;
+Xbyak::Label l6b0;
+Xbyak::Label l6f4;
+Xbyak::Label l720;
+Xbyak::Label l73c;
+Xbyak::Label l74c;
+Xbyak::Label l758;
+Xbyak::Label l76c;
+Xbyak::Label l804;
+Xbyak::Label l858;
+Xbyak::Label l88c;
+Xbyak::Label l8a4;
+Xbyak::Label l8b2;
+Xbyak::Label l8bc;
+Xbyak::Label l8cc;
+Xbyak::Label l944;
+Xbyak::Label l98c;
+Xbyak::Label l9b0;
+Xbyak::Label l9c8;
+Xbyak::Label l9d8;
+
+ preamble();
+#ifdef _WIN32
+ auto stacksize = get_size_of_abi_save_regs();
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(LDA, qword[LDA]);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ sub(A, -128);
+ sub(B, -128);
+ cmp(N, 0x30);
+ jl(l234, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ add(A, 0x30);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l170, T_NEAR);
+ align(4);
+
+L(l34);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm2);
+ movdqu(xmm0, xword[A1-0x70]);
+ movdqu(xmm1, xword[A1+LDA*1-0x70]);
+ movdqu(xmm2, xword[A1+LDA*2-0x70]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x70]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B-0x30], xmm1);
+ movdqu(xword[B-0x20], xmm4);
+ movdqu(xword[B-0x10], xmm2);
+ movdqu(xmm0, xword[A1-0x60]);
+ movdqu(xmm1, xword[A1+LDA*1-0x60]);
+ movdqu(xmm2, xword[A1+LDA*2-0x60]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x60]);
+ lea(A1, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B], xmm0);
+ movdqu(xword[B+0x10], xmm1);
+ movdqu(xword[B+0x20], xmm4);
+ movdqu(xword[B+0x30], xmm2);
+ sub(B, -192);
+ dec(I);
+ jg(l34, T_NEAR);
+ align(4);
+
+L(l170);
+ test(M, 0x2);
+ jle(l1f0, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ movdqu(xmm2, xword[A1-0x60]);
+ add(A1, LDA);
+ movdqu(xmm3, xword[A1-0x80]);
+ movdqu(xmm4, xword[A1-0x70]);
+ movdqu(xmm5, xword[A1-0x60]);
+ add(A1, LDA);
+ movdqa(xmm6, xmm0);
+ punpcklbw(xmm0, xmm3);
+ punpckhbw(xmm6, xmm3);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm6);
+ movdqa(xmm6, xmm1);
+ punpcklbw(xmm1, xmm4);
+ punpckhbw(xmm6, xmm4);
+ movdqu(xword[B-0x60], xmm1);
+ movdqu(xword[B-0x50], xmm6);
+ movdqa(xmm6, xmm2);
+ punpcklbw(xmm2, xmm5);
+ punpckhbw(xmm6, xmm5);
+ movdqu(xword[B-0x40], xmm2);
+ movdqu(xword[B-0x30], xmm6);
+ sub(B, -96);
+ align(4);
+
+L(l1f0);
+ test(M, 0x1);
+ jle(l224, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ movdqu(xmm2, xword[A1-0x60]);
+ add(A1, LDA);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm2);
+ sub(B, -48);
+ align(4);
+
+L(l224);
+ sub(N, 0x30);
+ cmp(N, 0x30);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l234);
+ cmp(N, 0x20);
+ jl(l3c0, T_NEAR);
+ align(4);
+
+L(l240);
+ mov(A1, A);
+ add(A, 0x20);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l32c, T_NEAR);
+ align(4);
+
+L(l254);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm2);
+ movdqu(xmm0, xword[A1-0x70]);
+ movdqu(xmm1, xword[A1+LDA*1-0x70]);
+ movdqu(xmm2, xword[A1+LDA*2-0x70]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x70]);
+ lea(A1, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B-0x30], xmm1);
+ movdqu(xword[B-0x20], xmm4);
+ movdqu(xword[B-0x10], xmm2);
+ sub(B, -128);
+ dec(I);
+ jg(l254, T_NEAR);
+ align(4);
+
+L(l32c);
+ test(M, 0x2);
+ jle(l388, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqu(xmm2, xword[A1-0x80]);
+ movdqu(xmm3, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm2);
+ punpckhbw(xmm4, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm4);
+ movdqa(xmm4, xmm1);
+ punpcklbw(xmm1, xmm3);
+ punpckhbw(xmm4, xmm3);
+ movdqu(xword[B-0x60], xmm1);
+ movdqu(xword[B-0x50], xmm4);
+ sub(B, -64);
+ align(4);
+
+L(l388);
+ test(M, 0x1);
+ jle(l3b0, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l3b0);
+ sub(N, 0x20);
+ cmp(N, 0x20);
+ jge(l240, T_NEAR);
+ align(4);
+
+L(l3c0);
+ cmp(N, 0x10);
+ jl(l4b8, T_NEAR);
+ align(4);
+
+L(l3cc);
+ mov(A1, A);
+ add(A, 0x10);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l454, T_NEAR);
+ align(4);
+
+L(l3dc);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm1, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm2, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm3, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm1, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm1, xmm3);
+ movdqa(xmm3, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm3, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm1);
+ punpckhwd(xmm2, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm3);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm2);
+ sub(B, -64);
+ dec(I);
+ jg(l3dc, T_NEAR);
+ align(4);
+
+L(l454);
+ test(M, 0x2);
+ jle(l48c, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm1, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqa(xmm2, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm2, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ align(4);
+
+L(l48c);
+ test(M, 0x1);
+ jle(l4a8, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l4a8);
+ sub(N, 0x10);
+ cmp(N, 0x10);
+ jge(l3cc, T_NEAR);
+ align(4);
+
+L(l4b8);
+ cmp(N, 0x8);
+ jl(l61c, T_NEAR);
+ align(4);
+
+L(l4c4);
+ mov(A1, A);
+ add(A, 0x8);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l570, T_NEAR);
+ align(4);
+
+L(l4d8);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ dec(I);
+ jg(l4d8, T_NEAR);
+ align(4);
+
+L(l570);
+ test(M, 0x4);
+ jle(l5c4, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l5c4);
+ test(M, 0x2);
+ jle(l5f0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l5f0);
+ test(M, 0x1);
+ jle(l60c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l60c);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l4c4, T_NEAR);
+ align(4);
+
+L(l61c);
+ cmp(N, 0x4);
+ jl(l74c, T_NEAR);
+ align(4);
+
+L(l628);
+ mov(A1, A);
+ add(A, 0x4);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l6b0, T_NEAR);
+ align(4);
+
+L(l638);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ dec(I);
+ jg(l638, T_NEAR);
+ align(4);
+
+L(l6b0);
+ test(M, 0x4);
+ jle(l6f4, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l6f4);
+ test(M, 0x2);
+ jle(l720, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l720);
+ test(M, 0x1);
+ jle(l73c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l73c);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l628, T_NEAR);
+ align(4);
+
+L(l74c);
+ cmp(N, 0x2);
+ jl(l8b2, T_NEAR);
+ align(4);
+
+L(l758);
+ mov(A1, A);
+ add(A, 0x2);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l804, T_NEAR);
+ align(4);
+
+L(l76c);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm4, eax, 0x0);
+ punpcklbw(xmm1, xmm2);
+ punpcklbw(xmm3, xmm4);
+ punpcklwd(xmm1, xmm3);
+ punpcklqdq(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(LDA3);
+ jg(l76c, T_NEAR);
+ align(4);
+
+L(l804);
+ test(M, 0x4);
+ jle(l858, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l858);
+ test(M, 0x2);
+ jle(l88c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l88c);
+ test(M, 0x1);
+ jle(l8a4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ mov(word[B-0x80], ax);
+ sub(B, -2);
+ align(4);
+
+L(l8a4);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l758, T_NEAR);
+ align(4);
+
+L(l8b2);
+ cmp(N, 0x1);
+ jl(l9d8, T_NEAR);
+ align(4);
+
+L(l8bc);
+ mov(A1, A);
+ add(A, 0x1);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l944, T_NEAR);
+ align(4);
+
+L(l8cc);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x7);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ dec(LDA3);
+ jg(l8cc, T_NEAR);
+ align(4);
+
+L(l944);
+ test(M, 0x4);
+ jle(l98c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l98c);
+ test(M, 0x2);
+ jle(l9b0, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l9b0);
+ test(M, 0x1);
+ jle(l9c8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l9c8);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l8bc, T_NEAR);
+ align(4);
+
+L(l9d8);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp
new file mode 100644
index 0000000000..1c11fc6cef
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp
@@ -0,0 +1,2209 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_at_kern::jit_avx512_core_u8_copy_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l1014;
+Xbyak::Label l1390;
+Xbyak::Label l159c;
+Xbyak::Label l173c;
+Xbyak::Label l18e4;
+Xbyak::Label l1a7c;
+Xbyak::Label l1a8c;
+Xbyak::Label l1a98;
+Xbyak::Label l1ab4;
+Xbyak::Label l1c64;
+Xbyak::Label l1d74;
+Xbyak::Label l1e50;
+Xbyak::Label l1f2c;
+Xbyak::Label l1ffc;
+Xbyak::Label l20;
+Xbyak::Label l200c;
+Xbyak::Label l2018;
+Xbyak::Label l2034;
+Xbyak::Label l2110;
+Xbyak::Label l21a0;
+Xbyak::Label l2210;
+Xbyak::Label l2284;
+Xbyak::Label l22f0;
+Xbyak::Label l2300;
+Xbyak::Label l230c;
+Xbyak::Label l2324;
+Xbyak::Label l2398;
+Xbyak::Label l23e8;
+Xbyak::Label l242c;
+Xbyak::Label l2474;
+Xbyak::Label l24b4;
+Xbyak::Label l24c4;
+Xbyak::Label l24d0;
+Xbyak::Label l24e8;
+Xbyak::Label l2520;
+Xbyak::Label l254c;
+Xbyak::Label l2578;
+Xbyak::Label l25a8;
+Xbyak::Label l25c8;
+Xbyak::Label l25d6;
+Xbyak::Label l25e0;
+Xbyak::Label l25f0;
+Xbyak::Label l260c;
+Xbyak::Label l262c;
+Xbyak::Label l264c;
+Xbyak::Label l2668;
+Xbyak::Label l2680;
+Xbyak::Label l2690;
+Xbyak::Label l44;
+Xbyak::Label l58c;
+Xbyak::Label l8b0;
+Xbyak::Label lb14;
+Xbyak::Label ld84;
+Xbyak::Label lfdc;
+Xbyak::Label lfec;
+Xbyak::Label lff8;
+
+ preamble();
+#ifdef _WIN32
+ auto stacksize = get_size_of_abi_save_regs();
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(N, qword[N]);
+ mov(M, qword[M]);
+ mov(LDA, qword[LDA]);
+ sub(A, -128);
+ sub(B, -128);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ cmp(N, 0x30);
+ jl(lfec, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x5);
+ lea(I, ptr[I+LDA*8]);
+ lea(I, ptr[I+LDA*8]);
+ add(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l58c, T_NEAR);
+ align(4);
+
+L(l44);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B+0x40], xmm1);
+ movdqu(xword[B+0x100], xmm4);
+ movdqu(xword[B+0x1c0], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B+0x50], xmm1);
+ movdqu(xword[B+0x110], xmm4);
+ movdqu(xword[B+0x1d0], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B+0x60], xmm1);
+ movdqu(xword[B+0x120], xmm4);
+ movdqu(xword[B+0x1e0], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B+0x70], xmm1);
+ movdqu(xword[B+0x130], xmm4);
+ movdqu(xword[B+0x1f0], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B+0x80], xmm1);
+ movdqu(xword[B+0x140], xmm4);
+ movdqu(xword[B+0x200], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ movdqu(xword[B+0x90], xmm1);
+ movdqu(xword[B+0x150], xmm4);
+ movdqu(xword[B+0x210], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ movdqu(xword[B+0xa0], xmm1);
+ movdqu(xword[B+0x160], xmm4);
+ movdqu(xword[B+0x220], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ movdqu(xword[B+0xb0], xmm1);
+ movdqu(xword[B+0x170], xmm4);
+ movdqu(xword[B+0x230], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B], xmm0);
+ movdqu(xword[B+0xc0], xmm1);
+ movdqu(xword[B+0x180], xmm4);
+ movdqu(xword[B+0x240], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B+0x10], xmm0);
+ movdqu(xword[B+0xd0], xmm1);
+ movdqu(xword[B+0x190], xmm4);
+ movdqu(xword[B+0x250], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B+0x20], xmm0);
+ movdqu(xword[B+0xe0], xmm1);
+ movdqu(xword[B+0x1a0], xmm4);
+ movdqu(xword[B+0x260], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B+0x30], xmm0);
+ movdqu(xword[B+0xf0], xmm1);
+ movdqu(xword[B+0x1b0], xmm4);
+ movdqu(xword[B+0x270], xmm3);
+ sub(A1, -16);
+ sub(B, -768);
+ dec(I);
+ jg(l44, T_NEAR);
+ align(4);
+
+L(l58c);
+ test(M, 0x8);
+ jle(l8b0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B+0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B+0x50], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B+0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B+0x70], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B+0x80], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x30], xmm0);
+ movdqu(xword[B+0x90], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x20], xmm0);
+ movdqu(xword[B+0xa0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x10], xmm0);
+ movdqu(xword[B+0xb0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B], xmm0);
+ movdqu(xword[B+0xc0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B+0x10], xmm0);
+ movdqu(xword[B+0xd0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B+0x20], xmm0);
+ movdqu(xword[B+0xe0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B+0x30], xmm0);
+ movdqu(xword[B+0xf0], xmm1);
+ sub(A1, -8);
+ sub(B, -384);
+ align(4);
+
+L(l8b0);
+ test(M, 0x4);
+ jle(lb14, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x40], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x30], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x20], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x10], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B+0x10], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B+0x20], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B+0x30], xmm0);
+ sub(A1, -4);
+ sub(B, -192);
+ align(4);
+
+L(lb14);
+ test(M, 0x2);
+ jle(ld84, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x70], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x60], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x50], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x40], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x30], xmm0);
+ sub(A1, -2);
+ sub(B, -96);
+ align(4);
+
+L(ld84);
+ test(M, 0x1);
+ jle(lfdc, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x80], xmm0);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x70], xmm0);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x60], xmm0);
+ sub(B, -48);
+ align(4);
+
+L(lfdc);
+ sub(N, 0x30);
+ cmp(N, 0x30);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(lfec);
+ cmp(N, 0x20);
+ jl(l1a8c, T_NEAR);
+ align(4);
+
+L(lff8);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x5);
+ add(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l1390, T_NEAR);
+ align(4);
+
+L(l1014);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B], xmm1);
+ movdqu(xword[B+0x80], xmm4);
+ movdqu(xword[B+0x100], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B+0x10], xmm1);
+ movdqu(xword[B+0x90], xmm4);
+ movdqu(xword[B+0x110], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B+0x20], xmm1);
+ movdqu(xword[B+0xa0], xmm4);
+ movdqu(xword[B+0x120], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B+0x30], xmm1);
+ movdqu(xword[B+0xb0], xmm4);
+ movdqu(xword[B+0x130], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B+0x40], xmm1);
+ movdqu(xword[B+0xc0], xmm4);
+ movdqu(xword[B+0x140], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ movdqu(xword[B+0x50], xmm1);
+ movdqu(xword[B+0xd0], xmm4);
+ movdqu(xword[B+0x150], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ movdqu(xword[B+0x60], xmm1);
+ movdqu(xword[B+0xe0], xmm4);
+ movdqu(xword[B+0x160], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ movdqu(xword[B+0x70], xmm1);
+ movdqu(xword[B+0xf0], xmm4);
+ movdqu(xword[B+0x170], xmm3);
+ sub(A1, -16);
+ sub(B, -512);
+ dec(I);
+ jg(l1014, T_NEAR);
+ align(4);
+
+L(l1390);
+ test(M, 0x8);
+ jle(l159c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B+0x10], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B+0x20], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B+0x30], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B+0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x30], xmm0);
+ movdqu(xword[B+0x50], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x20], xmm0);
+ movdqu(xword[B+0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x10], xmm0);
+ movdqu(xword[B+0x70], xmm1);
+ sub(A1, -8);
+ sub(B, -256);
+ align(4);
+
+L(l159c);
+ test(M, 0x4);
+ jle(l173c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x40], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x30], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x20], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x10], xmm0);
+ sub(A1, -4);
+ sub(B, -128);
+ align(4);
+
+L(l173c);
+ test(M, 0x2);
+ jle(l18e4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x70], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x60], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x50], xmm0);
+ sub(A1, -2);
+ sub(B, -64);
+ align(4);
+
+L(l18e4);
+ test(M, 0x1);
+ jle(l1a7c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x80], xmm0);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l1a7c);
+ sub(N, 0x20);
+ cmp(N, 0x20);
+ jge(lff8, T_NEAR);
+ align(4);
+
+L(l1a8c);
+ cmp(N, 0x10);
+ jl(l200c, T_NEAR);
+ align(4);
+
+L(l1a98);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x4);
+ add(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l1c64, T_NEAR);
+ align(4);
+
+L(l1ab4);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x40], xmm1);
+ movdqu(xword[B], xmm4);
+ movdqu(xword[B+0x40], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x30], xmm1);
+ movdqu(xword[B+0x10], xmm4);
+ movdqu(xword[B+0x50], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x20], xmm1);
+ movdqu(xword[B+0x20], xmm4);
+ movdqu(xword[B+0x60], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B-0x10], xmm1);
+ movdqu(xword[B+0x30], xmm4);
+ movdqu(xword[B+0x70], xmm3);
+ sub(A1, -16);
+ sub(B, -256);
+ dec(I);
+ jg(l1ab4, T_NEAR);
+ align(4);
+
+L(l1c64);
+ test(M, 0x8);
+ jle(l1d74, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x30], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x20], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B-0x10], xmm1);
+ sub(A1, -8);
+ sub(B, -128);
+ align(4);
+
+L(l1d74);
+ test(M, 0x4);
+ jle(l1e50, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ sub(A1, -4);
+ sub(B, -64);
+ align(4);
+
+L(l1e50);
+ test(M, 0x2);
+ jle(l1f2c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x70], xmm0);
+ sub(A1, -2);
+ sub(B, -32);
+ align(4);
+
+L(l1f2c);
+ test(M, 0x1);
+ jle(l1ffc, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l1ffc);
+ sub(N, 0x10);
+ cmp(N, 0x10);
+ jge(l1a98, T_NEAR);
+ align(4);
+
+L(l200c);
+ cmp(N, 0x8);
+ jl(l2300, T_NEAR);
+ align(4);
+
+L(l2018);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*4]);
+ lea(I, ptr[A1+LDA*8]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l2110, T_NEAR);
+ align(4);
+
+L(l2034);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ sub(A1, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x60], xmm1);
+ movdqu(xword[B-0x40], xmm4);
+ movdqu(xword[B-0x20], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ movdqu(xword[B-0x30], xmm4);
+ movdqu(xword[B-0x10], xmm3);
+ sub(B, -128);
+ dec(I);
+ jg(l2034, T_NEAR);
+ align(4);
+
+L(l2110);
+ test(M, 0x8);
+ jle(l21a0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ sub(A1, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ align(4);
+
+L(l21a0);
+ test(M, 0x4);
+ jle(l2210, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ sub(A1, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l2210);
+ test(M, 0x2);
+ jle(l2284, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l2284);
+ test(M, 0x1);
+ jle(l22f0, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x7);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l22f0);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l2018, T_NEAR);
+ align(4);
+
+L(l2300);
+ cmp(N, 0x4);
+ jl(l24c4, T_NEAR);
+ align(4);
+
+L(l230c);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*2]);
+ lea(I, ptr[A1+LDA*4]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l2398, T_NEAR);
+ align(4);
+
+L(l2324);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm2, xword[A2-0x80]);
+ movdqu(xmm3, xword[A2+LDA*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm3);
+ sub(B, -64);
+ dec(I);
+ jg(l2324, T_NEAR);
+ align(4);
+
+L(l2398);
+ test(M, 0x8);
+ jle(l23e8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ sub(A1, -8);
+ movq(xmm2, qword[A2-0x80]);
+ movq(xmm3, qword[A2+LDA*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l23e8);
+ test(M, 0x4);
+ jle(l242c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ sub(A1, -4);
+ movd(xmm2, dword[A2-0x80]);
+ movd(xmm3, dword[A2+LDA*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l242c);
+ test(M, 0x2);
+ jle(l2474, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x3);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l2474);
+ test(M, 0x1);
+ jle(l24b4, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l24b4);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l230c, T_NEAR);
+ align(4);
+
+L(l24c4);
+ cmp(N, 0x2);
+ jl(l25d6, T_NEAR);
+ align(4);
+
+L(l24d0);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*1]);
+ lea(I, ptr[A1+LDA*2]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l2520, T_NEAR);
+ align(4);
+
+L(l24e8);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm1, xword[A2-0x80]);
+ sub(A2, -16);
+ movdqa(xmm2, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm2, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ dec(I);
+ jg(l24e8, T_NEAR);
+ align(4);
+
+L(l2520);
+ test(M, 0x8);
+ jle(l254c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(xmm1, qword[A2-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l254c);
+ test(M, 0x4);
+ jle(l2578, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(xmm1, dword[A2-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l2578);
+ test(M, 0x2);
+ jle(l25a8, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x1);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l25a8);
+ test(M, 0x1);
+ jle(l25c8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A2-0x80]);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l25c8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l24d0, T_NEAR);
+ align(4);
+
+L(l25d6);
+ cmp(N, 0x1);
+ jl(l2690, T_NEAR);
+ align(4);
+
+L(l25e0);
+ mov(A1, A);
+ add(A, LDA);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l260c, T_NEAR);
+ align(4);
+
+L(l25f0);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(I);
+ jg(l25f0, T_NEAR);
+ align(4);
+
+L(l260c);
+ test(M, 0x8);
+ jle(l262c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l262c);
+ test(M, 0x4);
+ jle(l264c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l264c);
+ test(M, 0x2);
+ jle(l2668, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ mov(word[B-0x80], ax);
+ sub(A1, -2);
+ sub(B, -2);
+ align(4);
+
+L(l2668);
+ test(M, 0x1);
+ jle(l2680, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l2680);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l25e0, T_NEAR);
+ align(4);
+
+L(l2690);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp
new file mode 100644
index 0000000000..56c36ee14a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp
@@ -0,0 +1,564 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_bn_kern::jit_avx512_core_u8_copy_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l118;
+Xbyak::Label l1a8;
+Xbyak::Label l20;
+Xbyak::Label l218;
+Xbyak::Label l28c;
+Xbyak::Label l2f8;
+Xbyak::Label l308;
+Xbyak::Label l314;
+Xbyak::Label l32c;
+Xbyak::Label l3a0;
+Xbyak::Label l3c;
+Xbyak::Label l3f0;
+Xbyak::Label l434;
+Xbyak::Label l47c;
+Xbyak::Label l4bc;
+Xbyak::Label l4cc;
+Xbyak::Label l4d8;
+Xbyak::Label l4f0;
+Xbyak::Label l528;
+Xbyak::Label l554;
+Xbyak::Label l580;
+Xbyak::Label l5b0;
+Xbyak::Label l5d0;
+Xbyak::Label l5de;
+Xbyak::Label l5e8;
+Xbyak::Label l5f8;
+Xbyak::Label l614;
+Xbyak::Label l634;
+Xbyak::Label l654;
+Xbyak::Label l670;
+Xbyak::Label l688;
+Xbyak::Label l698;
+
+ preamble();
+#ifdef _WIN32
+ auto stacksize = get_size_of_abi_save_regs();
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(N, qword[N]);
+ mov(M, qword[M]);
+ mov(LDA, qword[LDA]);
+ sub(A, -128);
+ sub(B, -128);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ cmp(N, 0x8);
+ jl(l308, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*4]);
+ lea(I, ptr[A1+LDA*8]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l118, T_NEAR);
+ align(4);
+
+L(l3c);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ sub(A1, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x60], xmm1);
+ movdqu(xword[B-0x40], xmm4);
+ movdqu(xword[B-0x20], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ movdqu(xword[B-0x30], xmm4);
+ movdqu(xword[B-0x10], xmm3);
+ sub(B, -128);
+ dec(I);
+ jg(l3c, T_NEAR);
+ align(4);
+
+L(l118);
+ test(M, 0x8);
+ jle(l1a8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ sub(A1, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ align(4);
+
+L(l1a8);
+ test(M, 0x4);
+ jle(l218, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ sub(A1, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l218);
+ test(M, 0x2);
+ jle(l28c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l28c);
+ test(M, 0x1);
+ jle(l2f8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x7);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l2f8);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l308);
+ cmp(N, 0x4);
+ jl(l4cc, T_NEAR);
+ align(4);
+
+L(l314);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*2]);
+ lea(I, ptr[A1+LDA*4]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l3a0, T_NEAR);
+ align(4);
+
+L(l32c);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm2, xword[A2-0x80]);
+ movdqu(xmm3, xword[A2+LDA*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm3);
+ sub(B, -64);
+ dec(I);
+ jg(l32c, T_NEAR);
+ align(4);
+
+L(l3a0);
+ test(M, 0x8);
+ jle(l3f0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ sub(A1, -8);
+ movq(xmm2, qword[A2-0x80]);
+ movq(xmm3, qword[A2+LDA*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l3f0);
+ test(M, 0x4);
+ jle(l434, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ sub(A1, -4);
+ movd(xmm2, dword[A2-0x80]);
+ movd(xmm3, dword[A2+LDA*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l434);
+ test(M, 0x2);
+ jle(l47c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x3);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l47c);
+ test(M, 0x1);
+ jle(l4bc, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l4bc);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l314, T_NEAR);
+ align(4);
+
+L(l4cc);
+ cmp(N, 0x2);
+ jl(l5de, T_NEAR);
+ align(4);
+
+L(l4d8);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*1]);
+ lea(I, ptr[A1+LDA*2]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l528, T_NEAR);
+ align(4);
+
+L(l4f0);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm1, xword[A2-0x80]);
+ sub(A2, -16);
+ movdqa(xmm2, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm2, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ dec(I);
+ jg(l4f0, T_NEAR);
+ align(4);
+
+L(l528);
+ test(M, 0x8);
+ jle(l554, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(xmm1, qword[A2-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l554);
+ test(M, 0x4);
+ jle(l580, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(xmm1, dword[A2-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l580);
+ test(M, 0x2);
+ jle(l5b0, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x1);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l5b0);
+ test(M, 0x1);
+ jle(l5d0, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A2-0x80]);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l5d0);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l4d8, T_NEAR);
+ align(4);
+
+L(l5de);
+ cmp(N, 0x1);
+ jl(l698, T_NEAR);
+ align(4);
+
+L(l5e8);
+ mov(A1, A);
+ add(A, LDA);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l614, T_NEAR);
+ align(4);
+
+L(l5f8);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(I);
+ jg(l5f8, T_NEAR);
+ align(4);
+
+L(l614);
+ test(M, 0x8);
+ jle(l634, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l634);
+ test(M, 0x4);
+ jle(l654, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l654);
+ test(M, 0x2);
+ jle(l670, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ mov(word[B-0x80], ax);
+ sub(A1, -2);
+ sub(B, -2);
+ align(4);
+
+L(l670);
+ test(M, 0x1);
+ jle(l688, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l688);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l5e8, T_NEAR);
+ align(4);
+
+L(l698);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp
new file mode 100644
index 0000000000..53e99d94de
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp
@@ -0,0 +1,501 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_bt_kern::jit_avx512_core_u8_copy_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l120;
+Xbyak::Label l14c;
+Xbyak::Label l168;
+Xbyak::Label l178;
+Xbyak::Label l184;
+Xbyak::Label l194;
+Xbyak::Label l20;
+Xbyak::Label l20c;
+Xbyak::Label l250;
+Xbyak::Label l27c;
+Xbyak::Label l298;
+Xbyak::Label l2a8;
+Xbyak::Label l2b4;
+Xbyak::Label l2c8;
+Xbyak::Label l34;
+Xbyak::Label l360;
+Xbyak::Label l3b4;
+Xbyak::Label l3e8;
+Xbyak::Label l400;
+Xbyak::Label l40e;
+Xbyak::Label l418;
+Xbyak::Label l428;
+Xbyak::Label l4a0;
+Xbyak::Label l4e8;
+Xbyak::Label l50c;
+Xbyak::Label l524;
+Xbyak::Label l534;
+Xbyak::Label lcc;
+
+ preamble();
+#ifdef _WIN32
+ auto stacksize = get_size_of_abi_save_regs();
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(LDA, qword[LDA]);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ sub(A, -128);
+ sub(B, -128);
+ cmp(N, 0x8);
+ jl(l178, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ add(A, 0x8);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(lcc, T_NEAR);
+ align(4);
+
+L(l34);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ dec(I);
+ jg(l34, T_NEAR);
+ align(4);
+
+L(lcc);
+ test(M, 0x4);
+ jle(l120, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l120);
+ test(M, 0x2);
+ jle(l14c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l14c);
+ test(M, 0x1);
+ jle(l168, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l168);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l178);
+ cmp(N, 0x4);
+ jl(l2a8, T_NEAR);
+ align(4);
+
+L(l184);
+ mov(A1, A);
+ add(A, 0x4);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l20c, T_NEAR);
+ align(4);
+
+L(l194);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ dec(I);
+ jg(l194, T_NEAR);
+ align(4);
+
+L(l20c);
+ test(M, 0x4);
+ jle(l250, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l250);
+ test(M, 0x2);
+ jle(l27c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l27c);
+ test(M, 0x1);
+ jle(l298, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l298);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l184, T_NEAR);
+ align(4);
+
+L(l2a8);
+ cmp(N, 0x2);
+ jl(l40e, T_NEAR);
+ align(4);
+
+L(l2b4);
+ mov(A1, A);
+ add(A, 0x2);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l360, T_NEAR);
+ align(4);
+
+L(l2c8);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm4, eax, 0x0);
+ punpcklbw(xmm1, xmm2);
+ punpcklbw(xmm3, xmm4);
+ punpcklwd(xmm1, xmm3);
+ punpcklqdq(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(LDA3);
+ jg(l2c8, T_NEAR);
+ align(4);
+
+L(l360);
+ test(M, 0x4);
+ jle(l3b4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l3b4);
+ test(M, 0x2);
+ jle(l3e8, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l3e8);
+ test(M, 0x1);
+ jle(l400, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ mov(word[B-0x80], ax);
+ sub(B, -2);
+ align(4);
+
+L(l400);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l2b4, T_NEAR);
+ align(4);
+
+L(l40e);
+ cmp(N, 0x1);
+ jl(l534, T_NEAR);
+ align(4);
+
+L(l418);
+ mov(A1, A);
+ add(A, 0x1);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l4a0, T_NEAR);
+ align(4);
+
+L(l428);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x7);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ dec(LDA3);
+ jg(l428, T_NEAR);
+ align(4);
+
+L(l4a0);
+ test(M, 0x4);
+ jle(l4e8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l4e8);
+ test(M, 0x2);
+ jle(l50c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l50c);
+ test(M, 0x1);
+ jle(l524, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l524);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l418, T_NEAR);
+ align(4);
+
+L(l534);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp
new file mode 100644
index 0000000000..49a312fc88
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp
@@ -0,0 +1,1283 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_sum_an_kern::jit_avx512_core_u8_copy_sum_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#define ARG_BIAS 24+stacksize+rsp
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+#define ARG_BIAS 72+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l1024;
+Xbyak::Label l1090;
+Xbyak::Label l10d4;
+Xbyak::Label l10fc;
+Xbyak::Label l111a;
+Xbyak::Label l1124;
+Xbyak::Label l113c;
+Xbyak::Label l11d4;
+Xbyak::Label l1234;
+Xbyak::Label l1278;
+Xbyak::Label l129c;
+Xbyak::Label l12bc;
+Xbyak::Label l20;
+Xbyak::Label l2a0;
+Xbyak::Label l3c0;
+Xbyak::Label l438;
+Xbyak::Label l480;
+Xbyak::Label l48c;
+Xbyak::Label l4c8;
+Xbyak::Label l5c;
+Xbyak::Label l6a8;
+Xbyak::Label l7b4;
+Xbyak::Label l850;
+Xbyak::Label l89c;
+Xbyak::Label l8a8;
+Xbyak::Label l8d0;
+Xbyak::Label l9d0;
+Xbyak::Label la64;
+Xbyak::Label lab8;
+Xbyak::Label lae8;
+Xbyak::Label laf4;
+Xbyak::Label lb14;
+Xbyak::Label lc30;
+Xbyak::Label lcc8;
+Xbyak::Label ld1c;
+Xbyak::Label ld54;
+Xbyak::Label ld78;
+Xbyak::Label ld84;
+Xbyak::Label ld9c;
+Xbyak::Label le58;
+Xbyak::Label lebc;
+Xbyak::Label lef8;
+Xbyak::Label lf1c;
+Xbyak::Label lf3c;
+Xbyak::Label lf48;
+Xbyak::Label lf60;
+
+ preamble();
+ auto stacksize = get_size_of_abi_save_regs();
+#ifdef _WIN32
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(LDA, qword[LDA]);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ sub(A, -128);
+ sub(B, -128);
+ cmp(N, 0x30);
+ jl(l480, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ add(A, 0x30);
+ vxorps(ymm8, ymm8, ymm8);
+ vxorps(ymm9, ymm9, ymm9);
+ vxorps(ymm10, ymm10, ymm10);
+ vxorps(ymm11, ymm11, ymm11);
+ vxorps(ymm12, ymm12, ymm12);
+ vxorps(ymm13, ymm13, ymm13);
+ vxorps(ymm14, ymm14, ymm14);
+ vxorps(ymm15, ymm15, ymm15);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l2a0, T_NEAR);
+ align(4);
+
+L(l5c);
+ vmovdqu(xmm0, xword[A1-0x80]);
+ vmovdqu(xmm1, xword[A1+LDA*1-0x80]);
+ vmovdqu(xmm2, xword[A1+LDA*2-0x80]);
+ vmovdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ vpunpcklbw(xmm4, xmm0, xmm1);
+ vpunpckhbw(xmm5, xmm0, xmm1);
+ vpunpcklbw(xmm6, xmm2, xmm3);
+ vpunpckhbw(xmm7, xmm2, xmm3);
+ vpunpcklwd(xmm0, xmm4, xmm6);
+ vpunpckhwd(xmm1, xmm4, xmm6);
+ vpunpcklwd(xmm2, xmm5, xmm7);
+ vpunpckhwd(xmm3, xmm5, xmm7);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vmovdqu(xword[B-0x80], xmm0);
+ vmovdqu(xword[B-0x70], xmm1);
+ vpmovsxbw(ymm5, xmm2);
+ vmovhlps(xmm6, xmm2, xmm2);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vmovdqu(xword[B-0x60], xmm2);
+ vmovdqu(xword[B-0x50], xmm3);
+ vmovdqu(xmm0, xword[A1-0x70]);
+ vmovdqu(xmm1, xword[A1+LDA*1-0x70]);
+ vmovdqu(xmm2, xword[A1+LDA*2-0x70]);
+ vmovdqu(xmm3, xword[A1+LDA3*1-0x70]);
+ vpunpcklbw(xmm4, xmm0, xmm1);
+ vpunpckhbw(xmm5, xmm0, xmm1);
+ vpunpcklbw(xmm6, xmm2, xmm3);
+ vpunpckhbw(xmm7, xmm2, xmm3);
+ vpunpcklwd(xmm0, xmm4, xmm6);
+ vpunpckhwd(xmm1, xmm4, xmm6);
+ vpunpcklwd(xmm2, xmm5, xmm7);
+ vpunpckhwd(xmm3, xmm5, xmm7);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vmovdqu(xword[B-0x40], xmm0);
+ vmovdqu(xword[B-0x30], xmm1);
+ vpmovsxbw(ymm5, xmm2);
+ vmovhlps(xmm6, xmm2, xmm2);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vmovdqu(xword[B-0x20], xmm2);
+ vmovdqu(xword[B-0x10], xmm3);
+ vmovdqu(xmm0, xword[A1-0x60]);
+ vmovdqu(xmm1, xword[A1+LDA*1-0x60]);
+ vmovdqu(xmm2, xword[A1+LDA*2-0x60]);
+ vmovdqu(xmm3, xword[A1+LDA3*1-0x60]);
+ lea(A1, ptr[A1+LDA*4]);
+ vpunpcklbw(xmm4, xmm0, xmm1);
+ vpunpckhbw(xmm5, xmm0, xmm1);
+ vpunpcklbw(xmm6, xmm2, xmm3);
+ vpunpckhbw(xmm7, xmm2, xmm3);
+ vpunpcklwd(xmm0, xmm4, xmm6);
+ vpunpckhwd(xmm1, xmm4, xmm6);
+ vpunpcklwd(xmm2, xmm5, xmm7);
+ vpunpckhwd(xmm3, xmm5, xmm7);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vmovdqu(xword[B], xmm0);
+ vmovdqu(xword[B+0x10], xmm1);
+ vpmovsxbw(ymm5, xmm2);
+ vmovhlps(xmm6, xmm2, xmm2);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ vmovdqu(xword[B+0x20], xmm2);
+ vmovdqu(xword[B+0x30], xmm3);
+ sub(B, -192);
+ dec(I);
+ jg(l5c, T_NEAR);
+ align(4);
+
+L(l2a0);
+ test(M, 0x2);
+ jle(l3c0, T_NEAR);
+ vmovdqu(xmm0, xword[A1-0x80]);
+ vmovdqu(xmm1, xword[A1-0x70]);
+ vmovdqu(xmm2, xword[A1-0x60]);
+ add(A1, LDA);
+ vmovdqu(xmm6, xword[A1-0x80]);
+ vmovdqu(xmm4, xword[A1-0x70]);
+ vmovdqu(xmm5, xword[A1-0x60]);
+ add(A1, LDA);
+ vpunpcklbw(xmm3, xmm0, xmm6);
+ vpunpckhbw(xmm0, xmm0, xmm6);
+ vpmovsxbw(ymm7, xmm3);
+ vmovhlps(xmm6, xmm3, xmm3);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm8, ymm8, ymm7);
+ vmovdqu(xword[B-0x80], xmm3);
+ vpmovsxbw(ymm7, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm9, ymm9, ymm7);
+ vmovdqu(xword[B-0x70], xmm0);
+ vpunpcklbw(xmm3, xmm1, xmm4);
+ vpunpckhbw(xmm0, xmm1, xmm4);
+ vpmovsxbw(ymm7, xmm3);
+ vmovhlps(xmm6, xmm3, xmm3);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm10, ymm10, ymm7);
+ vmovdqu(xword[B-0x60], xmm3);
+ vpmovsxbw(ymm7, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm11, ymm11, ymm7);
+ vmovdqu(xword[B-0x50], xmm0);
+ vpunpcklbw(xmm3, xmm2, xmm5);
+ vpunpckhbw(xmm0, xmm2, xmm5);
+ vpmovsxbw(ymm7, xmm3);
+ vmovhlps(xmm6, xmm3, xmm3);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm12, ymm12, ymm7);
+ vmovdqu(xword[B-0x40], xmm3);
+ vpmovsxbw(ymm7, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm13, ymm13, ymm7);
+ vmovdqu(xword[B-0x30], xmm0);
+ sub(B, -96);
+ align(4);
+
+L(l3c0);
+ test(M, 0x1);
+ jle(l438, T_NEAR);
+ vmovdqu(xmm0, xword[A1-0x80]);
+ vmovdqu(xmm1, xword[A1-0x70]);
+ vmovdqu(xmm2, xword[A1-0x60]);
+ add(A1, LDA);
+ vpmovsxbd(ymm7, xmm0);
+ vpaddd(ymm8, ymm8, ymm7);
+ vmovhlps(xmm7, xmm0, xmm0);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm9, ymm9, ymm7);
+ vmovdqu(xword[B-0x80], xmm0);
+ vpmovsxbd(ymm7, xmm1);
+ vpaddd(ymm10, ymm10, ymm7);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm11, ymm11, ymm7);
+ vmovdqu(xword[B-0x70], xmm1);
+ vpmovsxbd(ymm7, xmm2);
+ vpaddd(ymm12, ymm12, ymm7);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm13, ymm13, ymm7);
+ vmovdqu(xword[B-0x60], xmm2);
+ sub(B, -48);
+ align(4);
+
+L(l438);
+ mov(A1, qword[ARG_BIAS]);
+ vmovdqu(yword[A1], ymm8);
+ vmovdqu(yword[A1+0x20], ymm9);
+ vmovdqu(yword[A1+0x40], ymm10);
+ vmovdqu(yword[A1+0x60], ymm11);
+ vmovdqu(yword[A1+0x80], ymm12);
+ vmovdqu(yword[A1+0xa0], ymm13);
+ add(qword[ARG_BIAS], 0xc0);
+ sub(N, 0x30);
+ cmp(N, 0x30);
+ jge(l20, T_NEAR);
+ vzeroupper();
+ align(4);
+
+L(l480);
+ cmp(N, 0x20);
+ jl(l89c, T_NEAR);
+ align(4);
+
+L(l48c);
+ mov(A1, A);
+ add(A, 0x20);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ pxor(xmm10, xmm10);
+ pxor(xmm11, xmm11);
+ pxor(xmm12, xmm12);
+ pxor(xmm13, xmm13);
+ pxor(xmm14, xmm14);
+ pxor(xmm15, xmm15);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l6a8, T_NEAR);
+ align(4);
+
+L(l4c8);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm4);
+ pmovsxbw(xmm5, xmm2);
+ movhlps(xmm6, xmm2);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm2);
+ movdqu(xmm0, xword[A1-0x70]);
+ movdqu(xmm1, xword[A1+LDA*1-0x70]);
+ movdqu(xmm2, xword[A1+LDA*2-0x70]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x70]);
+ lea(A1, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B-0x30], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B-0x20], xmm4);
+ pmovsxbw(xmm5, xmm2);
+ movhlps(xmm6, xmm2);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B-0x10], xmm2);
+ sub(B, -128);
+ dec(I);
+ jg(l4c8, T_NEAR);
+ align(4);
+
+L(l6a8);
+ test(M, 0x2);
+ jle(l7b4, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqu(xmm2, xword[A1-0x80]);
+ movdqu(xmm3, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm2);
+ punpckhbw(xmm4, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm4);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x70], xmm4);
+ movdqa(xmm4, xmm1);
+ punpcklbw(xmm1, xmm3);
+ punpckhbw(xmm4, xmm3);
+ pmovsxbw(xmm5, xmm1);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm13, xmm6);
+ movdqu(xword[B-0x60], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm15, xmm6);
+ movdqu(xword[B-0x50], xmm4);
+ sub(B, -64);
+ align(4);
+
+L(l7b4);
+ test(M, 0x1);
+ jle(l850, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ add(A1, LDA);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm8, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbd(xmm5, xmm1);
+ paddd(xmm12, xmm5);
+ pshufd(xmm6, xmm1, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm13, xmm6);
+ pshufd(xmm5, xmm1, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ pshufd(xmm6, xmm1, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm15, xmm6);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l850);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ movdqu(xword[A1+0x20], xmm10);
+ movdqu(xword[A1+0x30], xmm11);
+ movdqu(xword[A1+0x40], xmm12);
+ movdqu(xword[A1+0x50], xmm13);
+ movdqu(xword[A1+0x60], xmm14);
+ movdqu(xword[A1+0x70], xmm15);
+ add(qword[ARG_BIAS], 0x80);
+ sub(N, 0x20);
+ cmp(N, 0x20);
+ jge(l48c, T_NEAR);
+ align(4);
+
+L(l89c);
+ cmp(N, 0x10);
+ jl(lae8, T_NEAR);
+ align(4);
+
+L(l8a8);
+ mov(A1, A);
+ add(A, 0x10);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ pxor(xmm10, xmm10);
+ pxor(xmm11, xmm11);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l9d0, T_NEAR);
+ align(4);
+
+L(l8d0);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm1, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm2, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm3, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm1, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm1, xmm3);
+ movdqa(xmm3, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm3, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm1);
+ punpckhwd(xmm2, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm3);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pmovsxbw(xmm5, xmm2);
+ movhlps(xmm6, xmm2);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm2);
+ sub(B, -64);
+ dec(I);
+ jg(l8d0, T_NEAR);
+ align(4);
+
+L(l9d0);
+ test(M, 0x2);
+ jle(la64, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm1, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqa(xmm2, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm2, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pmovsxbw(xmm5, xmm2);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movhlps(xmm6, xmm2);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ align(4);
+
+L(la64);
+ test(M, 0x1);
+ jle(lab8, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm8, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(lab8);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ movdqu(xword[A1+0x20], xmm10);
+ movdqu(xword[A1+0x30], xmm11);
+ add(qword[ARG_BIAS], 0x40);
+ sub(N, 0x10);
+ cmp(N, 0x10);
+ jge(l8a8, T_NEAR);
+ align(4);
+
+L(lae8);
+ cmp(N, 0x8);
+ jl(ld78, T_NEAR);
+ align(4);
+
+L(laf4);
+ mov(A1, A);
+ add(A, 0x8);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(lc30, T_NEAR);
+ align(4);
+
+L(lb14);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ dec(I);
+ jg(lb14, T_NEAR);
+ align(4);
+
+L(lc30);
+ test(M, 0x4);
+ jle(lcc8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(lcc8);
+ test(M, 0x2);
+ jle(ld1c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(ld1c);
+ test(M, 0x1);
+ jle(ld54, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ pmovsxbd(xmm5, xmm0);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm8, xmm5);
+ paddd(xmm9, xmm6);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(ld54);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ add(qword[ARG_BIAS], 0x20);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(laf4, T_NEAR);
+ align(4);
+
+L(ld78);
+ cmp(N, 0x4);
+ jl(lf3c, T_NEAR);
+ align(4);
+
+L(ld84);
+ mov(A1, A);
+ add(A, 0x4);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(le58, T_NEAR);
+ align(4);
+
+L(ld9c);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ dec(I);
+ jg(ld9c, T_NEAR);
+ align(4);
+
+L(le58);
+ test(M, 0x4);
+ jle(lebc, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(lebc);
+ test(M, 0x2);
+ jle(lef8, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(lef8);
+ test(M, 0x1);
+ jle(lf1c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(lf1c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x10);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(ld84, T_NEAR);
+ align(4);
+
+L(lf3c);
+ cmp(N, 0x2);
+ jl(l111a, T_NEAR);
+ align(4);
+
+L(lf48);
+ mov(A1, A);
+ add(A, 0x2);
+ pxor(xmm7, xmm7);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l1024, T_NEAR);
+ align(4);
+
+L(lf60);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm4, eax, 0x0);
+ punpcklbw(xmm1, xmm2);
+ punpcklbw(xmm3, xmm4);
+ punpcklwd(xmm1, xmm3);
+ punpcklqdq(xmm0, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(LDA3);
+ jg(lf60, T_NEAR);
+ align(4);
+
+L(l1024);
+ test(M, 0x4);
+ jle(l1090, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l1090);
+ test(M, 0x2);
+ jle(l10d4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l10d4);
+ test(M, 0x1);
+ jle(l10fc, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(word[B-0x80], ax);
+ sub(B, -2);
+ align(4);
+
+L(l10fc);
+ mov(A1, qword[ARG_BIAS]);
+ movq(qword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(lf48, T_NEAR);
+ align(4);
+
+L(l111a);
+ cmp(N, 0x1);
+ jl(l12bc, T_NEAR);
+ align(4);
+
+L(l1124);
+ mov(A1, A);
+ add(A, 0x1);
+ pxor(xmm7, xmm7);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l11d4, T_NEAR);
+ align(4);
+
+L(l113c);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ dec(LDA3);
+ jg(l113c, T_NEAR);
+ align(4);
+
+L(l11d4);
+ test(M, 0x4);
+ jle(l1234, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l1234);
+ test(M, 0x2);
+ jle(l1278, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l1278);
+ test(M, 0x1);
+ jle(l129c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l129c);
+ mov(A1, qword[ARG_BIAS]);
+ movd(dword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x4);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l1124, T_NEAR);
+ align(4);
+
+L(l12bc);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+#undef ARG_BIAS
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp
new file mode 100644
index 0000000000..a4f4ff09c6
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp
@@ -0,0 +1,3163 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_sum_at_kern::jit_avx512_core_u8_copy_sum_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#define ARG_BIAS 24+stacksize+rsp
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+#define ARG_BIAS 72+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l1750;
+Xbyak::Label l1b6c;
+Xbyak::Label l1e14;
+Xbyak::Label l20;
+Xbyak::Label l2068;
+Xbyak::Label l226c;
+Xbyak::Label l22b8;
+Xbyak::Label l22c4;
+Xbyak::Label l22f4;
+Xbyak::Label l26b4;
+Xbyak::Label l28cc;
+Xbyak::Label l2a2c;
+Xbyak::Label l2b5c;
+Xbyak::Label l2c64;
+Xbyak::Label l2c94;
+Xbyak::Label l2ca0;
+Xbyak::Label l2cc8;
+Xbyak::Label l2eac;
+Xbyak::Label l2fc0;
+Xbyak::Label l3078;
+Xbyak::Label l3118;
+Xbyak::Label l319c;
+Xbyak::Label l31c0;
+Xbyak::Label l31cc;
+Xbyak::Label l31ec;
+Xbyak::Label l32e4;
+Xbyak::Label l3378;
+Xbyak::Label l33dc;
+Xbyak::Label l3434;
+Xbyak::Label l347c;
+Xbyak::Label l349c;
+Xbyak::Label l34a8;
+Xbyak::Label l34c8;
+Xbyak::Label l3558;
+Xbyak::Label l35b0;
+Xbyak::Label l35f4;
+Xbyak::Label l3638;
+Xbyak::Label l366c;
+Xbyak::Label l368a;
+Xbyak::Label l3694;
+Xbyak::Label l36a8;
+Xbyak::Label l36ec;
+Xbyak::Label l3728;
+Xbyak::Label l3760;
+Xbyak::Label l3794;
+Xbyak::Label l37b8;
+Xbyak::Label l37d8;
+Xbyak::Label l5cc;
+Xbyak::Label l6c;
+Xbyak::Label l968;
+Xbyak::Label lc80;
+Xbyak::Label lf1c;
+Xbyak::Label lf64;
+Xbyak::Label lf70;
+Xbyak::Label lfb4;
+
+ preamble();
+ auto stacksize = get_size_of_abi_save_regs();
+#ifdef _WIN32
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(N, qword[N]);
+ mov(M, qword[M]);
+ mov(LDA, qword[LDA]);
+ sub(A, -128);
+ sub(B, -128);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ cmp(N, 0x30);
+ jl(lf64, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x5);
+ lea(I, ptr[I+LDA*8]);
+ lea(I, ptr[I+LDA*8]);
+ add(A, I);
+ vxorps(ymm8, ymm8, ymm8);
+ vxorps(ymm9, ymm9, ymm9);
+ vxorps(ymm10, ymm10, ymm10);
+ vxorps(ymm11, ymm11, ymm11);
+ vxorps(ymm12, ymm12, ymm12);
+ vxorps(ymm13, ymm13, ymm13);
+ vxorps(ymm14, ymm14, ymm14);
+ vxorps(ymm15, ymm15, ymm15);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l5cc, T_NEAR);
+ align(4);
+
+L(l6c);
+ vmovq(xmm0, qword[A1-0x80]);
+ vmovq(xmm1, qword[A1+LDA*1-0x80]);
+ vmovq(xmm2, qword[A1+LDA*2-0x80]);
+ vmovq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x80], xmm0);
+ vmovdqu(xword[B+0x40], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B-0x70], xmm2);
+ vmovdqu(xword[B+0x50], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x60], xmm0);
+ vmovdqu(xword[B+0x60], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B-0x50], xmm2);
+ vmovdqu(xword[B+0x70], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x40], xmm0);
+ vmovdqu(xword[B+0x80], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B-0x30], xmm2);
+ vmovdqu(xword[B+0x90], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x20], xmm0);
+ vmovdqu(xword[B+0xa0], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B-0x10], xmm2);
+ vmovdqu(xword[B+0xb0], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B], xmm0);
+ vmovdqu(xword[B+0xc0], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B+0x10], xmm2);
+ vmovdqu(xword[B+0xd0], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B+0x20], xmm0);
+ vmovdqu(xword[B+0xe0], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B+0x30], xmm2);
+ vmovdqu(xword[B+0xf0], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ sub(A1, -8);
+ sub(B, -384);
+ dec(I);
+ jg(l6c, T_NEAR);
+ align(4);
+
+L(l5cc);
+ test(M, 0x4);
+ jle(l968, T_NEAR);
+ vmovd(xmm0, dword[A1-0x80]);
+ vmovd(xmm1, dword[A1+LDA*1-0x80]);
+ vmovd(xmm2, dword[A1+LDA*2-0x80]);
+ vmovd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B-0x80], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x70], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B-0x60], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x50], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B-0x40], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x30], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B-0x20], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x10], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B+0x10], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B+0x20], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B+0x30], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ sub(A1, -4);
+ sub(B, -192);
+ align(4);
+
+L(l968);
+ test(M, 0x2);
+ jle(lc80, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vmovdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vmovdqu(xword[B-0x70], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vmovdqu(xword[B-0x60], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vmovdqu(xword[B-0x50], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vmovdqu(xword[B-0x40], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ vmovdqu(xword[B-0x30], xmm0);
+ sub(A1, -2);
+ sub(B, -96);
+ align(4);
+
+L(lc80);
+ test(M, 0x1);
+ jle(lf1c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xf);
+ vpmovsxbd(ymm7, xmm0);
+ vpaddd(ymm8, ymm8, ymm7);
+ vmovhlps(xmm7, xmm0, xmm0);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm9, ymm9, ymm7);
+ vmovdqu(xword[B-0x80], xmm0);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xf);
+ vpmovsxbd(ymm7, xmm0);
+ vpaddd(ymm10, ymm10, ymm7);
+ vmovhlps(xmm7, xmm0, xmm0);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm11, ymm11, ymm7);
+ vmovdqu(xword[B-0x70], xmm0);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xf);
+ vpmovsxbd(ymm7, xmm0);
+ vpaddd(ymm12, ymm12, ymm7);
+ vmovhlps(xmm7, xmm0, xmm0);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm13, ymm13, ymm7);
+ vmovdqu(xword[B-0x60], xmm0);
+ sub(B, -48);
+ align(4);
+
+L(lf1c);
+ mov(A1, qword[ARG_BIAS]);
+ vmovdqu(yword[A1], ymm8);
+ vmovdqu(yword[A1+0x20], ymm9);
+ vmovdqu(yword[A1+0x40], ymm10);
+ vmovdqu(yword[A1+0x60], ymm11);
+ vmovdqu(yword[A1+0x80], ymm12);
+ vmovdqu(yword[A1+0xa0], ymm13);
+ add(qword[ARG_BIAS], 0xc0);
+ sub(N, 0x30);
+ cmp(N, 0x30);
+ jge(l20, T_NEAR);
+ vzeroupper();
+ align(4);
+
+L(lf64);
+ cmp(N, 0x20);
+ jl(l22b8, T_NEAR);
+ align(4);
+
+L(lf70);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x5);
+ add(A, I);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ pxor(xmm10, xmm10);
+ pxor(xmm11, xmm11);
+ pxor(xmm12, xmm12);
+ pxor(xmm13, xmm13);
+ pxor(xmm14, xmm14);
+ pxor(xmm15, xmm15);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l1750, T_NEAR);
+ align(4);
+
+L(lfb4);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B+0x80], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B+0x100], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x10], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x90], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x110], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x20], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0xa0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x120], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x30], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0xb0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x130], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B+0x40], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B+0xc0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B+0x140], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B+0x50], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B+0xd0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B+0x150], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B+0x60], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B+0xe0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B+0x160], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B+0x70], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B+0xf0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B+0x170], xmm3);
+ sub(A1, -16);
+ sub(B, -512);
+ dec(I);
+ jg(lfb4, T_NEAR);
+ align(4);
+
+L(l1750);
+ test(M, 0x8);
+ jle(l1b6c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x10], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x20], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x30], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B+0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B+0x50], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B+0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B+0x70], xmm1);
+ sub(A1, -8);
+ sub(B, -256);
+ align(4);
+
+L(l1b6c);
+ test(M, 0x4);
+ jle(l1e14, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ sub(A1, -4);
+ sub(B, -128);
+ align(4);
+
+L(l1e14);
+ test(M, 0x2);
+ jle(l2068, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x70], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm13, xmm6);
+ movdqu(xword[B-0x60], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm15, xmm6);
+ movdqu(xword[B-0x50], xmm0);
+ sub(A1, -2);
+ sub(B, -64);
+ align(4);
+
+L(l2068);
+ test(M, 0x1);
+ jle(l226c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm8, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm12, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm13, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm15, xmm6);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l226c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ movdqu(xword[A1+0x20], xmm10);
+ movdqu(xword[A1+0x30], xmm11);
+ movdqu(xword[A1+0x40], xmm12);
+ movdqu(xword[A1+0x50], xmm13);
+ movdqu(xword[A1+0x60], xmm14);
+ movdqu(xword[A1+0x70], xmm15);
+ add(qword[ARG_BIAS], 0x80);
+ sub(N, 0x20);
+ cmp(N, 0x20);
+ jge(lf70, T_NEAR);
+ align(4);
+
+L(l22b8);
+ cmp(N, 0x10);
+ jl(l2c94, T_NEAR);
+ align(4);
+
+L(l22c4);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x4);
+ add(A, I);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ pxor(xmm10, xmm10);
+ pxor(xmm11, xmm11);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l26b4, T_NEAR);
+ align(4);
+
+L(l22f4);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x40], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B+0x40], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x30], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x10], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x50], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x20], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x20], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x60], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x10], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x30], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x70], xmm3);
+ sub(A1, -16);
+ sub(B, -256);
+ dec(I);
+ jg(l22f4, T_NEAR);
+ align(4);
+
+L(l26b4);
+ test(M, 0x8);
+ jle(l28cc, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x30], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x20], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x10], xmm1);
+ sub(A1, -8);
+ sub(B, -128);
+ align(4);
+
+L(l28cc);
+ test(M, 0x4);
+ jle(l2a2c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ sub(A1, -4);
+ sub(B, -64);
+ align(4);
+
+L(l2a2c);
+ test(M, 0x2);
+ jle(l2b5c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x70], xmm0);
+ sub(A1, -2);
+ sub(B, -32);
+ align(4);
+
+L(l2b5c);
+ test(M, 0x1);
+ jle(l2c64, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0xf);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm8, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l2c64);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ movdqu(xword[A1+0x20], xmm10);
+ movdqu(xword[A1+0x30], xmm11);
+ add(qword[ARG_BIAS], 0x40);
+ sub(N, 0x10);
+ cmp(N, 0x10);
+ jge(l22c4, T_NEAR);
+ align(4);
+
+L(l2c94);
+ cmp(N, 0x8);
+ jl(l31c0, T_NEAR);
+ align(4);
+
+L(l2ca0);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*4]);
+ lea(I, ptr[A1+LDA*8]);
+ mov(A, I);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l2eac, T_NEAR);
+ align(4);
+
+L(l2cc8);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ sub(A1, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x60], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x40], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x20], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x50], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x30], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x10], xmm3);
+ sub(B, -128);
+ dec(I);
+ jg(l2cc8, T_NEAR);
+ align(4);
+
+L(l2eac);
+ test(M, 0x8);
+ jle(l2fc0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ sub(A1, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ align(4);
+
+L(l2fc0);
+ test(M, 0x4);
+ jle(l3078, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ sub(A1, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l3078);
+ test(M, 0x2);
+ jle(l3118, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l3118);
+ test(M, 0x1);
+ jle(l319c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x7);
+ pmovsxbd(xmm5, xmm0);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm8, xmm5);
+ paddd(xmm9, xmm6);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l319c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ add(qword[ARG_BIAS], 0x20);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l2ca0, T_NEAR);
+ align(4);
+
+L(l31c0);
+ cmp(N, 0x4);
+ jl(l349c, T_NEAR);
+ align(4);
+
+L(l31cc);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*2]);
+ lea(I, ptr[A1+LDA*4]);
+ mov(A, I);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l32e4, T_NEAR);
+ align(4);
+
+L(l31ec);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm2, xword[A2-0x80]);
+ movdqu(xmm3, xword[A2+LDA*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x60], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x50], xmm3);
+ sub(B, -64);
+ dec(I);
+ jg(l31ec, T_NEAR);
+ align(4);
+
+L(l32e4);
+ test(M, 0x8);
+ jle(l3378, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ sub(A1, -8);
+ movq(xmm2, qword[A2-0x80]);
+ movq(xmm3, qword[A2+LDA*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l3378);
+ test(M, 0x4);
+ jle(l33dc, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ sub(A1, -4);
+ movd(xmm2, dword[A2-0x80]);
+ movd(xmm3, dword[A2+LDA*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l33dc);
+ test(M, 0x2);
+ jle(l3434, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x3);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l3434);
+ test(M, 0x1);
+ jle(l347c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l347c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x10);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l31cc, T_NEAR);
+ align(4);
+
+L(l349c);
+ cmp(N, 0x2);
+ jl(l368a, T_NEAR);
+ align(4);
+
+L(l34a8);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*1]);
+ lea(I, ptr[A1+LDA*2]);
+ mov(A, I);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l3558, T_NEAR);
+ align(4);
+
+L(l34c8);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm1, xword[A2-0x80]);
+ sub(A2, -16);
+ movdqa(xmm2, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm2, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pshufd(xmm6, xmm2, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ dec(I);
+ jg(l34c8, T_NEAR);
+ align(4);
+
+L(l3558);
+ test(M, 0x8);
+ jle(l35b0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(xmm1, qword[A2-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l35b0);
+ test(M, 0x4);
+ jle(l35f4, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(xmm1, dword[A2-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l35f4);
+ test(M, 0x2);
+ jle(l3638, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l3638);
+ test(M, 0x1);
+ jle(l366c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ align(4);
+
+L(l366c);
+ mov(A1, qword[ARG_BIAS]);
+ movq(qword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l34a8, T_NEAR);
+ align(4);
+
+L(l368a);
+ cmp(N, 0x1);
+ jl(l37d8, T_NEAR);
+ align(4);
+
+L(l3694);
+ mov(A1, A);
+ add(A, LDA);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l36ec, T_NEAR);
+ align(4);
+
+L(l36a8);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(I);
+ jg(l36a8, T_NEAR);
+ align(4);
+
+L(l36ec);
+ test(M, 0x8);
+ jle(l3728, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l3728);
+ test(M, 0x4);
+ jle(l3760, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l3760);
+ test(M, 0x2);
+ jle(l3794, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ mov(word[B-0x80], ax);
+ sub(A1, -2);
+ sub(B, -2);
+ align(4);
+
+L(l3794);
+ test(M, 0x1);
+ jle(l37b8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l37b8);
+ mov(A1, qword[ARG_BIAS]);
+ movd(dword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x4);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l3694, T_NEAR);
+ align(4);
+
+L(l37d8);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+#undef ARG_BIAS
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp
new file mode 100644
index 0000000000..c7f1393c9d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp
@@ -0,0 +1,821 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_sum_bn_kern::jit_avx512_core_u8_copy_sum_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#define ARG_BIAS 24+stacksize+rsp
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+#define ARG_BIAS 72+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l20;
+Xbyak::Label l22c;
+Xbyak::Label l340;
+Xbyak::Label l3f8;
+Xbyak::Label l48;
+Xbyak::Label l498;
+Xbyak::Label l51c;
+Xbyak::Label l540;
+Xbyak::Label l54c;
+Xbyak::Label l56c;
+Xbyak::Label l664;
+Xbyak::Label l6f8;
+Xbyak::Label l75c;
+Xbyak::Label l7b4;
+Xbyak::Label l7fc;
+Xbyak::Label l81c;
+Xbyak::Label l828;
+Xbyak::Label l848;
+Xbyak::Label l8d8;
+Xbyak::Label l930;
+Xbyak::Label l974;
+Xbyak::Label l9b8;
+Xbyak::Label l9ec;
+Xbyak::Label la0a;
+Xbyak::Label la14;
+Xbyak::Label la28;
+Xbyak::Label la6c;
+Xbyak::Label laa8;
+Xbyak::Label lae0;
+Xbyak::Label lb14;
+Xbyak::Label lb38;
+Xbyak::Label lb58;
+
+ preamble();
+ auto stacksize = get_size_of_abi_save_regs();
+#ifdef _WIN32
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(N, qword[N]);
+ mov(M, qword[M]);
+ mov(LDA, qword[LDA]);
+ sub(A, -128);
+ sub(B, -128);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ cmp(N, 0x8);
+ jl(l540, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*4]);
+ lea(I, ptr[A1+LDA*8]);
+ mov(A, I);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l22c, T_NEAR);
+ align(4);
+
+L(l48);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ sub(A1, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x60], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x40], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x20], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x50], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x30], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x10], xmm3);
+ sub(B, -128);
+ dec(I);
+ jg(l48, T_NEAR);
+ align(4);
+
+L(l22c);
+ test(M, 0x8);
+ jle(l340, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ sub(A1, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ align(4);
+
+L(l340);
+ test(M, 0x4);
+ jle(l3f8, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ sub(A1, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l3f8);
+ test(M, 0x2);
+ jle(l498, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l498);
+ test(M, 0x1);
+ jle(l51c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x7);
+ pmovsxbd(xmm5, xmm0);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm8, xmm5);
+ paddd(xmm9, xmm6);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l51c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ add(qword[ARG_BIAS], 0x20);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l540);
+ cmp(N, 0x4);
+ jl(l81c, T_NEAR);
+ align(4);
+
+L(l54c);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*2]);
+ lea(I, ptr[A1+LDA*4]);
+ mov(A, I);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l664, T_NEAR);
+ align(4);
+
+L(l56c);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm2, xword[A2-0x80]);
+ movdqu(xmm3, xword[A2+LDA*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x60], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x50], xmm3);
+ sub(B, -64);
+ dec(I);
+ jg(l56c, T_NEAR);
+ align(4);
+
+L(l664);
+ test(M, 0x8);
+ jle(l6f8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ sub(A1, -8);
+ movq(xmm2, qword[A2-0x80]);
+ movq(xmm3, qword[A2+LDA*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l6f8);
+ test(M, 0x4);
+ jle(l75c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ sub(A1, -4);
+ movd(xmm2, dword[A2-0x80]);
+ movd(xmm3, dword[A2+LDA*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l75c);
+ test(M, 0x2);
+ jle(l7b4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x3);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l7b4);
+ test(M, 0x1);
+ jle(l7fc, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l7fc);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x10);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l54c, T_NEAR);
+ align(4);
+
+L(l81c);
+ cmp(N, 0x2);
+ jl(la0a, T_NEAR);
+ align(4);
+
+L(l828);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*1]);
+ lea(I, ptr[A1+LDA*2]);
+ mov(A, I);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l8d8, T_NEAR);
+ align(4);
+
+L(l848);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm1, xword[A2-0x80]);
+ sub(A2, -16);
+ movdqa(xmm2, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm2, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pshufd(xmm6, xmm2, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ dec(I);
+ jg(l848, T_NEAR);
+ align(4);
+
+L(l8d8);
+ test(M, 0x8);
+ jle(l930, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(xmm1, qword[A2-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l930);
+ test(M, 0x4);
+ jle(l974, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(xmm1, dword[A2-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l974);
+ test(M, 0x2);
+ jle(l9b8, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l9b8);
+ test(M, 0x1);
+ jle(l9ec, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ align(4);
+
+L(l9ec);
+ mov(A1, qword[ARG_BIAS]);
+ movq(qword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l828, T_NEAR);
+ align(4);
+
+L(la0a);
+ cmp(N, 0x1);
+ jl(lb58, T_NEAR);
+ align(4);
+
+L(la14);
+ mov(A1, A);
+ add(A, LDA);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(la6c, T_NEAR);
+ align(4);
+
+L(la28);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(I);
+ jg(la28, T_NEAR);
+ align(4);
+
+L(la6c);
+ test(M, 0x8);
+ jle(laa8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(laa8);
+ test(M, 0x4);
+ jle(lae0, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(lae0);
+ test(M, 0x2);
+ jle(lb14, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ mov(word[B-0x80], ax);
+ sub(A1, -2);
+ sub(B, -2);
+ align(4);
+
+L(lb14);
+ test(M, 0x1);
+ jle(lb38, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(lb38);
+ mov(A1, qword[ARG_BIAS]);
+ movd(dword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x4);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(la14, T_NEAR);
+ align(4);
+
+L(lb58);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+#undef ARG_BIAS
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp
new file mode 100644
index 0000000000..afe4f1713e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp
@@ -0,0 +1,647 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_sum_bt_kern::jit_avx512_core_u8_copy_sum_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#define ARG_BIAS 24+stacksize+rsp
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+#define ARG_BIAS 72+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l15c;
+Xbyak::Label l1f4;
+Xbyak::Label l20;
+Xbyak::Label l248;
+Xbyak::Label l280;
+Xbyak::Label l2a4;
+Xbyak::Label l2b0;
+Xbyak::Label l2c8;
+Xbyak::Label l384;
+Xbyak::Label l3e8;
+Xbyak::Label l40;
+Xbyak::Label l424;
+Xbyak::Label l448;
+Xbyak::Label l468;
+Xbyak::Label l474;
+Xbyak::Label l48c;
+Xbyak::Label l550;
+Xbyak::Label l5bc;
+Xbyak::Label l600;
+Xbyak::Label l628;
+Xbyak::Label l646;
+Xbyak::Label l650;
+Xbyak::Label l668;
+Xbyak::Label l700;
+Xbyak::Label l760;
+Xbyak::Label l7a4;
+Xbyak::Label l7c8;
+Xbyak::Label l7e8;
+
+ preamble();
+ auto stacksize = get_size_of_abi_save_regs();
+#ifdef _WIN32
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(LDA, qword[LDA]);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ sub(A, -128);
+ sub(B, -128);
+ cmp(N, 0x8);
+ jl(l2a4, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ add(A, 0x8);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l15c, T_NEAR);
+ align(4);
+
+L(l40);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ dec(I);
+ jg(l40, T_NEAR);
+ align(4);
+
+L(l15c);
+ test(M, 0x4);
+ jle(l1f4, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l1f4);
+ test(M, 0x2);
+ jle(l248, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l248);
+ test(M, 0x1);
+ jle(l280, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ pmovsxbd(xmm5, xmm0);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm8, xmm5);
+ paddd(xmm9, xmm6);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l280);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ add(qword[ARG_BIAS], 0x20);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l2a4);
+ cmp(N, 0x4);
+ jl(l468, T_NEAR);
+ align(4);
+
+L(l2b0);
+ mov(A1, A);
+ add(A, 0x4);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l384, T_NEAR);
+ align(4);
+
+L(l2c8);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ dec(I);
+ jg(l2c8, T_NEAR);
+ align(4);
+
+L(l384);
+ test(M, 0x4);
+ jle(l3e8, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l3e8);
+ test(M, 0x2);
+ jle(l424, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l424);
+ test(M, 0x1);
+ jle(l448, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l448);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x10);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l2b0, T_NEAR);
+ align(4);
+
+L(l468);
+ cmp(N, 0x2);
+ jl(l646, T_NEAR);
+ align(4);
+
+L(l474);
+ mov(A1, A);
+ add(A, 0x2);
+ pxor(xmm7, xmm7);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l550, T_NEAR);
+ align(4);
+
+L(l48c);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm4, eax, 0x0);
+ punpcklbw(xmm1, xmm2);
+ punpcklbw(xmm3, xmm4);
+ punpcklwd(xmm1, xmm3);
+ punpcklqdq(xmm0, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(LDA3);
+ jg(l48c, T_NEAR);
+ align(4);
+
+L(l550);
+ test(M, 0x4);
+ jle(l5bc, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l5bc);
+ test(M, 0x2);
+ jle(l600, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l600);
+ test(M, 0x1);
+ jle(l628, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(word[B-0x80], ax);
+ sub(B, -2);
+ align(4);
+
+L(l628);
+ mov(A1, qword[ARG_BIAS]);
+ movq(qword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l474, T_NEAR);
+ align(4);
+
+L(l646);
+ cmp(N, 0x1);
+ jl(l7e8, T_NEAR);
+ align(4);
+
+L(l650);
+ mov(A1, A);
+ add(A, 0x1);
+ pxor(xmm7, xmm7);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l700, T_NEAR);
+ align(4);
+
+L(l668);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ dec(LDA3);
+ jg(l668, T_NEAR);
+ align(4);
+
+L(l700);
+ test(M, 0x4);
+ jle(l760, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l760);
+ test(M, 0x2);
+ jle(l7a4, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l7a4);
+ test(M, 0x1);
+ jle(l7c8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l7c8);
+ mov(A1, qword[ARG_BIAS]);
+ movd(dword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x4);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l650, T_NEAR);
+ align(4);
+
+L(l7e8);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+#undef ARG_BIAS
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp
new file mode 100644
index 0000000000..4fc11afcbc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp
@@ -0,0 +1,116 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <cstdint>
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "../f32/ref_gemm_f32.hpp"
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <typename b_dt>
+mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co) {
+
+ if (*M == 0 || *N == 0 || *K == 0)
+ return mkldnn_success;
+
+ bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
+ bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
+ bool AisN = (*transa == 'N' || *transa == 'n');
+ bool BisN = (*transb == 'N' || *transb == 'n');
+
+ int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC;
+ size_t sizeA = AisN ? lda * k : lda * m;
+ size_t sizeB = BisN ? ldb * n : ldb * k;
+ size_t sizeC = ldc * n;
+
+ double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K);
+ double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K);
+ double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K);
+
+ if (utils::any_null(dA, dB, dC)) {
+ free(dA);
+ free(dB);
+ free(dC);
+ return mkldnn_out_of_memory;
+ }
+
+ auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; };
+ auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; };
+
+ auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; };
+ auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; };
+
+ const int a_rows = AisN ? m : k;
+ const int a_cols = AisN ? k : m;
+ mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) {
+ da_setter(i, j,
+ static_cast<double>(ia_accessor(i, j)) + static_cast<double>(ao[0]));
+ });
+
+ const int b_rows = BisN ? k : n;
+ const int b_cols = BisN ? n : k;
+ mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) {
+ db_setter(i, j,
+ static_cast<double>(ib_accessor(i, j)) + static_cast<double>(bo[0]));
+ });
+ double one = 1.0, zero = 0.0;
+ ref_gemm<double>(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero,
+ dC, LDC, nullptr);
+
+ auto i2d = [=] (int32_t v) { return static_cast<double>(v); };
+ auto f2d = [=] (float v) { return static_cast<double>(v); };
+
+ mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) {
+ double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]);
+ double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc]))
+ + f2d(*alpha) * dC[i + j * ldc] + coffset;
+ C[i + j * ldc] = math::out_round<int32_t>(math::saturate<int32_t>(val));
+ });
+
+ free(dA);
+ free(dB);
+ free(dC);
+ return mkldnn_success;
+}
+
+template mkldnn_status_t ref_gemm_s8x8s32<uint8_t>(
+ const char *transa, const char *transb, const char *offsetc,
+ const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const uint8_t *B, const int *LDB, const int8_t *bo,
+ const float *beta, int32_t *C, const int *LDC, const int32_t *co);
+
+template mkldnn_status_t ref_gemm_s8x8s32<int8_t>(
+ const char *transa, const char *transb, const char *offsetc,
+ const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const int8_t *B, const int *LDB, const int8_t *bo,
+ const float *beta, int32_t *C, const int *LDC, const int32_t *co);
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp
new file mode 100644
index 0000000000..6c0370ae99
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp
@@ -0,0 +1,38 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef REF_GEMM_S8X8S32_HPP
+#define REF_GEMM_S8X8S32_HPP
+
+#include <stdint.h>
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <typename b_dt>
+mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co);
+
+}
+}
+}
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp
new file mode 100644
index 0000000000..de1035f3b2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp
@@ -0,0 +1,180 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "common.hpp"
+#include "nstl.hpp"
+#include "math_utils.hpp"
+
+#include "../gemm.hpp"
+#include "jit_avx512_core_gemm_s8u8s32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+void compensation_init(const char *offsetC, int32_t *compensation, int len,
+ const int32_t *oc) {
+ bool OCisC = (*offsetC == 'C' || *offsetC == 'c');
+ bool OCisF = (*offsetC == 'F' || *offsetC == 'f');
+
+ if (OCisF && (*oc) != 0) {
+ for (int i = 0; i < len; i++)
+ compensation[i] = *oc;
+ } else if (OCisC) {
+ for (int i = 0; i < len; i++)
+ compensation[i] = oc[i];
+ } else {
+ parallel_nd(len, [=](int i) { compensation[i] = 0; });
+ }
+}
+
+void compensation_compute(bool transa, int m, int k, float alpha,
+ const int8_t *a, int lda, int32_t *compensation) {
+ if (!transa) {
+ const int L2_cache_size = get_cache_size(2, true);
+ const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1);
+ const int npanels = k / blocking_factor;
+ const bool has_tile = k % blocking_factor > 0;
+
+ parallel_nd(npanels, m, [&](int j, int i) {
+ int32_t val = 0;
+ for (int jb = 0; jb < blocking_factor; jb++) {
+ val += a[(i + (ptrdiff_t)j * blocking_factor * lda)
+ + (ptrdiff_t)jb * lda];
+ }
+ if (alpha != 1.0f) {
+ val = math::out_round<int32_t>(math::saturate<int32_t>(
+ (double)val * alpha * -128.0));
+ } else {
+ val *= -128;
+ }
+ fetch_and_add(&compensation[i], val);
+ });
+
+ if (has_tile) {
+ parallel_nd(m, [=](int i) {
+ int32_t val = 0;
+ for (int j = npanels * blocking_factor; j < k; j++) {
+ val += a[i + (ptrdiff_t)j * lda];
+ }
+ if (alpha != 1.0f) {
+ val = math::out_round<int32_t>(math::saturate<int32_t>(
+ (double)val * alpha * -128.0));
+ } else {
+ val *= -128;
+ }
+ fetch_and_add(&compensation[i], val);
+ });
+ }
+ } else {
+ parallel_nd(m, [=](int i) {
+ int32_t val = 0;
+ for (int j = 0; j < k; j++) {
+ val += a[j + (ptrdiff_t)i * lda];
+ }
+ if (alpha != 1.0f) {
+ val = math::out_round<int32_t>(math::saturate<int32_t>(
+ (double)val * alpha * -128.0));
+ } else {
+ val *= -128;
+ }
+ compensation[i] += val;
+ });
+ }
+}
+
+void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8,
+ const int8_t *b_s8, int ldb_s8) {
+ const int b_cols = transb ? k : n;
+
+ parallel_nd(b_cols, [=](int j) {
+ const int b_rows = transb ? n : k;
+
+ uint8_t *pb_u8 = b_u8 + j * ldb_u8;
+ const int8_t *pb_s8 = b_s8 + j * ldb_s8;
+
+ for (int i = 0; i < b_rows; i++) {
+ (*pb_u8) = (*pb_s8) + 128;
+ pb_u8++;
+ pb_s8++;
+ }
+ });
+}
+
+/**
+ * gemm_s8s8s32 operation is defined as follows:
+ * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation
+ *
+ * where
+ * - compensation is a vector of length m that contains computed compensation
+ * that may contain C_offset if applicable. The compensation is applied inside
+ * gemm_s8u8s32 as a C_offset
+ * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128
+ *
+ * What is the compensation:
+ * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied:
+ * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset =
+ * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset
+ * compensation = -alpha * op(A) * B_shift
+ * Since B_shift is a matrix, every element of which is equal to 128 then
+ * - if op(A) = A: compensation contains sum of the elements in each row
+ * scaled by -128 * alpha
+ * - if op(A) = A**T: compensation contains sum of the elements in each column
+ * scaled by -128 * alpha
+ *
+ * The rest of parameters is described in mkldnn.h
+ */
+mkldnn_status_t simple_gemm_s8s8s32(
+ const char *transA, const char *transB, const char *offsetC,
+ const int *m, const int *n, const int *k,
+ const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
+ const int8_t *b, const int *ldb, const int8_t *ob,
+ const float *beta, int32_t *c, const int *ldc, const int32_t *oc) {
+ if (*oa != 0 || *ob != 0) return mkldnn_unimplemented;
+
+ int M = *m, N = *n, K = *k;
+ bool transa = (*transA == 'T' || *transA == 't');
+ bool transb = (*transB == 'T' || *transB == 't');
+ int ld = transb ? N : K;
+
+ uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64);
+ int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64);
+
+ if (utils::any_null(b_u8, compensation)) {
+ free(b_u8);
+ free(compensation);
+ return mkldnn_out_of_memory;
+ }
+
+ compensation_init(offsetC, compensation, M, oc);
+ compensation_compute(transa, M, K, *alpha, a, *lda, compensation);
+ copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb);
+
+ gemm_s8x8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8,
+ &ld, ob, beta, c, ldc, compensation);
+
+ if ((*offsetC == 'R' || *offsetC == 'r'))
+ parallel_nd(M, N,
+ [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; });
+
+ free(b_u8);
+ free(compensation);
+
+ return mkldnn_success;
+}
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp
new file mode 100644
index 0000000000..03a3d2f7e0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp
@@ -0,0 +1,37 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SIMPLE_GEMM_S8S8S32_HPP
+#define SIMPLE_GEMM_S8S8S32_HPP
+
+#include <stdint.h>
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t simple_gemm_s8s8s32(
+ const char *transA, const char *transB, const char *offsetC,
+ const int *m, const int *n, const int *k,
+ const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
+ const int8_t *b, const int *ldb, const int8_t *ob,
+ const float *beta, int32_t *c, const int *ldc, const int32_t *oc);
+}
+}
+}
+
+#endif // SIMPLE_GEMM_S8S8S32_HPP