summaryrefslogtreecommitdiff
path: root/thirdparty/oidn/mkl-dnn/src
diff options
context:
space:
mode:
Diffstat (limited to 'thirdparty/oidn/mkl-dnn/src')
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp104
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp240
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp550
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/concat.cpp86
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp211
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/convolution.cpp200
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp56
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp348
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp188
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp293
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp84
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp161
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/engine.cpp75
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/engine.hpp119
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp106
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp56
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp321
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/lrn.cpp91
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp170
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp280
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory.cpp238
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory.hpp63
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp212
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp400
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp295
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp131
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp365
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp115
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp277
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp77
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/nstl.hpp193
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/pooling.cpp114
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp238
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive.cpp103
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive.hpp76
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp290
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp183
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp78
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp174
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp90
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp68
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp89
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp79
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/query.cpp59
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/reorder.cpp68
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp85
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/rnn.cpp400
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp280
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp112
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp36
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp72
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp121
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/softmax.cpp68
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp161
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/stream.cpp46
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/stream.hpp44
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/sum.cpp79
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp143
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp200
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp348
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/utils.cpp135
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/utils.hpp370
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/verbose.cpp665
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/verbose.hpp62
-rw-r--r--thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp46
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp112
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp60
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp40
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp140
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp43
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp51
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp41
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp74
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp46
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp45
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp324
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp70
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp84
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp151
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp42
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp277
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp89
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp40
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp83
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp544
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp334
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp262
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp48
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp41
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp45
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp48
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp39
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp372
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp72
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp2131
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp36
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp2705
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp37
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp346
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp36
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp280
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp58
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp86
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp206
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp28
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp1409
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp38
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp539
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp101
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp290
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp411
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp64
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp819
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp2209
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp564
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp501
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp1283
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp3163
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp821
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp647
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp116
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp38
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp180
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp37
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp307
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp250
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp771
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp66
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp156
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp157
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp740
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp266
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp453
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp166
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp674
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp110
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp545
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp344
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp1501
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp225
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp410
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp302
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp1255
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp108
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp816
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp344
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp4539
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp423
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp1163
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp179
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp1526
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp302
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp1215
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp318
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp853
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp96
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp1103
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp144
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp1020
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp386
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp2596
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp291
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp1284
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp128
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp820
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp131
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp292
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp159
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp140
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp1182
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp239
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp423
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp115
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp1034
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp237
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp773
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp481
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp677
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp104
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp134
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp96
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp497
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp93
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp136
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp103
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp1192
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp145
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp327
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp1407
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp100
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp1302
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp253
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp427
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp266
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp1142
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp193
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp949
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp89
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp305
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp103
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp1487
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp183
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp699
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp192
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp264
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp182
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp1006
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp127
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp313
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp115
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp32
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD27
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md1
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h595
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h94
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c293
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h673
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp317
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp147
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp382
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp160
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp392
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp210
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp288
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp169
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp265
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp127
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp97
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp395
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp194
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp199
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp502
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp297
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp168
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp285
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp159
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp252
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp136
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp381
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp119
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp153
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp111
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp264
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp186
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp101
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp90
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp180
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp170
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp143
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp113
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp191
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp401
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp788
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp328
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp380
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp426
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp225
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp126
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp155
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp98
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp1022
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp91
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp74
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp376
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT47
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h2658
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h303
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h2017
-rw-r--r--thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h772
269 files changed, 101594 insertions, 0 deletions
diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp
new file mode 100644
index 0000000000..1a51d8562b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp
@@ -0,0 +1,104 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc,
+ prop_kind_t prop_kind, const memory_desc_t *data_desc,
+ const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) {
+ bool args_ok = true
+ && !any_null(bnrm_desc, data_desc)
+ && one_of(prop_kind, forward_training, forward_inference,
+ backward_data, backward)
+ && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr);
+ if (!args_ok) return invalid_arguments;
+
+ auto bd = batch_normalization_desc_t();
+ bd.primitive_kind = primitive_kind::batch_normalization;
+ bd.prop_kind = prop_kind;
+
+ bd.data_desc = *data_desc;
+ bd.diff_data_desc = zero_md();
+ if ( one_of(bd.prop_kind,backward_data, backward) )
+ bd.diff_data_desc = *diff_data_desc;
+
+ dims_t scaleshift_dims = { 2, data_desc->dims[1] };
+ mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2,
+ scaleshift_dims, data_type::f32, mkldnn_nc);
+ bd.diff_data_scaleshift_desc = zero_md();
+ if (bd.prop_kind == backward) {
+ bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc;
+ }
+
+ dims_t stats_dims = { data_desc->dims[1] };
+ mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims,
+ data_type::f32, mkldnn_x);
+ bd.variance_desc = bd.mean_desc;
+ bd.batch_norm_epsilon = epsilon;
+
+ unsigned bnorm_flags =
+ mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu;
+ if ((~bnorm_flags & flags) != 0) return invalid_arguments;
+
+ bd.flags = flags;
+
+ bool consistency = true
+ && utils::one_of(bd.data_desc.ndims, 2, 4, 5);
+ if (bd.prop_kind == backward_data)
+ consistency = consistency
+ && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5)
+ && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims,
+ bd.diff_data_desc.ndims);
+ if (!consistency) return invalid_arguments;
+
+ *bnrm_desc = bd;
+ return success;
+}
+}
+
+status_t mkldnn_batch_normalization_forward_desc_init(
+ batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
+ const memory_desc_t *data_desc, float epsilon, unsigned flags) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr,
+ epsilon, flags);
+}
+
+status_t mkldnn_batch_normalization_backward_desc_init(
+ batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
+ const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
+ float epsilon, unsigned flags) {
+ if (!one_of(prop_kind, backward, backward_data))
+ return invalid_arguments;
+ return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc,
+ epsilon, flags);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp
new file mode 100644
index 0000000000..f61410b33c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp
@@ -0,0 +1,240 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef BATCH_NORMALIZATION_PD_HPP
+#define BATCH_NORMALIZATION_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct batch_normalization_fwd_pd_t;
+
+struct batch_normalization_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::batch_normalization;
+
+ batch_normalization_pd_t(engine_t *engine,
+ const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ , stat_md_(desc_.mean_desc)
+ , scaleshift_md_(desc_.data_scaleshift_desc)
+ , ws_md_()
+ {}
+
+ const batch_normalization_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::batch_normalization_d:
+ *(const batch_normalization_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common batch_normalization aux functions */
+
+ dim_t MB() const { return data_desc().dims[0]; }
+ dim_t C() const { return data_desc().dims[1]; }
+ dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+ int ndims() const { return desc_.data_desc.ndims; }
+
+ bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; }
+ bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; }
+ bool use_global_stats() const
+ { return desc_.flags & mkldnn_use_global_stats; }
+ bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; }
+ bool with_relu_post_op() const {
+ const auto &p = this->attr()->post_ops_;
+ return p.len_ == 1 && p.entry_[0].is_relu(true, true);
+ }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+ bool is_bwd() const { return !this->is_fwd(); }
+ bool is_training() const
+ { return desc_.prop_kind == prop_kind::forward_training; }
+
+ bool has_zero_dim_memory() const
+ { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
+
+protected:
+ batch_normalization_desc_t desc_;
+ const batch_normalization_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t data_md_;
+ memory_desc_t stat_md_;
+ memory_desc_t scaleshift_md_;
+
+ memory_desc_t ws_md_;
+
+ void init_default_ws(size_t bits_per_element) {
+ const auto data_mdw = memory_desc_wrapper(data_md_);
+
+ const dim_t data_nelems = data_mdw.nelems(true);
+ const dim_t bits_per_byte = 8;
+ const dims_t ws_sz = { (dim_t)utils::div_up(
+ data_nelems * bits_per_element, bits_per_byte) };
+ mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8,
+ format_tag::x);
+ }
+
+private:
+ const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t {
+ typedef batch_normalization_fwd_pd_t base_class;
+ typedef batch_normalization_fwd_pd_t hint_class;
+
+ batch_normalization_fwd_pd_t(engine_t *engine,
+ const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input;
+ if (arg == MKLDNN_ARG_DST) return arg_usage_t::output;
+
+ if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) {
+ if (stats_is_src()) return arg_usage_t::input;
+ if (!stats_is_src() && is_training()) return arg_usage_t::output;
+ return arg_usage_t::unused;
+ }
+
+ if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override {
+ if (index == 0) return &data_md_;
+ if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_;
+ return nullptr;
+ }
+
+ virtual const memory_desc_t *dst_md(int index = 0) const override {
+ if (index == 0) return &data_md_;
+ if (!stats_is_src() && is_training() && (index == 1 || index == 2))
+ return &stat_md_;
+ return nullptr;
+ }
+
+ virtual const memory_desc_t *weights_md(int index = 0) const override
+ { return index == 0 ? &scaleshift_md_ : nullptr; }
+
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; }
+
+ const memory_desc_t *stat_md() const
+ { return stats_is_src() ? src_md(1) : dst_md(1); }
+
+ virtual int n_inputs() const override
+ { return 1 + 2 * stats_is_src() + use_scaleshift(); }
+ virtual int n_outputs() const override
+ { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); }
+};
+
+struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t {
+ typedef batch_normalization_bwd_pd_t base_class;
+ typedef batch_normalization_fwd_pd_t hint_class;
+
+ batch_normalization_bwd_pd_t(engine_t *engine,
+ const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_data_md_(desc_.diff_data_desc)
+ , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN,
+ MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+
+ virtual const memory_desc_t *weights_md(int index = 0) const override
+ { return index == 0 ? &scaleshift_md_ : nullptr; }
+ virtual const memory_desc_t *diff_weights_md(int index = 0) const override
+ { return index == 0 ? &diff_scaleshift_md_ : nullptr; }
+
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; }
+
+ const memory_desc_t *stat_md() const { return src_md(1); }
+
+ virtual int n_inputs() const override
+ { return 4 + use_scaleshift() + fuse_bn_relu(); }
+ virtual int n_outputs() const override
+ { return 1 + (desc_.prop_kind == prop_kind::backward); }
+
+protected:
+ memory_desc_t diff_data_md_;
+ memory_desc_t diff_scaleshift_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
new file mode 100644
index 0000000000..3d43a0fbee
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
@@ -0,0 +1,550 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef TYPE_MAPPING_HPP
+#define TYPE_MAPPING_HPP
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+
+// TODO: autogenerate this
+
+using dim_t = mkldnn_dim_t;
+using dims_t = mkldnn_dims_t;
+using stride_t = mkldnn_dim_t;
+using strides_t = mkldnn_strides_t;
+
+using status_t = mkldnn_status_t;
+namespace status {
+ const status_t success = mkldnn_success;
+ const status_t out_of_memory = mkldnn_out_of_memory;
+ const status_t try_again = mkldnn_try_again;
+ const status_t invalid_arguments = mkldnn_invalid_arguments;
+ const status_t not_ready = mkldnn_not_ready;
+ const status_t unimplemented = mkldnn_unimplemented;
+ const status_t iterator_ends = mkldnn_iterator_ends;
+ const status_t runtime_error = mkldnn_runtime_error;
+ const status_t not_required = mkldnn_not_required;
+}
+
+using prop_kind_t = mkldnn_prop_kind_t;
+namespace prop_kind {
+ const prop_kind_t undef = mkldnn_prop_kind_undef;
+ const prop_kind_t forward_training = mkldnn_forward_training;
+ const prop_kind_t forward_inference = mkldnn_forward_inference;
+ const prop_kind_t forward_scoring = mkldnn_forward_scoring;
+ const prop_kind_t forward = mkldnn_forward;
+ const prop_kind_t backward = mkldnn_backward;
+ const prop_kind_t backward_data = mkldnn_backward_data;
+ const prop_kind_t backward_weights = mkldnn_backward_weights;
+ const prop_kind_t backward_bias = mkldnn_backward_bias;
+}
+
+using alg_kind_t = mkldnn_alg_kind_t;
+namespace alg_kind {
+ const alg_kind_t undef = mkldnn_alg_kind_undef;
+ const alg_kind_t convolution_auto = mkldnn_convolution_auto;
+ const alg_kind_t convolution_direct = mkldnn_convolution_direct;
+ const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
+ const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct;
+ const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd;
+ const alg_kind_t eltwise_relu = mkldnn_eltwise_relu;
+ const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh;
+ const alg_kind_t eltwise_elu = mkldnn_eltwise_elu;
+ const alg_kind_t eltwise_square = mkldnn_eltwise_square;
+ const alg_kind_t eltwise_abs = mkldnn_eltwise_abs;
+ const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt;
+ const alg_kind_t eltwise_linear = mkldnn_eltwise_linear;
+ const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu;
+ const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu;
+ const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic;
+ const alg_kind_t pooling_max = mkldnn_pooling_max;
+ const alg_kind_t pooling_avg = mkldnn_pooling_avg;
+ const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding;
+ const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding;
+ const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels;
+ const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel;
+ const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn;
+ const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm;
+ const alg_kind_t vanilla_gru = mkldnn_vanilla_gru;
+ const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset;
+}
+
+using data_type_t = mkldnn_data_type_t;
+namespace data_type {
+ const data_type_t undef = mkldnn_data_type_undef;
+ const data_type_t f32 = mkldnn_f32;
+ const data_type_t s32 = mkldnn_s32;
+ const data_type_t s8 = mkldnn_s8;
+ const data_type_t u8 = mkldnn_u8;
+}
+
+using scratchpad_mode_t = mkldnn_scratchpad_mode_t;
+namespace scratchpad_mode {
+ const scratchpad_mode_t library = mkldnn_scratchpad_mode_library;
+ const scratchpad_mode_t user = mkldnn_scratchpad_mode_user;
+}
+
+using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t;
+namespace rnn_packed_format {
+ const rnn_packed_format_t undef = mkldnn_packed_format_undef;
+ const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p;
+ const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p;
+}
+
+using format_kind_t = mkldnn_format_kind_t;
+namespace format_kind {
+ const format_kind_t undef = mkldnn_format_kind_undef;
+ const format_kind_t any = mkldnn_format_kind_any;
+ const format_kind_t blocked = mkldnn_blocked;
+ const format_kind_t wino = mkldnn_format_kind_wino;
+ const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed;
+}
+
+using format_tag_t = mkldnn_format_tag_t;
+namespace format_tag {
+ const format_tag_t undef = mkldnn_format_tag_undef;
+ const format_tag_t any = mkldnn_format_tag_any;
+ const format_tag_t a = mkldnn_a;
+ const format_tag_t ab = mkldnn_ab;
+ const format_tag_t abc = mkldnn_abc;
+ const format_tag_t abcd = mkldnn_abcd;
+ const format_tag_t abcde = mkldnn_abcde;
+ const format_tag_t abcdef = mkldnn_abcdef;
+ const format_tag_t abdec = mkldnn_abdec;
+ const format_tag_t acb = mkldnn_acb;
+ const format_tag_t acbde = mkldnn_acbde;
+ const format_tag_t acdb = mkldnn_acdb;
+ const format_tag_t acdeb = mkldnn_acdeb;
+ const format_tag_t ba = mkldnn_ba;
+ const format_tag_t bac = mkldnn_bac;
+ const format_tag_t bacd = mkldnn_bacd;
+ const format_tag_t bcda = mkldnn_bcda;
+ const format_tag_t cba = mkldnn_cba;
+ const format_tag_t cdba = mkldnn_cdba;
+ const format_tag_t cdeba = mkldnn_cdeba;
+ const format_tag_t decab = mkldnn_decab;
+ const format_tag_t Abc16a = mkldnn_Abc16a;
+ const format_tag_t ABc16a16b = mkldnn_ABc16a16b;
+ const format_tag_t aBc16b = mkldnn_aBc16b;
+ const format_tag_t ABc16b16a = mkldnn_ABc16b16a;
+ const format_tag_t Abc4a = mkldnn_Abc4a;
+ const format_tag_t aBc4b = mkldnn_aBc4b;
+ const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b;
+ const format_tag_t ABc4b4a = mkldnn_ABc4b4a;
+ const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a;
+ const format_tag_t ABc8a8b = mkldnn_ABc8a8b;
+ const format_tag_t aBc8b = mkldnn_aBc8b;
+ const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b;
+ const format_tag_t ABc8b8a = mkldnn_ABc8b8a;
+ const format_tag_t Abcd16a = mkldnn_Abcd16a;
+ const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b;
+ const format_tag_t aBcd16b = mkldnn_aBcd16b;
+ const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a;
+ const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c;
+ const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b;
+ const format_tag_t Abcd4a = mkldnn_Abcd4a;
+ const format_tag_t aBcd4b = mkldnn_aBcd4b;
+ const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b;
+ const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a;
+ const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c;
+ const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b;
+ const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a;
+ const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b;
+ const format_tag_t aBcd8b = mkldnn_aBcd8b;
+ const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b;
+ const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b;
+ const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a;
+ const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c;
+ const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c;
+ const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b;
+ const format_tag_t Abcde16a = mkldnn_Abcde16a;
+ const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b;
+ const format_tag_t aBcde16b = mkldnn_aBcde16b;
+ const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a;
+ const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c;
+ const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b;
+ const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c;
+ const format_tag_t Abcde4a = mkldnn_Abcde4a;
+ const format_tag_t aBcde4b = mkldnn_aBcde4b;
+ const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a;
+ const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c;
+ const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c;
+ const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b;
+ const format_tag_t Abcde8a = mkldnn_Abcde8a;
+ const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b;
+ const format_tag_t aBcde8b = mkldnn_aBcde8b;
+ const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b;
+ const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b;
+ const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a;
+ const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c;
+ const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c;
+ const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b;
+ const format_tag_t aBcdef16b = mkldnn_aBcdef16b;
+ const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c;
+ const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b;
+ const format_tag_t aBcdef4b = mkldnn_aBcdef4b;
+ const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b;
+ const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c;
+ const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c;
+ const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b;
+ const format_tag_t aBdc16b = mkldnn_aBdc16b;
+ const format_tag_t aBdc4b = mkldnn_aBdc4b;
+ const format_tag_t aBdc8b = mkldnn_aBdc8b;
+ const format_tag_t aBdec16b = mkldnn_aBdec16b;
+ const format_tag_t aBdec4b = mkldnn_aBdec4b;
+ const format_tag_t aBdec8b = mkldnn_aBdec8b;
+ const format_tag_t aBdefc16b = mkldnn_aBdefc16b;
+ const format_tag_t aBdefc4b = mkldnn_aBdefc4b;
+ const format_tag_t aBdefc8b = mkldnn_aBdefc8b;
+ const format_tag_t Acb16a = mkldnn_Acb16a;
+ const format_tag_t Acb4a = mkldnn_Acb4a;
+ const format_tag_t Acb8a = mkldnn_Acb8a;
+ const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c;
+ const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c;
+ const format_tag_t Acdb16a = mkldnn_Acdb16a;
+ const format_tag_t Acdb4a = mkldnn_Acdb4a;
+ const format_tag_t Acdb8a = mkldnn_Acdb8a;
+ const format_tag_t Acdeb16a = mkldnn_Acdeb16a;
+ const format_tag_t Acdeb4a = mkldnn_Acdeb4a;
+ const format_tag_t Acdeb8a = mkldnn_Acdeb8a;
+ const format_tag_t BAc16a16b = mkldnn_BAc16a16b;
+ const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b;
+ const format_tag_t last = mkldnn_format_tag_last;
+
+ const format_tag_t x = mkldnn_x;
+ const format_tag_t nc = mkldnn_nc;
+ const format_tag_t cn = mkldnn_cn;
+ const format_tag_t ncw = mkldnn_ncw;
+ const format_tag_t nwc = mkldnn_nwc;
+ const format_tag_t nchw = mkldnn_nchw;
+ const format_tag_t nhwc = mkldnn_nhwc;
+ const format_tag_t chwn = mkldnn_chwn;
+ const format_tag_t ncdhw = mkldnn_ncdhw;
+ const format_tag_t ndhwc = mkldnn_ndhwc;
+ const format_tag_t oi = mkldnn_oi;
+ const format_tag_t io = mkldnn_io;
+ const format_tag_t oiw = mkldnn_oiw;
+ const format_tag_t wio = mkldnn_wio;
+ const format_tag_t oihw = mkldnn_oihw;
+ const format_tag_t hwio = mkldnn_hwio;
+ const format_tag_t ihwo = mkldnn_ihwo;
+ const format_tag_t iohw = mkldnn_iohw;
+ const format_tag_t oidhw = mkldnn_oidhw;
+ const format_tag_t dhwio = mkldnn_dhwio;
+ const format_tag_t goiw = mkldnn_goiw;
+ const format_tag_t goihw = mkldnn_goihw;
+ const format_tag_t hwigo = mkldnn_hwigo;
+ const format_tag_t giohw = mkldnn_giohw;
+ const format_tag_t goidhw = mkldnn_goidhw;
+ const format_tag_t tnc = mkldnn_tnc;
+ const format_tag_t ntc = mkldnn_ntc;
+ const format_tag_t ldsnc = mkldnn_ldsnc;
+ const format_tag_t ldigo = mkldnn_ldigo;
+ const format_tag_t ldgoi = mkldnn_ldgoi;
+ const format_tag_t ldgo = mkldnn_ldgo;
+ const format_tag_t nCdhw16c = mkldnn_nCdhw16c;
+ const format_tag_t nCdhw4c = mkldnn_nCdhw4c;
+ const format_tag_t nCdhw8c = mkldnn_nCdhw8c;
+ const format_tag_t nChw16c = mkldnn_nChw16c;
+ const format_tag_t nChw4c = mkldnn_nChw4c;
+ const format_tag_t nChw8c = mkldnn_nChw8c;
+ const format_tag_t nCw16c = mkldnn_nCw16c;
+ const format_tag_t nCw4c = mkldnn_nCw4c;
+ const format_tag_t nCw8c = mkldnn_nCw8c;
+ const format_tag_t IOw16o16i = mkldnn_IOw16o16i;
+ const format_tag_t OIw16i16o = mkldnn_OIw16i16o;
+ const format_tag_t OIw16o16i = mkldnn_OIw16o16i;
+ const format_tag_t Oiw16o = mkldnn_Oiw16o;
+ const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i;
+ const format_tag_t OIw4i4o = mkldnn_OIw4i4o;
+ const format_tag_t Oiw4o = mkldnn_Oiw4o;
+ const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i;
+ const format_tag_t OIw8i8o = mkldnn_OIw8i8o;
+ const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o;
+ const format_tag_t OIw8o8i = mkldnn_OIw8o8i;
+ const format_tag_t Owi16o = mkldnn_Owi16o;
+ const format_tag_t Owi4o = mkldnn_Owi4o;
+ const format_tag_t Owi8o = mkldnn_Owi8o;
+ const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i;
+ const format_tag_t Ohwi16o = mkldnn_Ohwi16o;
+ const format_tag_t Ohwi4o = mkldnn_Ohwi4o;
+ const format_tag_t Ohwi8o = mkldnn_Ohwi8o;
+ const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o;
+ const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i;
+ const format_tag_t Oihw16o = mkldnn_Oihw16o;
+ const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i;
+ const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o;
+ const format_tag_t Oihw4o = mkldnn_Oihw4o;
+ const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i;
+ const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o;
+ const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o;
+ const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i;
+ const format_tag_t Odhwi16o = mkldnn_Odhwi16o;
+ const format_tag_t Odhwi4o = mkldnn_Odhwi4o;
+ const format_tag_t Odhwi8o = mkldnn_Odhwi8o;
+ const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o;
+ const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i;
+ const format_tag_t Oidhw16o = mkldnn_Oidhw16o;
+ const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o;
+ const format_tag_t Oidhw4o = mkldnn_Oidhw4o;
+ const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i;
+ const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o;
+ const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i;
+ const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i;
+ const format_tag_t Goiw16g = mkldnn_Goiw16g;
+ const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o;
+ const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i;
+ const format_tag_t gOiw16o = mkldnn_gOiw16o;
+ const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i;
+ const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o;
+ const format_tag_t gOiw4o = mkldnn_gOiw4o;
+ const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i;
+ const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o;
+ const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o;
+ const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i;
+ const format_tag_t gOwi16o = mkldnn_gOwi16o;
+ const format_tag_t gOwi4o = mkldnn_gOwi4o;
+ const format_tag_t gOwi8o = mkldnn_gOwi8o;
+ const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i;
+ const format_tag_t gOhwi16o = mkldnn_gOhwi16o;
+ const format_tag_t gOhwi4o = mkldnn_gOhwi4o;
+ const format_tag_t gOhwi8o = mkldnn_gOhwi8o;
+ const format_tag_t Goihw16g = mkldnn_Goihw16g;
+ const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o;
+ const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i;
+ const format_tag_t gOihw16o = mkldnn_gOihw16o;
+ const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i;
+ const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i;
+ const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o;
+ const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i;
+ const format_tag_t gOihw4o = mkldnn_gOihw4o;
+ const format_tag_t Goihw8g = mkldnn_Goihw8g;
+ const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i;
+ const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o;
+ const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o;
+ const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i;
+ const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o;
+ const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o;
+ const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o;
+ const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o;
+ const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i;
+ const format_tag_t gOidhw16o = mkldnn_gOidhw16o;
+ const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o;
+ const format_tag_t gOidhw4o = mkldnn_gOidhw4o;
+ const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i;
+ const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o;
+ const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i;
+}
+
+using memory_extra_flags_t = mkldnn_memory_extra_flags_t;
+namespace memory_extra_flags {
+ const memory_extra_flags_t none = mkldnn_memory_extra_flag_none;
+ const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8;
+ const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust;
+}
+
+using padding_kind_t = mkldnn_padding_kind_t;
+namespace padding_kind {
+ const padding_kind_t padding_zero = mkldnn_padding_zero;
+}
+
+using engine_kind_t = mkldnn_engine_kind_t;
+namespace engine_kind {
+ const engine_kind_t any_engine = mkldnn_any_engine;
+ const engine_kind_t cpu = mkldnn_cpu;
+}
+
+using primitive_kind_t = mkldnn_primitive_kind_t;
+namespace primitive_kind {
+ const primitive_kind_t undefined = mkldnn_undefined_primitive;
+ const primitive_kind_t reorder = mkldnn_reorder;
+ const primitive_kind_t concat = mkldnn_concat;
+ const primitive_kind_t sum = mkldnn_sum;
+ const primitive_kind_t convolution = mkldnn_convolution;
+ const primitive_kind_t deconvolution = mkldnn_deconvolution;
+ const primitive_kind_t shuffle = mkldnn_shuffle;
+ const primitive_kind_t eltwise = mkldnn_eltwise;
+ const primitive_kind_t softmax = mkldnn_softmax;
+ const primitive_kind_t pooling = mkldnn_pooling;
+ const primitive_kind_t lrn = mkldnn_lrn;
+ const primitive_kind_t batch_normalization = mkldnn_batch_normalization;
+ const primitive_kind_t inner_product = mkldnn_inner_product;
+ const primitive_kind_t rnn = mkldnn_rnn;
+}
+
+using query_t = mkldnn_query_t;
+namespace query {
+ const query_t undef = mkldnn_query_undef;
+
+ const query_t engine = mkldnn_query_engine;
+ const query_t primitive_kind = mkldnn_query_primitive_kind;
+
+ const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32;
+ const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32;
+
+ const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64;
+ const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64;
+
+ const query_t scratchpad_engine = mkldnn_query_scratchpad_engine;
+
+ const query_t impl_info_str = mkldnn_query_impl_info_str;
+
+ const query_t some_d = mkldnn_query_some_d;
+ const query_t op_d = mkldnn_query_op_d;
+ const query_t convolution_d = mkldnn_query_convolution_d;
+ const query_t deconvolution_d = mkldnn_query_deconvolution_d;
+ const query_t shuffle_d = mkldnn_query_shuffle_d;
+ const query_t eltwise_d = mkldnn_query_eltwise_d;
+ const query_t softmax_d = mkldnn_query_softmax_d;
+ const query_t pooling_d = mkldnn_query_pooling_d;
+ const query_t lrn_d = mkldnn_query_lrn_d;
+ const query_t batch_normalization_d = mkldnn_query_batch_normalization_d;
+ const query_t inner_product_d = mkldnn_query_inner_product_d;
+ const query_t rnn_d = mkldnn_query_rnn_d;
+
+ const query_t some_md = mkldnn_query_some_md;
+ const query_t src_md = mkldnn_query_src_md;
+ const query_t diff_src_md = mkldnn_query_diff_src_md;
+ const query_t weights_md = mkldnn_query_weights_md;
+ const query_t diff_weights_md = mkldnn_query_diff_weights_md;
+ const query_t dst_md = mkldnn_query_dst_md;
+ const query_t diff_dst_md = mkldnn_query_diff_dst_md;
+
+ const query_t workspace_md = mkldnn_query_workspace_md;
+ const query_t scratchpad_md = mkldnn_query_scratchpad_md;
+}
+
+using blocking_desc_t = mkldnn_blocking_desc_t;
+using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t;
+using wino_desc_t = mkldnn_wino_desc_t;
+using memory_extra_desc_t = mkldnn_memory_extra_desc_t;
+using memory_desc_t = mkldnn_memory_desc_t;
+using convolution_desc_t = mkldnn_convolution_desc_t;
+using deconvolution_desc_t = mkldnn_deconvolution_desc_t;
+using shuffle_desc_t = mkldnn_shuffle_desc_t;
+using pooling_desc_t = mkldnn_pooling_desc_t;
+using eltwise_desc_t = mkldnn_eltwise_desc_t;
+using softmax_desc_t = mkldnn_softmax_desc_t;
+using lrn_desc_t = mkldnn_lrn_desc_t;
+using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t;
+using inner_product_desc_t = mkldnn_inner_product_desc_t;
+
+using rnn_direction_t = mkldnn_rnn_direction_t;
+using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t;
+using rnn_desc_t = mkldnn_rnn_desc_t;
+
+/* C op_desc_t, which eventually are just (void*) */
+using c_op_desc_t = mkldnn_op_desc_t;
+using const_c_op_desc_t = const_mkldnn_op_desc_t;
+
+struct op_desc_t {
+ union {
+ primitive_kind_t kind;
+ convolution_desc_t convolution;
+ deconvolution_desc_t deconvolution;
+ shuffle_desc_t shuffle;
+ pooling_desc_t pooling;
+ eltwise_desc_t eltwise;
+ softmax_desc_t softmax;
+ lrn_desc_t lrn;
+ batch_normalization_desc_t batch_normalization;
+ inner_product_desc_t inner_product;
+ rnn_desc_t rnn;
+ };
+
+ op_desc_t(const primitive_kind_t &_): kind(_) {}
+
+# define DECL_CTOR_AND_CONVERTERS(c_type, name) \
+ op_desc_t(const c_type &_): name(_) {} \
+ static op_desc_t *convert_from_c(c_type *_) \
+ { return reinterpret_cast<op_desc_t*>(_); } \
+ static const op_desc_t *convert_from_c(const c_type *_) \
+ { return reinterpret_cast<const op_desc_t*>(_); }
+
+ DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution);
+ DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle);
+ DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling);
+ DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise);
+ DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax);
+ DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn);
+ DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization);
+ DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product);
+ DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn);
+
+# undef DECL_CTOR_AND_CONVERTERS
+};
+
+using engine_t = mkldnn_engine;
+using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator;
+using primitive_desc_t = mkldnn_primitive_desc;
+using primitive_attr_t = mkldnn_primitive_attr;
+using post_ops_t = mkldnn_post_ops;
+using memory_t = mkldnn_memory;
+using primitive_t = mkldnn_primitive;
+
+using primitive_arg_index_t = int;
+
+using stream_flags_t = mkldnn_stream_flags_t;
+namespace stream_flags {
+ const stream_flags_t default_flags = mkldnn_stream_default_flags;
+}
+using stream_t = mkldnn_stream;
+
+/* forward declaration of the internal primitive_desc types */
+struct batch_normalization_bwd_pd_t;
+struct batch_normalization_fwd_pd_t;
+struct batch_normalization_pd_t;
+struct concat_pd_t;
+struct convolution_bwd_data_pd_t;
+struct convolution_bwd_weights_pd_t;
+struct convolution_fwd_pd_t;
+struct convolution_pd_t;
+struct deconvolution_bwd_data_pd_t;
+struct deconvolution_bwd_weights_pd_t;
+struct deconvolution_fwd_pd_t;
+struct deconvolution_pd_t;
+struct eltwise_bwd_pd_t;
+struct eltwise_fwd_pd_t;
+struct eltwise_pd_t;
+struct inner_product_bwd_data_pd_t;
+struct inner_product_bwd_weights_pd_t;
+struct inner_product_fwd_pd_t;
+struct inner_product_pd_t;
+struct lrn_bwd_pd_t;
+struct lrn_fwd_pd_t;
+struct lrn_pd_t;
+struct pooling_bwd_pd_t;
+struct pooling_fwd_pd_t;
+struct pooling_pd_t;
+struct reorder_pd_t;
+struct rnn_bwd_pd_t;
+struct rnn_fwd_pd_t;
+struct rnn_pd_t;
+struct shuffle_pd_t;
+struct softmax_bwd_pd_t;
+struct softmax_fwd_pd_t;
+struct softmax_pd_t;
+struct sum_pd_t;
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat.cpp b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp
new file mode 100644
index 0000000000..ed4c35c6e9
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/concat.cpp
@@ -0,0 +1,86 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "concat_pd.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd,
+ const memory_desc_t *dst_md, int n, int concat_dim,
+ const memory_desc_t *src_mds,
+ const primitive_attr_t *attr,
+ engine_t *engine) {
+ bool args_ok = !any_null(concat_pd, src_mds) && n > 0;
+ if (!args_ok) return invalid_arguments;
+
+ const primitive_attr_t dummy_attr;
+ if (attr == NULL)
+ attr = &dummy_attr;
+
+ const int ndims = src_mds[0].ndims;
+ const dims_t &dims = src_mds[0].dims;
+ const data_type_t dt = src_mds[0].data_type;
+
+ int concat_dim_sz = dims[concat_dim];
+ for (int i = 1; i < n; ++i) {
+ if (src_mds[i].ndims != ndims) return invalid_arguments;
+ for (int d = 0; d < ndims; ++d) {
+ if (d == concat_dim) continue;
+ if (src_mds[i].dims[d] != dims[d])
+ return invalid_arguments;
+ }
+ if (src_mds[i].data_type != dt) return invalid_arguments;
+ concat_dim_sz += src_mds[i].dims[concat_dim];
+ }
+
+ memory_desc_t dummy_dst_md;
+ if (dst_md) {
+ if (dst_md->ndims != ndims) return invalid_arguments;
+ for (int d = 0; d < ndims; ++d) {
+ if (dst_md->dims[d] !=
+ (d == concat_dim ? concat_dim_sz : dims[d]))
+ return invalid_arguments;
+ }
+ } else {
+ dummy_dst_md = src_mds[0];
+ dummy_dst_md.dims[concat_dim] = concat_dim_sz;
+ dummy_dst_md.format_kind = format_kind::any;
+ dst_md = &dummy_dst_md;
+ }
+
+ auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd);
+
+ for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
+ if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds)
+ == success) {
+ (*c_pd)->init_info();
+ (*c_pd)->init_scratchpad_md();
+ return success;
+ }
+ }
+ return unimplemented;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
new file mode 100644
index 0000000000..29311927e2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
@@ -0,0 +1,211 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CONCAT_PD_HPP
+#define CONCAT_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct concat_pd_t: public primitive_desc_t {
+ concat_pd_t(engine_t *engine, const primitive_attr_t *attr,
+ const memory_desc_t *dst_md, int n, int concat_dim,
+ const memory_desc_t *src_mds)
+ : primitive_desc_t(engine, attr, primitive_kind::concat)
+ , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md)
+ {
+ src_mds_.reserve(n_);
+ for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
+ }
+
+ concat_pd_t(const concat_pd_t &rhs) = default;
+
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg >= MKLDNN_ARG_MULTIPLE_SRC
+ && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index < n_inputs() ? &src_mds_[index] : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return n_; }
+ virtual int n_outputs() const override { return 1; }
+
+ int concat_dim() const { return concat_dim_; }
+
+ const memory_desc_t *src_image_md(int index = 0) const
+ { return index < n_inputs() ? &src_image_mds_[index] : nullptr; }
+
+protected:
+ int n_, concat_dim_;
+ memory_desc_t dst_md_;
+ nstl::vector<memory_desc_t> src_mds_;
+
+ /* contains images of srcs in the dst memory (if possible)
+ * Lives here to simplify some implementations. An implementation might
+ * use this auxiliary array iff init() returned success */
+ nstl::vector<memory_desc_t> src_image_mds_;
+
+protected:
+ /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */
+ status_t init() {
+ bool ok = true
+ && set_default_params() == status::success
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper i_d(&src_mds_[i]);
+ if (!i_d.is_blocking_desc() || i_d.is_additional_buffer())
+ return status::unimplemented;
+ }
+
+ const int ndims = dst_md_.ndims;
+ int current_concat_dim_offset = 0;
+ for (int i = 0; i < n_; ++i) {
+ const int dim = src_mds_[i].dims[concat_dim_];
+ dims_t dims, offsets = {};
+ utils::array_copy(dims, dst_md_.dims, ndims);
+ dims[concat_dim_] = dim;
+ offsets[concat_dim_] = current_concat_dim_offset;
+
+ memory_desc_t src_img_d;
+ status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
+ &dst_md_, dims, offsets);
+ if (status != status::success) return status;
+ src_image_mds_.push_back(src_img_d);
+ current_concat_dim_offset += dim;
+ }
+
+ return status::success;
+ }
+
+ status_t set_default_params() {
+ if (dst_md_.format_kind != format_kind::any)
+ return status::success;
+
+ const int ndims = dst_md_.ndims;
+
+ /* The stupidest ever heuristics (but not the same as we had before):
+ * - Pick the first non-plain format;
+ * - If all formats are plain or it is not possible to create a
+ * blocked format for the output, pick the format of the plain input
+ * - If this fails as well, use plain layout (abcd...)
+ */
+ status_t status = status::unimplemented;
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper src_d(src_mds_[i]);
+ if (src_d.is_blocking_desc() && !src_d.is_plain()) {
+ status = memory_desc_init_by_blocking_desc(dst_md_,
+ src_d.blocking_desc());
+ if (status == status::success) break;
+ }
+ }
+
+ if (status == status::success) {
+ /* check if we can create a sub-memory for the dst */
+ bool desired_format_ok = true;
+ int current_concat_dim_offset = 0;
+ for (int i = 0; i < n_; ++i) {
+ const int dim = src_mds_[i].dims[concat_dim_];
+ dims_t dims, offsets = {};
+ utils::array_copy(dims, dst_md_.dims, ndims);
+ dims[concat_dim_] = dim;
+ offsets[concat_dim_] = current_concat_dim_offset;
+
+ memory_desc_t src_img_d;
+ status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
+ &dst_md_, dims, offsets);
+ if (status != status::success) {
+ desired_format_ok = false;
+ break;
+ }
+ current_concat_dim_offset += dim;
+ }
+
+ if (!desired_format_ok)
+ status = status::unimplemented;
+ }
+
+ /* if no success so far, try using the format of the first plain input */
+ if (status != status::success) {
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper src_d(src_mds_[i]);
+ if (src_d.is_blocking_desc() && src_d.is_plain()) {
+ status = memory_desc_init_by_blocking_desc(dst_md_,
+ memory_desc_wrapper(src_mds_[0]).blocking_desc());
+ if (status == status::success) return status;
+ }
+ }
+ }
+
+ /* the last line of defense: use plain abcd... format */
+ if (status != status::success)
+ status = memory_desc_init_by_strides(dst_md_, nullptr);
+
+ return status;
+ }
+};
+
+#define DECLARE_CONCAT_PD_t(impl_name, ...) \
+ static status_t create(concat_pd_t **concat_pd, \
+ engine_t *engine, const primitive_attr_t *attr, \
+ const memory_desc_t *dst_md, int n, int concat_dim, \
+ const memory_desc_t *src_mds) { \
+ using namespace status; \
+ auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \
+ if (_pd == nullptr) return out_of_memory; \
+ if (_pd->init() != success) { delete _pd; return unimplemented; } \
+ return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd); \
+ } \
+ virtual status_t create_primitive(primitive_t **p) const override { \
+ double ms = get_msec(); \
+ auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
+ ms = get_msec() - ms; \
+ if (mkldnn_verbose()->level >= 2) { \
+ printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
+ fflush(0); \
+ } \
+ return ret; \
+ } \
+ virtual pd_t *clone() const override { return new pd_t(*this); } \
+ virtual const char *name() const override { return impl_name; } \
+
+#define DECLARE_CONCAT_PD_T(impl_name, ...) \
+ DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__)
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
new file mode 100644
index 0000000000..0c5c02bcd1
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
@@ -0,0 +1,200 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace mkldnn {
+namespace impl {
+status_t conv_desc_init(convolution_desc_t *conv_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t dilates,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ bool args_ok = true
+ && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides,
+ padding_l)
+ && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd)
+ && one_of(padding_kind, padding_kind::padding_zero);
+ if (!args_ok) return invalid_arguments;
+
+ if (padding_r == nullptr) padding_r = padding_l;
+
+ auto cd = convolution_desc_t();
+ cd.primitive_kind = primitive_kind::convolution;
+ cd.prop_kind = prop_kind;
+ cd.alg_kind = alg_kind;
+
+ cd.diff_src_desc = cd.src_desc = zero_md();
+ cd.diff_dst_desc = cd.dst_desc = zero_md();
+ cd.diff_weights_desc = cd.weights_desc = zero_md();
+ cd.diff_bias_desc = cd.bias_desc = zero_md();
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+ const bool with_bias =
+ bias_desc && bias_desc->format_kind != format_kind::undef;
+ const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
+
+ (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc;
+ (is_fwd ? cd.dst_desc : cd.diff_dst_desc) = *dst_desc;
+ (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) =
+ *weights_desc;
+ if (with_bias)
+ (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) =
+ *bias_desc;
+
+ int sp_dims = src_desc->ndims - 2;
+ utils::array_copy(cd.strides, strides, sp_dims);
+ utils::array_copy(cd.padding[0], padding_l, sp_dims);
+ utils::array_copy(cd.padding[1], padding_r, sp_dims);
+ if (dilates)
+ utils::array_copy(cd.dilates, dilates, sp_dims);
+ else
+ utils::array_set(cd.dilates, 0, sp_dims);
+
+ cd.padding_kind = padding_kind;
+ cd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
+ weights_desc->data_type, dst_desc->data_type, prop_kind);
+
+ const int g = with_groups ? weights_desc->dims[0] : 1;
+ const int bias_dim = prop_kind == backward_data
+ ? src_desc->dims[1]
+ : dst_desc->dims[1];
+
+ bool consistency = true
+ && memory_desc_wrapper(weights_desc).nelems()
+ && src_desc->ndims == dst_desc->ndims
+ && utils::one_of(src_desc->ndims, 3, 4, 5)
+ && utils::one_of(weights_desc->ndims, src_desc->ndims,
+ src_desc->ndims + 1)
+ && (with_bias ? bias_desc->ndims == 1 : true)
+ && (with_bias ? bias_desc->dims[0] == bias_dim : true)
+ && src_desc->dims[0] == dst_desc->dims[0]
+ && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
+ && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
+ for (int i = 2; i < src_desc->ndims; ++i)
+ {
+ int src = src_desc->dims[i];
+ int ker = weights_desc->dims[with_groups + i];
+ int dil = cd.dilates[i - 2];
+ int pad_l = padding_l[i - 2];
+ int pad_r = padding_r[i - 2];
+ int str = strides[i - 2];
+ int dst = dst_desc->dims[i];
+ int ker_range = 1 + (ker - 1) * (dil + 1);
+
+ if (str < 1) return invalid_arguments;
+ consistency = consistency
+ && dil >= 0
+ && pad_l >= 0
+ && pad_r + str > 0
+ && (src - ker_range + pad_l + pad_r) / str + 1 == dst;
+ }
+ if (!consistency) return invalid_arguments;
+
+ *conv_desc = cd;
+ return success;
+}
+}
+}
+
+status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
+ weights_desc, bias_desc, dst_desc, strides, nullptr,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_convolution_forward_desc_init(
+ convolution_desc_t *conv_desc, prop_kind_t prop_kind,
+ alg_kind_t alg_kind, const memory_desc_t *src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
+ weights_desc, bias_desc, dst_desc, strides, dilates,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_convolution_backward_data_desc_init(
+ convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
+ weights_desc, nullptr, diff_dst_desc, strides, nullptr,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_convolution_backward_data_desc_init(
+ convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
+ weights_desc, nullptr, diff_dst_desc, strides, dilates,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_convolution_backward_weights_desc_init(
+ convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
+ diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
+ nullptr, padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_convolution_backward_weights_desc_init(
+ convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
+ diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
+ dilates, padding_l, padding_r, padding_kind);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp
new file mode 100644
index 0000000000..9604e0acf5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp
@@ -0,0 +1,56 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "utils.hpp"
+
+#include "convolution_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+using namespace prop_kind;
+
+memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) {
+ return desc->prop_kind == backward_data
+ ? &desc->diff_src_desc : &desc->src_desc;
+}
+
+memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) {
+ return desc->prop_kind == backward_weights
+ ? &desc->diff_weights_desc : &desc->weights_desc;
+}
+
+memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) {
+ return desc->prop_kind == backward_weights
+ ? &desc->diff_bias_desc : &desc->bias_desc;
+}
+
+memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) {
+ return utils::one_of(desc->prop_kind, forward_inference, forward_training)
+ ? &desc->dst_desc : &desc->diff_dst_desc;
+}
+
+const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); }
+const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); }
+const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); }
+const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); }
+
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp
new file mode 100644
index 0000000000..b10c36db49
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp
@@ -0,0 +1,348 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CONVOLUTION_PD_HPP
+#define CONVOLUTION_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+status_t conv_desc_init(convolution_desc_t *conv_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t dilates,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind);
+
+memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc);
+
+struct convolution_fwd_pd_t;
+
+struct convolution_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::convolution;
+
+ convolution_pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ {}
+
+ const convolution_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case pkind_traits<base_pkind>::query_d:
+ *(const convolution_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common conv aux functions */
+
+ dim_t MB() const { return _src_md()->dims[0]; }
+
+ dim_t IC() const { return _src_md()->dims[1]; }
+ dim_t OC() const { return _dst_md()->dims[1]; }
+ dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; }
+
+ dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; }
+ dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; }
+ dim_t IW() const { return _src_md()->dims[ndims() - 1]; }
+
+ dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; }
+ dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; }
+ dim_t OW() const { return _dst_md()->dims[ndims() - 1]; }
+
+ dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; }
+ dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; }
+ dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; }
+
+ dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
+ dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
+ dim_t KSW() const { return desc_.strides[ndims() - 3]; }
+
+ dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
+ dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
+ dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
+
+ dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
+ dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
+ dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
+ dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
+ dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
+ dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
+
+ int ndims() const { return _src_md()->ndims; }
+
+ bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); }
+ bool with_groups() const { return _wei_md()->ndims == ndims() + 1; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ bool has_zero_dim_memory() const {
+ const auto s_d = memory_desc_wrapper(*_src_md());
+ const auto d_d = memory_desc_wrapper(*_dst_md());
+ return s_d.has_zero_dim() || d_d.has_zero_dim();
+ }
+
+protected:
+ convolution_desc_t desc_;
+ const convolution_fwd_pd_t *hint_fwd_pd_;
+
+ bool set_default_formats_common_template(
+ memory_desc_t &src_md, format_tag_t src_tag,
+ memory_desc_t &wei_md, format_tag_t wei_tag,
+ memory_desc_t &dst_md, format_tag_t dst_tag,
+ memory_desc_t &bia_md) {
+ using namespace format_tag;
+
+# define IS_OK(f) \
+ do { if ((f) != status::success) return false; } while(0)
+ if (src_md.format_kind == format_kind::any
+ && !utils::one_of(src_tag, any, undef))
+ IS_OK(memory_desc_init_by_tag(src_md, src_tag));
+ if (dst_md.format_kind == format_kind::any
+ && !utils::one_of(dst_tag, any, undef))
+ IS_OK(memory_desc_init_by_tag(dst_md, dst_tag));
+ if (wei_md.format_kind == format_kind::any
+ && !utils::one_of(wei_tag, any, undef))
+ IS_OK(memory_desc_init_by_tag(wei_md, wei_tag));
+ if (with_bias() && bia_md.format_kind == format_kind::any)
+ IS_OK(memory_desc_init_by_tag(bia_md, x));
+# undef IS_OK
+
+ return true;
+ }
+
+ bool set_default_alg_kind(alg_kind_t alg_kind) {
+ assert(utils::one_of(alg_kind, alg_kind::convolution_direct,
+ alg_kind::convolution_winograd));
+ if (desc_.alg_kind == alg_kind::convolution_auto)
+ desc_.alg_kind = alg_kind;
+ return desc_.alg_kind == alg_kind;
+ }
+
+ bool expect_data_types(data_type_t src_dt, data_type_t wei_dt,
+ data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const {
+ bool ok = true
+ && (src_dt == data_type::undef || _src_md()->data_type == src_dt)
+ && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt)
+ && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt)
+ && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt);
+ if (with_bias() && bia_dt != data_type::undef)
+ ok = ok && _bia_md()->data_type == bia_dt;
+ return ok;
+ }
+
+private:
+ const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); }
+ const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); }
+ const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); }
+ const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); }
+};
+
+struct convolution_fwd_pd_t: public convolution_pd_t {
+ typedef convolution_fwd_pd_t base_class;
+ typedef convolution_fwd_pd_t hint_class;
+
+ convolution_fwd_pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , weights_md_(desc_.weights_desc)
+ , bias_md_(desc_.bias_desc)
+ , dst_md_(desc_.dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_BIAS && with_bias())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_md_;
+ if (index == 1 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2 + with_bias(); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t dst_md_;
+
+ bool set_default_formats_common(format_tag_t src_tag,
+ format_tag_t wei_tag, format_tag_t dst_tag) {
+ return set_default_formats_common_template(src_md_, src_tag,
+ weights_md_, wei_tag, dst_md_, dst_tag, bias_md_);
+ }
+};
+
+struct convolution_bwd_data_pd_t: public convolution_pd_t {
+ typedef convolution_bwd_data_pd_t base_class;
+ typedef convolution_fwd_pd_t hint_class;
+
+ convolution_bwd_data_pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_md_(desc_.diff_src_desc)
+ , weights_md_(desc_.weights_desc)
+ , bias_md_(desc_.bias_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_md_;
+ if (index == 1 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2 + with_bias(); }
+ virtual int n_outputs() const override { return 1; }
+
+ virtual bool support_bias() const { return false; }
+
+protected:
+ memory_desc_t diff_src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t diff_dst_md_;
+
+ bool set_default_formats_common(format_tag_t diff_src_tag,
+ format_tag_t wei_tag, format_tag_t diff_dst_tag) {
+ return set_default_formats_common_template(diff_src_md_, diff_src_tag,
+ weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_);
+ }
+};
+
+struct convolution_bwd_weights_pd_t: public convolution_pd_t {
+ typedef convolution_bwd_weights_pd_t base_class;
+ typedef convolution_fwd_pd_t hint_class;
+
+ convolution_bwd_weights_pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , diff_weights_md_(desc_.diff_weights_desc)
+ , diff_bias_md_(desc_.diff_bias_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
+ if (index == 0) return &diff_weights_md_;
+ if (index == 1 && with_bias()) return &diff_bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1 + with_bias(); }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t diff_weights_md_;
+ memory_desc_t diff_bias_md_;
+ memory_desc_t diff_dst_md_;
+
+ bool set_default_formats_common(format_tag_t src_tag,
+ format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) {
+ return set_default_formats_common_template(src_md_, src_tag,
+ diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag,
+ diff_bias_md_);
+ }
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
new file mode 100644
index 0000000000..98063c1c37
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
@@ -0,0 +1,188 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t deconv_desc_init(deconvolution_desc_t *deconv_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t dilates, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ bool args_ok = true
+ && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides,
+ padding_l)
+ && one_of(alg_kind, deconvolution_direct, deconvolution_winograd)
+ && one_of(padding_kind, padding_kind::padding_zero);
+ if (!args_ok)
+ return invalid_arguments;
+
+ if (padding_r == nullptr)
+ padding_r = padding_l;
+
+ auto dd = deconvolution_desc_t();
+ dd.primitive_kind = primitive_kind::deconvolution;
+ dd.prop_kind = prop_kind;
+ dd.alg_kind = alg_kind;
+
+ dd.diff_src_desc = dd.src_desc = zero_md();
+ dd.diff_dst_desc = dd.dst_desc = zero_md();
+ dd.diff_weights_desc = dd.weights_desc = zero_md();
+ dd.diff_bias_desc = dd.bias_desc = zero_md();
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+ const bool with_bias
+ = bias_desc && bias_desc->format_kind != format_kind::undef;
+ const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
+
+ (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc;
+ (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc;
+ (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc)
+ = *weights_desc;
+ if (with_bias)
+ (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc)
+ = *bias_desc;
+
+ int sp_dims = src_desc->ndims - 2;
+ utils::array_copy(dd.strides, strides, sp_dims);
+ utils::array_copy(dd.padding[0], padding_l, sp_dims);
+ utils::array_copy(dd.padding[1], padding_r, sp_dims);
+ if (dilates)
+ utils::array_copy(dd.dilates, dilates, sp_dims);
+ else
+ utils::array_set(dd.dilates, 0, sp_dims);
+
+ dd.padding_kind = padding_kind;
+ dd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
+ weights_desc->data_type, dst_desc->data_type, prop_kind);
+
+ const int g = with_groups ? weights_desc->dims[0] : 1;
+ bool consistency = true
+ && src_desc->ndims == dst_desc->ndims
+ && utils::one_of(src_desc->ndims, 3, 4, 5)
+ && utils::one_of(weights_desc->ndims, src_desc->ndims,
+ src_desc->ndims + 1)
+ && (with_bias ? bias_desc->ndims == 1 : true)
+ && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
+ && src_desc->dims[0] == dst_desc->dims[0]
+ && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
+ && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
+ for (int i = 2; i < src_desc->ndims; ++i) {
+ int src = src_desc->dims[i];
+ int ker = weights_desc->dims[with_groups + i];
+ int dil = dd.dilates[i - 2];
+ int pad = padding_l[i - 2] + padding_r[i - 2];
+ int str = strides[i - 2];
+ int dst = dst_desc->dims[i];
+ int ker_range = 1 + (ker - 1) * (dil + 1);
+
+ consistency
+ = consistency && (dst - ker_range + pad) / str + 1 == src;
+ }
+ if (!consistency)
+ return invalid_arguments;
+
+ *deconv_desc = dd;
+ return success;
+}
+}
+
+status_t mkldnn_deconvolution_forward_desc_init(
+ deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
+ alg_kind_t alg_kind, const memory_desc_t *src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_desc, const dims_t strides,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
+ weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l,
+ padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_deconvolution_forward_desc_init(
+ deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
+ alg_kind_t alg_kind, const memory_desc_t *src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
+ weights_desc, bias_desc, dst_desc, strides, dilates, padding_l,
+ padding_r, padding_kind);
+}
+
+status_t mkldnn_deconvolution_backward_data_desc_init(
+ deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
+ weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l,
+ padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_deconvolution_backward_data_desc_init(
+ deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
+ weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l,
+ padding_r, padding_kind);
+}
+
+status_t mkldnn_deconvolution_backward_weights_desc_init(
+ deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
+ const dims_t strides, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
+ diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr,
+ padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_deconvolution_backward_weights_desc_init(
+ deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
+ const dims_t strides, const dims_t dilates, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
+ diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates,
+ padding_l, padding_r, padding_kind);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp
new file mode 100644
index 0000000000..539e44bd9b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp
@@ -0,0 +1,293 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef DECONVOLUTION_PD_HPP
+#define DECONVOLUTION_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "convolution_pd.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct deconvolution_fwd_pd_t;
+
+struct deconvolution_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::deconvolution;
+
+ deconvolution_pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ {}
+
+ const deconvolution_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case pkind_traits<base_pkind>::query_d:
+ *(const deconvolution_desc_t **)result = desc();
+ break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */
+
+ dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; }
+
+ dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; }
+ dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; }
+ dim_t G() const
+ { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; }
+
+ dim_t ID() const {
+ return ndims() >= 5
+ ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t IH() const {
+ return ndims() >= 4
+ ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t IW() const {
+ return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1];
+ }
+
+ dim_t OD() const {
+ return ndims() >= 5
+ ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t OH() const {
+ return ndims() >= 4
+ ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t OW() const {
+ return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1];
+ }
+
+ dim_t KD() const {
+ const int w_ndims = ndims() + with_groups();
+ return ndims() >= 5
+ ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1;
+ }
+ dim_t KH() const {
+ const int w_ndims = ndims() + with_groups();
+ return ndims() >= 4
+ ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1;
+ }
+ dim_t KW() const {
+ const int w_ndims = ndims() + with_groups();
+ return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1];
+ }
+
+ dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
+ dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
+ dim_t KSW() const { return desc_.strides[ndims() - 3]; }
+
+ dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
+ dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
+ dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
+
+ dim_t padFront() const
+ { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
+ dim_t padBack() const
+ { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
+ dim_t padT() const
+ { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
+ dim_t padB() const
+ { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
+ dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
+ dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
+
+ bool with_bias() const {
+ return
+ !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero();
+ }
+
+ bool with_groups() const
+ { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; }
+
+ int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ bool has_zero_dim_memory() const {
+ const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_));
+ const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_));
+ return s_d.has_zero_dim() || d_d.has_zero_dim();
+ }
+
+protected:
+ deconvolution_desc_t desc_;
+ const deconvolution_fwd_pd_t *hint_fwd_pd_;
+};
+
+struct deconvolution_fwd_pd_t: public deconvolution_pd_t {
+ typedef deconvolution_fwd_pd_t base_class;
+ typedef deconvolution_fwd_pd_t hint_class;
+
+ deconvolution_fwd_pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , weights_md_(desc_.weights_desc)
+ , bias_md_(desc_.bias_desc)
+ , dst_md_(desc_.dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_BIAS && with_bias())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_md_;
+ if (index == 1 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2 + with_bias(); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t dst_md_;
+};
+
+struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t {
+ typedef deconvolution_bwd_data_pd_t base_class;
+ typedef deconvolution_fwd_pd_t hint_class;
+
+ deconvolution_bwd_data_pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_md_(desc_.diff_src_desc)
+ , weights_md_(desc_.weights_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override
+ { return index == 0 ? &weights_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t diff_dst_md_;
+};
+
+struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t {
+ typedef deconvolution_bwd_weights_pd_t base_class;
+ typedef deconvolution_fwd_pd_t hint_class;
+
+ deconvolution_bwd_weights_pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , diff_weights_md_(desc_.diff_weights_desc)
+ , diff_bias_md_(desc_.diff_bias_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
+ if (index == 0) return &diff_weights_md_;
+ if (index == 1 && with_bias()) return &diff_bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1 + with_bias(); }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t diff_weights_md_;
+ memory_desc_t diff_bias_md_;
+ memory_desc_t diff_dst_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
new file mode 100644
index 0000000000..f1708fca52
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
@@ -0,0 +1,84 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
+ alg_kind_t alg_kind, const memory_desc_t *data_desc,
+ const memory_desc_t *diff_data_desc, float alpha, float beta) {
+ bool args_ok = true
+ && !any_null(eltwise_desc, data_desc)
+ && one_of(prop_kind, forward_training, forward_inference,
+ backward_data)
+ && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
+ eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+ eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)
+ && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
+ if (!args_ok) return invalid_arguments;
+
+ auto ed = eltwise_desc_t();
+ ed.primitive_kind = primitive_kind::eltwise;
+ ed.prop_kind = prop_kind;
+ ed.alg_kind = alg_kind;
+
+ ed.data_desc = *data_desc;
+ ed.diff_data_desc =
+ (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md();
+
+ ed.alpha = alpha;
+ ed.beta = beta;
+
+ bool consistency = true
+ && IMPLICATION(ed.prop_kind == backward_data,
+ array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims,
+ ed.diff_data_desc.ndims));
+ if (!consistency) return invalid_arguments;
+
+ *eltwise_desc = ed;
+ return success;
+}
+}
+
+status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *data_desc, float alpha, float beta) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc,
+ nullptr, alpha, beta);
+}
+
+status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc,
+ alg_kind_t alg_kind, const memory_desc_t *diff_data_desc,
+ const memory_desc_t *data_desc, float alpha, float beta) {
+ return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc,
+ diff_data_desc, alpha, beta);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
new file mode 100644
index 0000000000..9fd260fcee
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
@@ -0,0 +1,161 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef ELTWISE_PD_HPP
+#define ELTWISE_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct eltwise_fwd_pd_t;
+
+struct eltwise_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::eltwise;
+
+ eltwise_pd_t(mkldnn::impl::engine_t *engine,
+ const eltwise_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const eltwise_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ {}
+
+ const eltwise_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::eltwise_d:
+ *(const eltwise_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common eltwise aux functions */
+
+ dim_t MB() const { return data_desc().dims[0]; }
+ dim_t C() const { return data_desc().dims[1]; }
+ dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+ int ndims() const { return data_desc().ndims; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ bool has_zero_dim_memory() const
+ { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
+
+protected:
+ eltwise_desc_t desc_;
+ const eltwise_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t data_md_;
+
+private:
+ const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct eltwise_fwd_pd_t: public eltwise_pd_t {
+ typedef eltwise_fwd_pd_t base_class;
+ typedef eltwise_fwd_pd_t hint_class;
+
+ eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine,
+ const eltwise_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const eltwise_fwd_pd_t *hint_fwd_pd)
+ : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override { return 1; }
+
+ bool is_zero_preserved() const
+ { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); }
+};
+
+struct eltwise_bwd_pd_t: public eltwise_pd_t {
+ typedef eltwise_bwd_pd_t base_class;
+ typedef eltwise_fwd_pd_t hint_class;
+
+ eltwise_bwd_pd_t(engine_t *engine,
+ const eltwise_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const eltwise_fwd_pd_t *hint_fwd_pd)
+ : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_data_md_(desc_.diff_data_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1; }
+
+ bool is_zero_preserved() const { return true; }
+
+protected:
+ memory_desc_t diff_data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.cpp b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp
new file mode 100644
index 0000000000..3b3e25456d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/engine.cpp
@@ -0,0 +1,75 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+#include "engine.hpp"
+#include "nstl.hpp"
+
+#include "c_types_map.hpp"
+#include "../cpu/cpu_engine.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+engine_factory_t *engine_factories[] = {
+ &cpu::engine_factory,
+ nullptr,
+};
+
+static inline engine_factory_t *get_engine_factory(engine_kind_t kind) {
+ for (engine_factory_t **ef = engine_factories; *ef; ef++)
+ if ((*ef)->kind() == kind)
+ return *ef;
+ return nullptr;
+}
+
+}
+}
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+size_t mkldnn_engine_get_count(engine_kind_t kind) {
+ engine_factory_t *ef = get_engine_factory(kind);
+ return ef != nullptr ? ef->count() : 0;
+}
+
+status_t mkldnn_engine_create(engine_t **engine,
+ engine_kind_t kind, size_t index) {
+ if (engine == nullptr)
+ return invalid_arguments;
+
+ engine_factory_t *ef = get_engine_factory(kind);
+ if (ef == nullptr || index >= ef->count())
+ return invalid_arguments;
+
+ return ef->engine_create(engine, index);
+}
+
+status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) {
+ if (engine == nullptr)
+ return invalid_arguments;
+ *kind = engine->kind();
+ return success;
+}
+
+status_t mkldnn_engine_destroy(engine_t *engine) {
+ /* TODO: engine->dec_ref_count(); */
+ delete engine;
+ return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/engine.hpp b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp
new file mode 100644
index 0000000000..8ac8a29de5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/engine.hpp
@@ -0,0 +1,119 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef ENGINE_HPP
+#define ENGINE_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive.hpp"
+#include "utils.hpp"
+
+/** \brief An abstraction of an execution unit with shared resources
+ *
+ * Responsibilities:
+ * - Provide engine specific memory allocation
+ * - Provide engine specific primitive_desc_t creators
+ */
+struct mkldnn_engine: public mkldnn::impl::c_compatible {
+ mkldnn_engine(mkldnn::impl::engine_kind_t kind)
+ : kind_(kind)
+ {}
+ virtual ~mkldnn_engine() {}
+
+ /** get kind of the current engine */
+ virtual mkldnn::impl::engine_kind_t kind() const { return kind_; }
+
+ /** allocate memory */
+ virtual mkldnn::impl::status_t memory_create(
+ mkldnn::impl::memory_t **memory,
+ const mkldnn::impl::memory_desc_t *md,
+ void *handle) = 0;
+
+ /** implementation section (typedefs) */
+
+ // TODO: remove engine?
+ typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)(
+ mkldnn::impl::reorder_pd_t **reorder_pd,
+ mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_attr_t *attr,
+ mkldnn::impl::engine_t *src_engine,
+ const mkldnn::impl::memory_desc_t *src_md,
+ mkldnn::impl::engine_t *dst_engine,
+ const mkldnn::impl::memory_desc_t *dst_md);
+
+ typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)(
+ mkldnn::impl::concat_pd_t **concat_pd,
+ mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_attr_t *attr,
+ const mkldnn::impl::memory_desc_t *dst_md,
+ int n, int concat_dim,
+ const mkldnn::impl::memory_desc_t *src_mds);
+
+ typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)(
+ mkldnn::impl::sum_pd_t **sum_pd,
+ mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_attr_t *attr,
+ const mkldnn::impl::memory_desc_t *dst_md,
+ int n, const float *scales,
+ const mkldnn::impl::memory_desc_t *src_mds);
+
+ typedef mkldnn::impl::status_t (*primitive_desc_create_f)(
+ mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *,
+ const mkldnn::impl::primitive_attr_t *attr,
+ mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *);
+
+ /* implementation section */
+
+ /** return the list of reorder implementations. engine guarantees to return
+ * a NULL-terminated list */
+ virtual const reorder_primitive_desc_create_f*
+ get_reorder_implementation_list() const = 0;
+
+ /** return the list of concat implementations. engine guarantees to return
+ * a NULL-terminated list */
+ virtual const concat_primitive_desc_create_f*
+ get_concat_implementation_list() const = 0;
+
+ /** return the list of sum implementations. engine guarantees to return
+ * a NULL-terminated list */
+ virtual const sum_primitive_desc_create_f*
+ get_sum_implementation_list() const = 0;
+
+ /** return the list of implementations. engine guarantees to return a
+ * NULL-terminated list */
+ virtual const primitive_desc_create_f* get_implementation_list() const = 0;
+
+protected:
+ mkldnn::impl::engine_kind_t kind_;
+};
+
+namespace mkldnn {
+namespace impl {
+
+struct engine_factory_t: public c_compatible {
+ virtual size_t count() const = 0;
+ virtual engine_kind_t kind() const = 0;
+ virtual status_t engine_create(engine_t **engine, size_t index) const = 0;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
new file mode 100644
index 0000000000..5a9f58cb1e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
@@ -0,0 +1,106 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+ const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) {
+ bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc);
+ if (!args_ok) return invalid_arguments;
+
+ auto id = inner_product_desc_t();
+ id.primitive_kind = primitive_kind::inner_product;
+ id.prop_kind = prop_kind;
+
+ id.diff_src_desc = id.src_desc = zero_md();
+ id.diff_dst_desc = id.dst_desc = zero_md();
+ id.diff_weights_desc = id.weights_desc = zero_md();
+ id.diff_bias_desc = id.bias_desc = zero_md();
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+ const bool with_bias =
+ bias_desc && bias_desc->format_kind != format_kind::undef;
+
+ (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc;
+ (is_fwd ? id.dst_desc : id.diff_dst_desc) = *dst_desc;
+ (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) =
+ *weights_desc;
+ if (with_bias)
+ (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) =
+ *bias_desc;
+
+ id.accum_data_type = types::default_accum_data_type(src_desc->data_type,
+ weights_desc->data_type, dst_desc->data_type, prop_kind);
+
+ bool consistency = true
+ && memory_desc_wrapper(weights_desc).nelems()
+ && one_of(src_desc->ndims, 2, 3, 4, 5)
+ && dst_desc->ndims == 2
+ && weights_desc->ndims == src_desc->ndims
+ && (with_bias ? bias_desc->ndims == 1 : true)
+ && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
+ && src_desc->dims[0] == dst_desc->dims[0]
+ && array_cmp(&src_desc->dims[1], &weights_desc->dims[1],
+ src_desc->ndims - 1)
+ && dst_desc->dims[1] == weights_desc->dims[0];
+ if (!consistency) return invalid_arguments;
+
+ *ip_desc = id;
+ return success;
+}
+}
+
+status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc,
+ prop_kind_t prop_kind, const memory_desc_t *src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_desc) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc,
+ dst_desc);
+}
+
+status_t mkldnn_inner_product_backward_data_desc_init(
+ inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc,
+ const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc)
+{
+ return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc,
+ nullptr, diff_dst_desc);
+}
+
+status_t mkldnn_inner_product_backward_weights_desc_init(
+ inner_product_desc_t *ip_desc, const memory_desc_t *src_desc,
+ const memory_desc_t *diff_weights_desc,
+ const memory_desc_t *diff_bias_desc,
+ const memory_desc_t *diff_dst_desc) {
+ return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc,
+ diff_bias_desc, diff_dst_desc);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp
new file mode 100644
index 0000000000..091cf0f5d6
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp
@@ -0,0 +1,56 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "utils.hpp"
+
+#include "inner_product_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+using namespace prop_kind;
+
+memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) {
+ return desc->prop_kind == backward_data
+ ? &desc->diff_src_desc : &desc->src_desc;
+}
+
+memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) {
+ return desc->prop_kind == backward_weights
+ ? &desc->diff_weights_desc : &desc->weights_desc;
+}
+
+memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) {
+ return desc->prop_kind == backward_weights
+ ? &desc->diff_bias_desc : &desc->bias_desc;
+}
+
+memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) {
+ return utils::one_of(desc->prop_kind, forward_inference, forward_training)
+ ? &desc->dst_desc : &desc->diff_dst_desc;
+}
+
+const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); }
+const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); }
+const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); }
+const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); }
+
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp
new file mode 100644
index 0000000000..c426de632c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp
@@ -0,0 +1,321 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef INNER_PRODUCT_PD_HPP
+#define INNER_PRODUCT_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc);
+memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc);
+memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc);
+memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc);
+
+struct inner_product_fwd_pd_t;
+
+struct inner_product_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::inner_product;
+
+ inner_product_pd_t(engine_t *engine,
+ const inner_product_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const inner_product_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ {}
+
+ const inner_product_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::inner_product_d:
+ *(const inner_product_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common inner_product aux functions */
+
+ dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; }
+ dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; }
+ dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; }
+
+ dim_t ID() const {
+ return ndims() >= 5
+ ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t IH() const {
+ return ndims() >= 4
+ ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t IW() const {
+ return ndims() >= 3
+ ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1;
+ }
+
+ dim_t OD() const {
+ return ndims() >= 5
+ ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t OH() const {
+ return ndims() >= 4
+ ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t OW() const {
+ return ndims() >= 3
+ ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1;
+ }
+
+ dim_t KD() const {
+ return ndims() >= 5
+ ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1;
+ }
+ dim_t KH() const {
+ return ndims() >= 4
+ ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1;
+ }
+ dim_t KW() const {
+ return ndims() >= 3
+ ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1;
+ }
+
+ dim_t IC_total() const {
+ return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1],
+ ndims() - 1);
+ }
+
+ dim_t IC_total_padded() const {
+ auto src_d = desc()->prop_kind == prop_kind::backward_data
+ ? memory_desc_wrapper(diff_src_md())
+ : memory_desc_wrapper(src_md());
+ assert(src_d.is_blocking_desc());
+ if (!src_d.is_blocking_desc()) return -1;
+ return utils::array_product(src_d.padded_dims() + 1, ndims() - 1);
+ }
+
+ int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; }
+
+ bool with_bias() const
+ { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); }
+
+ bool has_zero_dim_memory() const {
+ const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_));
+ const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_));
+ return s_d.has_zero_dim() || d_d.has_zero_dim();
+ }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+protected:
+ inner_product_desc_t desc_;
+ const inner_product_fwd_pd_t *hint_fwd_pd_;
+
+ status_t template_set_default_params(memory_desc_t &src_md,
+ memory_desc_t &weights_md, memory_desc_t &dst_md,
+ memory_desc_t *bias_md) {
+ using namespace format_tag;
+ if (src_md.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(src_md,
+ utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw)));
+ }
+ if (dst_md.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_md, nc));
+ if (weights_md.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(weights_md,
+ utils::pick(ndims() - 2, oi, oiw, oihw, oidhw)));
+ }
+ if (bias_md && bias_md->format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(*bias_md, x));
+ return status::success;
+ }
+};
+
+struct inner_product_fwd_pd_t: public inner_product_pd_t {
+ typedef inner_product_fwd_pd_t base_class;
+ typedef inner_product_fwd_pd_t hint_class;
+
+ inner_product_fwd_pd_t(engine_t *engine,
+ const inner_product_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const inner_product_fwd_pd_t *hint_fwd_pd)
+ : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , weights_md_(desc_.weights_desc)
+ , bias_md_(desc_.bias_desc)
+ , dst_md_(desc_.dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_BIAS && with_bias())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_md_;
+ if (index == 1 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2 + with_bias(); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t dst_md_;
+
+ status_t set_default_params() {
+ return template_set_default_params(src_md_, weights_md_, dst_md_,
+ &bias_md_);
+ }
+};
+
+struct inner_product_bwd_data_pd_t: public inner_product_pd_t {
+ typedef inner_product_bwd_data_pd_t base_class;
+ typedef inner_product_fwd_pd_t hint_class;
+
+ inner_product_bwd_data_pd_t(engine_t *engine,
+ const inner_product_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const inner_product_fwd_pd_t *hint_fwd_pd)
+ : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_md_(desc_.diff_src_desc)
+ , weights_md_(desc_.weights_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *weights_md(int index = 0) const override
+ { return index == 0 ? &weights_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_src_md_;
+ memory_desc_t weights_md_;
+ memory_desc_t diff_dst_md_;
+
+ status_t set_default_params() {
+ return template_set_default_params(diff_src_md_, weights_md_,
+ diff_dst_md_, nullptr);
+ }
+};
+
+struct inner_product_bwd_weights_pd_t: public inner_product_pd_t {
+ typedef inner_product_bwd_weights_pd_t base_class;
+ typedef inner_product_fwd_pd_t hint_class;
+
+ inner_product_bwd_weights_pd_t(engine_t *engine,
+ const inner_product_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const inner_product_fwd_pd_t *hint_fwd_pd)
+ : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , diff_weights_md_(desc_.diff_weights_desc)
+ , diff_bias_md_(desc_.diff_bias_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
+ if (index == 0) return &diff_weights_md_;
+ if (index == 1 && with_bias()) return &diff_bias_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override { return 2; }
+ virtual int n_outputs() const override { return 1 + with_bias(); }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t diff_weights_md_;
+ memory_desc_t diff_bias_md_;
+ memory_desc_t diff_dst_md_;
+
+ status_t set_default_params() {
+ return template_set_default_params(src_md_, diff_weights_md_,
+ diff_dst_md_, &diff_bias_md_);
+ }
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
new file mode 100644
index 0000000000..fcf18b556f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
@@ -0,0 +1,91 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t lrn_desc_init(lrn_desc_t *lrn_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc,
+ dim_t local_size, float alpha, float beta, float k) {
+ bool args_ok = true
+ && !any_null(lrn_desc, data_desc)
+ && one_of(alg_kind, lrn_within_channel, lrn_across_channels)
+ && one_of(prop_kind, forward_training, forward_inference, backward_data)
+ && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
+ if (!args_ok) return invalid_arguments;
+
+ auto ld = lrn_desc_t();
+ ld.primitive_kind = primitive_kind::lrn;
+ ld.prop_kind = prop_kind;
+ ld.alg_kind = alg_kind;
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+
+ ld.data_desc = *data_desc;
+ if (!is_fwd)
+ ld.diff_data_desc = *diff_data_desc;
+ else
+ ld.diff_data_desc = zero_md();
+ ld.local_size = local_size;
+ ld.lrn_alpha = alpha;
+ ld.lrn_beta = beta;
+ ld.lrn_k = k;
+
+ bool consistency = true
+ && ld.data_desc.ndims == 4;
+ if (ld.prop_kind == backward_data)
+ consistency = consistency
+ && ld.diff_data_desc.ndims == 4
+ && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4);
+ if (!consistency) return invalid_arguments;
+
+ *lrn_desc = ld;
+ return success;
+}
+}
+
+status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *data_desc, dim_t local_size, float alpha,
+ float beta, float k) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr,
+ local_size, alpha, beta, k);
+}
+
+status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc,
+ alg_kind_t alg_kind, const memory_desc_t *data_desc,
+ const memory_desc_t *diff_data_desc, dim_t local_size, float alpha,
+ float beta, float k) {
+ return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc,
+ diff_data_desc, local_size, alpha, beta, k);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
new file mode 100644
index 0000000000..90886e9656
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
@@ -0,0 +1,170 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef LRN_PD_HPP
+#define LRN_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct lrn_fwd_pd_t;
+
+struct lrn_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::lrn;
+
+ lrn_pd_t(engine_t *engine,
+ const lrn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const lrn_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ , ws_md_()
+ {}
+
+ const lrn_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::lrn_d:
+ *(const lrn_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common lrn aux functions */
+
+ dim_t MB() const { return data_desc().dims[0]; }
+ dim_t C() const { return data_desc().dims[1]; }
+ dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+ int ndims() const { return data_desc().ndims; }
+
+ bool has_zero_dim_memory() const
+ { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+protected:
+ lrn_desc_t desc_;
+ const lrn_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t data_md_;
+ memory_desc_t ws_md_;
+
+private:
+ const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct lrn_fwd_pd_t: public lrn_pd_t {
+ typedef lrn_fwd_pd_t base_class;
+ typedef lrn_fwd_pd_t hint_class;
+
+ lrn_fwd_pd_t(engine_t *engine,
+ const lrn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const lrn_fwd_pd_t *hint_fwd_pd)
+ : lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override
+ { return 1 + (workspace_md() != nullptr); }
+};
+
+struct lrn_bwd_pd_t: public lrn_pd_t {
+ typedef lrn_bwd_pd_t base_class;
+ typedef lrn_fwd_pd_t hint_class;
+
+ lrn_bwd_pd_t(engine_t *engine,
+ const lrn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const lrn_fwd_pd_t *hint_fwd_pd)
+ : lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_data_md_(desc_.diff_data_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::input;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ virtual int n_inputs() const override
+ { return 2 + (workspace_md() != nullptr); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
new file mode 100644
index 0000000000..3fddc0bd45
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
@@ -0,0 +1,280 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MATH_UTILS_HPP
+#define MATH_UTILS_HPP
+
+#include <stdint.h>
+#include <math.h>
+
+#include "utils.hpp"
+#include "nstl.hpp"
+#include "mkldnn_traits.hpp"
+
+#if defined(MKLDNN_X86_64)
+#include "immintrin.h"
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace math {
+
+/** rounds @p f to an integer according to the mxcsr register */
+inline int mxcsr_round(float f) {
+#if defined(MKLDNN_X86_64)
+ return _mm_cvtss_si32(_mm_load_ss(&f));
+#else
+ return (int)nearbyintf(f); // optimism
+#endif
+}
+
+template <typename data_t, typename acc_t>
+inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
+ typename utils::remove_reference<data_t>::type>::type
+saturate(const acc_t &x) {
+ return (typename utils::remove_reference<data_t>::type)x;
+}
+
+template <typename data_t, typename acc_t>
+inline typename utils::enable_if<nstl::is_integral<data_t>::value,
+ typename utils::remove_reference<data_t>::type>::type
+saturate(const acc_t &x) {
+ acc_t v = x;
+ if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
+ v = (acc_t)nstl::numeric_limits<data_t>::lowest();
+ if (v > (acc_t)nstl::numeric_limits<data_t>::max())
+ v = (acc_t)nstl::numeric_limits<data_t>::max();
+ return (typename utils::remove_reference<data_t>::type)v;
+}
+
+template <typename data_t>
+double saturate(const double &x) {
+ double v = x;
+ if (v < (double)nstl::numeric_limits<data_t>::lowest())
+ v = (double)nstl::numeric_limits<data_t>::lowest();
+ if (v > (double)nstl::numeric_limits<data_t>::max())
+ v = (double)nstl::numeric_limits<data_t>::max();
+ return v;
+}
+
+template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
+ return x <= 127u ? x : 127;
+}
+
+template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
+ return x >= 0 ? x : 0;
+}
+
+template <typename out_t>
+typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
+out_round(float v) { return (out_t)mxcsr_round(v); }
+
+template <typename out_t>
+typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
+out_round(double v) { return (out_t)mxcsr_round((float)v); }
+
+template <typename out_t>
+typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
+out_round(float v) { return v; }
+
+inline int gcd(int a, int b) {
+ a = impl::nstl::abs(a);
+ b = impl::nstl::abs(b);
+ if (a < b) { int x = a; a = b; b = x; }
+
+ if (b == 0) return a;
+
+ int r;
+ while ((r = a % b) != 0) { a = b; b = r; }
+
+ return b;
+}
+
+template <typename T>
+inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
+
+/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
+inline int ilog2q(size_t v) {
+ if (v == 0)
+ return -1;
+
+ int p = 0;
+# define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
+ CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
+# undef CP
+ return p;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U one_m_square(T x) {
+ return (U)(1 - x) * (1 + x);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U x_m_square(T x) {
+ return (U)(1 - x) * x;
+}
+
+/* activation */
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U relu_fwd(T s, A alpha) {
+ return s > 0 ? s : (U)(s * alpha);
+}
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U relu_bwd(T dd, T s, A alpha) {
+ return s > 0 ? dd : (U)(dd * alpha);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U tanh_fwd(T s) {
+ const float e = tanhf((float) s);
+ return (U)e;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U tanh_bwd(T dd, T s) {
+ const float e = tanh_fwd<float>((float) s);
+ return (U)(dd * (1 - e) * (1 + e));
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U elu_fwd(T s, A alpha) {
+ return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
+}
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+ inline U elu_bwd(T dd, T s, A alpha) {
+ return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U square_fwd(T s) {
+ return s * s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U square_bwd(T dd, T s) {
+ return dd * 2 * s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U abs_fwd(T s) {
+ return s > 0 ? s : -s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U abs_bwd(T dd, T s) {
+ return s > 0 ? dd : s < 0 ? -dd : 0;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U sqrt_fwd(T s) {
+ return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U sqrt_bwd(T dd, T s) {
+ return s > 0
+ ? (U)(dd / (2 * ::sqrtf((float)(s))))
+ : 0;
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U linear_fwd(T s, A alpha, A beta) {
+ return (U)(alpha * s + beta);
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U linear_bwd(T dd, T s, A alpha, A beta) {
+ (void) s;
+ (void) beta;
+ return (U)(dd * alpha);
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U bounded_relu_fwd(T s, A alpha) {
+ s = s > 0 ? s : 0;
+ return s > alpha ? (U)(alpha) : s;
+}
+
+template <typename T, typename A,
+ typename U = typename utils::remove_reference<T>::type>
+inline U bounded_relu_bwd(T dd, T s, A alpha) {
+ return dd * (0 < s && s < alpha ? 1 : 0);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U soft_relu_fwd(T s) {
+ float max_logf = 8.872284e+01; //::logf(FLT_MAX)
+ return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U soft_relu_bwd(T dd, T s) {
+ return (U)(dd / (1 + ::expf((float)(-s))));
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U logistic_fwd(T s) {
+ U v = (U)(::expf((float) -s));
+ return 1 / (1 + v);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U logistic_bwd(T dd, T s) {
+ U v = logistic_fwd<T, U>(s);
+ return dd * v * (1 - v);
+}
+
+inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
+ using namespace alg_kind;
+ using namespace utils;
+ const bool preserves_zero = true
+ && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic)
+ && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh));
+ return preserves_zero;
+}
+
+inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
+{
+ if (!bias)
+ return 0.0f;
+
+#define CASE(dt) \
+ case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
+
+ switch (data_type) {
+ CASE(data_type::s8);
+ CASE(data_type::u8);
+ CASE(data_type::s32);
+ CASE(data_type::f32);
+ default: assert(!"unimplemented");
+ }
+ return 0; // never happens (should probably be a NaN)
+#undef CASE
+}
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp
new file mode 100644
index 0000000000..cea849c96e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory.cpp
@@ -0,0 +1,238 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::data_type;
+
+namespace {
+bool memory_desc_sanity_check(int ndims,const dims_t dims,
+ data_type_t data_type, format_kind_t format_kind) {
+ if (ndims == 0) return true;
+
+ bool ok = true
+ && dims != nullptr
+ && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS
+ && one_of(data_type, f32, s32, s8, u8)
+ && format_kind != format_kind::undef;
+ if (!ok) return false;
+ for (int d = 0; d < ndims; ++d)
+ if (dims[d] < 0) return false;
+
+ return true;
+}
+
+bool memory_desc_sanity_check(const memory_desc_t *md) {
+ if (md == nullptr) return false;
+ return memory_desc_sanity_check(md->ndims, md->dims, md->data_type,
+ format_kind::any);
+}
+}
+
+status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims,
+ const dims_t dims, data_type_t data_type, format_tag_t tag) {
+ if (any_null(memory_desc)) return invalid_arguments;
+ if (ndims == 0 || tag == format_tag::undef) {
+ *memory_desc = types::zero_md();
+ return success;
+ }
+
+ format_kind_t format_kind = types::format_tag_to_kind(tag);
+
+ /* memory_desc != 0 */
+ bool args_ok = !any_null(memory_desc)
+ && memory_desc_sanity_check(ndims, dims, data_type, format_kind);
+ if (!args_ok) return invalid_arguments;
+
+ auto md = memory_desc_t();
+ md.ndims = ndims;
+ array_copy(md.dims, dims, ndims);
+ md.data_type = data_type;
+ array_copy(md.padded_dims, dims, ndims);
+ md.format_kind = format_kind;
+
+ status_t status = success;
+ if (tag == format_tag::undef) {
+ status = invalid_arguments;
+ } else if (tag == format_tag::any) {
+ // nop
+ } else if (format_kind == format_kind::blocked) {
+ status = memory_desc_wrapper::compute_blocking(md, tag);
+ } else {
+ assert(!"unreachable");
+ status = invalid_arguments;
+ }
+
+ if (status == success)
+ *memory_desc = md;
+
+ return status;
+}
+
+status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc,
+ int ndims, const dims_t dims, data_type_t data_type,
+ const dims_t strides) {
+ if (any_null(memory_desc)) return invalid_arguments;
+ if (ndims == 0) {
+ *memory_desc = types::zero_md();
+ return success;
+ }
+
+ /* memory_desc != 0 */
+ bool args_ok = !any_null(memory_desc)
+ && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any);
+ if (!args_ok) return invalid_arguments;
+
+ auto md = memory_desc_t();
+ md.ndims = ndims;
+ array_copy(md.dims, dims, ndims);
+ md.data_type = data_type;
+ array_copy(md.padded_dims, dims, ndims);
+ md.format_kind = format_kind::blocked;
+
+ dims_t default_strides = {0};
+ if (strides == nullptr) {
+ default_strides[md.ndims - 1] = 1;
+ for (int d = md.ndims - 2; d >= 0; --d)
+ default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1];
+ strides = default_strides;
+ } else {
+ /* TODO: add sanity check for the provided strides */
+ }
+
+ array_copy(md.format_desc.blocking.strides, strides, md.ndims);
+
+ *memory_desc = md;
+
+ return status::success;
+}
+
+status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md,
+ const memory_desc_t *parent_md, const dims_t dims,
+ const dims_t offsets) {
+ if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md))
+ return invalid_arguments;
+
+ const memory_desc_wrapper src_d(parent_md);
+
+ for (int d = 0; d < src_d.ndims(); ++d) {
+ if (dims[d] < 0 || offsets[d] < 0
+ || (offsets[d] + dims[d] > src_d.dims()[d]))
+ return invalid_arguments;
+ }
+
+ if (src_d.format_kind() != format_kind::blocked)
+ return unimplemented;
+
+ dims_t blocks;
+ src_d.compute_blocks(blocks);
+
+ memory_desc_t dst_d = *parent_md;
+ auto &dst_d_blk = dst_d.format_desc.blocking;
+
+ /* TODO: put this into memory_desc_wrapper */
+ for (int d = 0; d < src_d.ndims(); ++d) {
+ /* very limited functionality for now */
+ const bool ok = true
+ && offsets[d] % blocks[d] == 0 /* [r1] */
+ && src_d.padded_offsets()[d] == 0
+ && (false
+ || dims[d] % blocks[d] == 0
+ || dims[d] < blocks[d]);
+ if (!ok)
+ return unimplemented;
+
+ const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d];
+
+ dst_d.dims[d] = dims[d];
+ dst_d.padded_dims[d] = is_right_border
+ ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d];
+ dst_d.padded_offsets[d] = src_d.padded_offsets()[d];
+ dst_d.offset0 += /* [r1] */
+ offsets[d] / blocks[d] * dst_d_blk.strides[d];
+ }
+
+ *md = dst_d;
+
+ return success;
+}
+
+int mkldnn_memory_desc_equal(const memory_desc_t *lhs,
+ const memory_desc_t *rhs) {
+ if (lhs == rhs) return 1;
+ if (any_null(lhs, rhs)) return 0;
+ return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs);
+}
+
+size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) {
+ if (md == nullptr) return 0;
+ return memory_desc_wrapper(*md).size();
+}
+
+status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md,
+ engine_t *engine, void *handle) {
+ if (any_null(memory, engine)) return invalid_arguments;
+ memory_desc_t z_md = types::zero_md();
+ return engine->memory_create(memory, md ? md : &z_md, handle);
+}
+
+status_t mkldnn_memory_get_memory_desc(const memory_t *memory,
+ const memory_desc_t **md) {
+ if (any_null(memory, md)) return invalid_arguments;
+ *md = memory->md();
+ return success;
+}
+
+status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) {
+ if (any_null(memory, engine)) return invalid_arguments;
+ *engine = memory->engine();
+ return success;
+}
+
+status_t mkldnn_memory_get_data_handle(const memory_t *memory,
+ void **handle) {
+ if (any_null(handle))
+ return invalid_arguments;
+ if (memory == nullptr) {
+ *handle = nullptr;
+ return success;
+ }
+ return memory->get_data_handle(handle);
+}
+
+status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) {
+ if (any_null(memory)) return invalid_arguments;
+ return memory->set_data_handle(handle);
+}
+
+status_t mkldnn_memory_destroy(memory_t *memory) {
+ delete memory;
+ return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp
new file mode 100644
index 0000000000..03dfee01ff
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory.hpp
@@ -0,0 +1,63 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MEMORY_HPP
+#define MEMORY_HPP
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+
+struct mkldnn_memory: public mkldnn::impl::c_compatible {
+ mkldnn_memory(mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::memory_desc_t *md)
+ : engine_(engine), md_(*md) {}
+ virtual ~mkldnn_memory() {}
+
+ /** allocates/initializes memory */
+ virtual mkldnn::impl::status_t init() = 0;
+
+ /** returns memory's engine */
+ mkldnn::impl::engine_t *engine() const { return engine_; }
+ /** returns memory's description */
+ const mkldnn::impl::memory_desc_t *md() const { return &md_; }
+
+ /** returns data handle */
+ virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0;
+
+ /** sets data handle */
+ virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0;
+
+ /** zeros padding */
+ virtual mkldnn::impl::status_t zero_pad() const
+ { return mkldnn::impl::status::success; }
+
+protected:
+ mkldnn::impl::engine_t *engine_;
+ const mkldnn::impl::memory_desc_t md_;
+
+private:
+ mkldnn_memory() = delete;
+ mkldnn_memory(const mkldnn_memory &) = delete;
+ mkldnn_memory(mkldnn_memory &&) = delete;
+ mkldnn_memory &operator=(const mkldnn_memory &) = delete;
+ mkldnn_memory &operator=(mkldnn_memory &&) = delete;
+};
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp
new file mode 100644
index 0000000000..8a99be33f3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp
@@ -0,0 +1,212 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include <initializer_list>
+
+#include "c_types_map.hpp"
+#include "memory_desc_wrapper.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+status_t fill_blocked(memory_desc_t &md,
+ std::initializer_list<int> perm,
+ std::initializer_list<int> inner_blks,
+ std::initializer_list<int> inner_idxs) {
+ const bool ok = true
+ && perm.size() == (size_t)md.ndims
+ && inner_blks.size() == inner_idxs.size();
+ if (!ok) return status::invalid_arguments;
+
+ md.offset0 = 0;
+
+ blocking_desc_t &blk = md.format_desc.blocking;
+
+ dim_t block_size = 1;
+ dims_t blocks = {0};
+ utils::array_set(blocks, 1, md.ndims);
+
+ blk.inner_nblks = (int)inner_blks.size();
+
+ int iblk = 0;
+ for (const auto &b: inner_idxs)
+ blk.inner_idxs[iblk++] = b;
+
+ iblk = 0;
+ for (const auto &b: inner_blks) {
+ int dim = blk.inner_idxs[iblk];
+ block_size *= b;
+ blocks[dim] *= b;
+ blk.inner_blks[iblk++] = b;
+ }
+
+ utils::array_set(md.padded_offsets, 0, md.ndims);
+ for (int d = 0; d < md.ndims; ++d)
+ md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
+
+ dim_t stride = block_size;
+ // if only we use C++14, the initializer_list would have rbegin()/rend()...
+ for (int d = 0; d < md.ndims; ++d)
+ stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d];
+
+ for (const auto &d: perm) {
+ if (md.padded_dims[d] == 0) {
+ blk.strides[d] = 1;
+ continue;
+ }
+ stride /= md.padded_dims[d] / blocks[d];
+ blk.strides[d] = stride;
+ }
+
+ assert(stride == block_size);
+
+ return status::success;
+}
+
+status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc,
+ format_tag_t tag)
+{
+ using namespace format_tag;
+
+ if (memory_desc.ndims == 0) return status::invalid_arguments;
+
+# define C(tag, ... /* perm, inner_blks, inner_idxs */) \
+ case tag: return fill_blocked(memory_desc, __VA_ARGS__)
+
+ switch (tag) {
+ C(a, {0}, {}, {});
+ C(ab, {0, 1}, {}, {});
+ C(abc, {0, 1, 2}, {}, {});
+ C(abcd, {0, 1, 2, 3}, {}, {});
+ C(abcde, {0, 1, 2, 3, 4}, {}, {});
+ C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {});
+ C(abdec, {0, 1, 3, 4, 2}, {}, {});
+ C(acb, {0, 2, 1}, {}, {});
+ C(acbde, {0, 2, 1, 3, 4}, {}, {});
+ C(acdb, {0, 2, 3, 1}, {}, {});
+ C(acdeb, {0, 2, 3, 4, 1}, {}, {});
+ C(ba, {1, 0}, {}, {});
+ C(bac, {1, 0, 2}, {}, {});
+ C(bacd, {1, 0, 2, 3}, {}, {});
+ C(bcda, {1, 2, 3, 0}, {}, {});
+ C(cba, {2, 1, 0}, {}, {});
+ C(cdba, {2, 3, 1, 0}, {}, {});
+ C(cdeba, {2, 3, 4, 1, 0}, {}, {});
+ C(decab, {3, 4, 2, 0, 1}, {}, {});
+
+ C(Abc4a, {0, 1, 2}, {4}, {0});
+ C(aBc4b, {0, 1, 2}, {4}, {1});
+ C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1});
+ C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0});
+ C(Abcd4a, {0, 1, 2, 3}, {4}, {0});
+ C(aBcd4b, {0, 1, 2, 3}, {4}, {1});
+ C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0});
+ C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2});
+ C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
+ C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0});
+ C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1});
+ C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0});
+ C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
+ C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1});
+ C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1});
+ C(aBdc4b, {0, 1, 3, 2}, {4}, {1});
+ C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1});
+ C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1});
+ C(Acb4a, {0, 2, 1}, {4}, {0});
+ C(Acdb4a, {0, 2, 3, 1}, {4}, {0});
+ C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0});
+
+ C(Abc16a, {0, 1, 2}, {16}, {0});
+ C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1});
+ C(aBc16b, {0, 1, 2}, {16}, {1});
+ C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0});
+ C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0});
+ C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1});
+ C(aBc8b, {0, 1, 2}, {8}, {1});
+ C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1});
+ C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0});
+ C(Abcd16a, {0, 1, 2, 3}, {16}, {0});
+ C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1});
+ C(aBcd16b, {0, 1, 2, 3}, {16}, {1});
+ C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0});
+ C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2});
+ C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1});
+ C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1});
+ C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0});
+ C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1});
+ C(aBcd8b, {0, 1, 2, 3}, {8}, {1});
+ C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1});
+ C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1});
+ C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0});
+ C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2});
+ C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2});
+ C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1});
+ C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0});
+ C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1});
+ C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1});
+ C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0});
+ C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2});
+ C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1});
+ C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2});
+ C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2});
+ C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2});
+ C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0});
+ C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1});
+ C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1});
+ C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1});
+ C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1});
+ C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0});
+ C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2});
+ C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2});
+ C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1});
+ C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1});
+ C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2});
+ C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1});
+ C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2});
+ C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2});
+ C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1});
+ C(aBdc16b, {0, 1, 3, 2}, {16}, {1});
+ C(aBdc8b, {0, 1, 3, 2}, {8}, {1});
+ C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1});
+ C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1});
+ C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1});
+ C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1});
+ C(Acb16a, {0, 2, 1}, {16}, {0});
+ C(Acb8a, {0, 2, 1}, {8}, {0});
+ C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2});
+ C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2});
+ C(Acdb16a, {0, 2, 3, 1}, {16}, {0});
+ C(Acdb8a, {0, 2, 3, 1}, {8}, {0});
+ C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0});
+ C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0});
+ C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1});
+ C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1});
+ default: break;
+ }
+
+#undef C
+
+ return status::invalid_arguments;
+}
+
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp
new file mode 100644
index 0000000000..1758f9078a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp
@@ -0,0 +1,400 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MEMORY_DESC_WRAPPER_HPP
+#define MEMORY_DESC_WRAPPER_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+/** thin wrapper class over \struct memory_desc_t which allows easy
+ * manipulations with underlying C structure, which is taken by reference */
+struct memory_desc_wrapper: public c_compatible {
+ const memory_desc_t *md_;
+
+ /** constructor which takes a reference to a constant underlying C memory
+ * descriptor \param md */
+ memory_desc_wrapper(const memory_desc_t *md): md_(md) {}
+ memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {}
+
+ /* implementing attributes */
+ int ndims() const { return md_->ndims; }
+ const dims_t &dims() const { return md_->dims; }
+ data_type_t data_type() const { return md_->data_type; }
+
+ const dims_t &padded_dims() const { return md_->padded_dims; }
+ const dims_t &padded_offsets() const { return md_->padded_offsets; }
+ dim_t offset0() const { return md_->offset0; }
+
+ format_kind_t format_kind() const { return md_->format_kind; }
+
+ bool is_blocking_desc() const
+ { return format_kind() == format_kind::blocked; }
+ bool is_wino_desc() const
+ { return format_kind() == format_kind::wino; }
+ bool is_rnn_packed_desc() const
+ { return format_kind() == format_kind::rnn_packed; }
+
+ const blocking_desc_t &blocking_desc() const {
+ assert(is_blocking_desc());
+ return md_->format_desc.blocking;
+ }
+ const wino_desc_t &wino_desc() const {
+ assert(is_wino_desc());
+ return md_->format_desc.wino_desc;
+ }
+ const rnn_packed_desc_t &rnn_packed_desc() const {
+ assert(is_rnn_packed_desc());
+ return md_->format_desc.rnn_packed_desc;
+ }
+
+ const memory_extra_desc_t &extra() const { return md_->extra; }
+
+ /* some useful function */
+
+ /** returns the number of elements including padding if \param with_padding
+ * is true, and the number of data elements otherwise */
+ dim_t nelems(bool with_padding = false) const {
+ if (is_zero()) return 0;
+ return utils::array_product(
+ with_padding ? padded_dims() : dims(), ndims());
+ }
+
+ /** returns true if memory descriptor is zero */
+ bool is_zero() const { return ndims() == 0; }
+
+ /** returns true if memory descriptor contains zero as one of its dim */
+ bool has_zero_dim() const { return nelems() == 0; }
+
+ /** return the size of data type (a shortcut) */
+ size_t data_type_size() const
+ { return types::data_type_size(data_type()); }
+
+ /** return the size of data type of additional buffer */
+ size_t additional_buffer_data_size() const {
+ if (extra().flags & memory_extra_flags::compensation_conv_s8s8)
+ return sizeof(int32_t);
+ return 0;
+ }
+
+ /** return true if memory format has additional buffer */
+ bool is_additional_buffer() const {
+ return (extra().flags & memory_extra_flags::compensation_conv_s8s8);
+ }
+
+ /** returns the size of additional buffer */
+ size_t additional_buffer_size() const {
+ if (extra().flags & memory_extra_flags::compensation_conv_s8s8) {
+ int cmask = extra().compensation_mask;
+ assert(cmask == 1 || cmask == 3);
+ dim_t prod = 1;
+ for (int d = 0; d < ndims(); ++d)
+ if (cmask & (1<<d)) prod *= padded_dims()[d];
+ return prod * additional_buffer_data_size();
+ }
+
+ return 0;
+ }
+
+ /** returns the size required to store described memory
+ * note: if offset0 != 0 returns 0 (need to specify the behavior) */
+ size_t size() const {
+ if (is_zero() || has_zero_dim() || format_kind() == format_kind::any)
+ return 0;
+
+ if (format_kind() == format_kind::wino) {
+ return wino_desc().size;
+ } else if (format_kind() == format_kind::rnn_packed) {
+ return rnn_packed_desc().size;
+ } else {
+ if (offset0() != 0) return 0;
+
+ dims_t blocks = {0};
+ compute_blocks(blocks);
+
+ const auto &bd = blocking_desc();
+
+ size_t max_size = 0;
+ for (int d = 0; d < ndims(); ++d)
+ max_size = nstl::max<size_t>(max_size,
+ padded_dims()[d] / blocks[d] * bd.strides[d]);
+
+ if (max_size == 1 && bd.inner_nblks != 0) {
+ max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
+ }
+
+ return max_size * data_type_size() + additional_buffer_size();
+ }
+ }
+
+ /** returns true if data is dense in memory */
+ bool is_dense(bool with_padding = false) const {
+ if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
+ return false;
+ return nelems(with_padding) * data_type_size() == size();
+ }
+
+ /** returns true if memory desc is fully defined */
+ bool is_defined() const { return format_kind() != format_kind::any; }
+
+ /** returns true if the only (potentially) padded dim is \param dim */
+ bool only_padded_dim(int dim) const {
+ for (int d = 0; d < ndims(); ++d)
+ if (d != dim && dims()[d] != padded_dims()[d])
+ return false;
+ return true;
+ }
+
+ /** returns true if memory desc has blocked layout and block dims are 1s */
+ bool is_plain() const {
+ if (!is_blocking_desc()) return false;
+ return blocking_desc().inner_nblks == 0;
+ }
+
+ /** returns overall block sizes */
+ void compute_blocks(dims_t blocks) const {
+ if (!is_blocking_desc()) {
+ utils::array_set(blocks, 0, ndims());
+ return;
+ }
+
+ utils::array_set(blocks, 1, ndims());
+
+ const auto &bd = blocking_desc();
+ for (int iblk = 0; iblk < bd.inner_nblks; ++iblk)
+ blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk];
+ }
+
+ /* comparison section */
+
+ bool operator==(const memory_desc_wrapper &rhs) const
+ { return *this->md_ == *rhs.md_; }
+ bool operator!=(const memory_desc_wrapper &rhs) const
+ { return !operator==(rhs); }
+ bool operator==(const memory_desc_t &rhs) const
+ { return operator==(memory_desc_wrapper(rhs)); }
+ bool operator!=(const memory_desc_t &rhs) const
+ { return !operator==(rhs); }
+
+ /** returns true if data (w/o padding if with_padding == false and w/
+ * padding otherwise) have the same physical structure, i.e. dimensions,
+ * strides, and blocked structure. Depending on with_data_type flag
+ * data_type is taken or not taken into account. dim_start allows to check
+ * similarity for the logical part of data [dim_start .. ndims()].
+ * CAUTION: format kind any and undef are not similar to whatever, hence the
+ * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
+ /* TODO: revise */
+ bool similar_to(const memory_desc_wrapper &rhs,
+ bool with_padding = true, bool with_data_type = true,
+ int dim_start = 0) const;
+
+ /** returns true if one memory can be reordered to another */
+ bool consistent_with(const memory_desc_wrapper &rhs) const;
+
+ /** returns true if the memory desc corresponds to the given format tag and
+ * strides.
+ * @sa memory_desc_matches_tag */
+ bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const {
+ return memory_desc_matches_tag(*md_, tag, strides);
+ }
+
+ /** returns matching tag (or undef if match is not found)
+ * XXX: This is a workaround that eventually should go away! */
+ template <typename... Tags>
+ format_tag_t matches_one_of_tag(Tags ...tags) const {
+ for (const auto tag: {tags...}) {
+ if (memory_desc_matches_tag(*md_, tag))
+ return tag;
+ }
+ return format_tag::undef;
+ }
+
+ /* offset section */
+
+ /** returns physical offset by logical one. logical offset is represented by
+ * an array \param pos. if \param is_pos_padded is true \param pos
+ * represents the position in already padded area */
+ dim_t off_v(const dims_t pos, bool is_pos_padded = false) const {
+ assert(is_blocking_desc());
+ const blocking_desc_t &blk = blocking_desc();
+
+ dims_t pos_copy = {0};
+ for (int d = 0; d < ndims(); ++d)
+ pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]);
+
+ dim_t phys_offset = offset0();
+
+ if (blk.inner_nblks > 0) {
+ dim_t blk_stride = 1;
+ for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) {
+ const int d = blk.inner_idxs[iblk];
+ const dim_t p = pos_copy[d] % blk.inner_blks[iblk];
+
+ phys_offset += p * blk_stride;
+
+ pos_copy[d] /= blk.inner_blks[iblk];
+
+ blk_stride *= blk.inner_blks[iblk];
+ }
+ }
+
+ for (int d = 0; d < ndims(); ++d) {
+ const dim_t p = pos_copy[d];
+ phys_offset += p * blk.strides[d];
+ }
+
+ return phys_offset;
+ }
+
+ /** returns physical offset by logical one. logical offset is represented by
+ * a scalar \param l_offset. if \param is_pos_padded is true, \param
+ * l_offset represents logical offset in already padded area */
+ dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const {
+ assert(is_blocking_desc());
+ dims_t pos;
+ for (int rd = 0; rd < ndims(); ++rd) {
+ const int d = ndims() - 1 - rd;
+ const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d];
+ pos[d] = l_offset % cur_dim;
+ l_offset /= cur_dim;
+ }
+ return off_v(pos, is_pos_padded);
+ }
+
+ /** returns physical offset by logical one. logical offset is represented by
+ * a tuple of indices (\param xn, ..., \param x1, \param x0) */
+ template<typename... Args>
+ dim_t off(Args... args) const {
+ assert(sizeof...(args) == ndims());
+ dims_t pos = { args... };
+ return off_v(pos, false);
+ }
+
+ /** returns physical offset by logical one. logical offset is represented by
+ * a tuple of indices (\param xn, ..., \param x1, \param x0) in already
+ * padded area */
+ template<typename... Args>
+ dim_t off_padding(Args... args) const {
+ assert(sizeof...(args) == ndims());
+ dims_t pos = { args... };
+ return off_v(pos, true);
+ }
+
+ /** returns physical offset by logical one. Logical offset is represented by
+ * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a
+ * user responsibility to adjust the result to get offset within blocks */
+ template<typename ...Args>
+ dim_t blk_off(Args... args) const {
+ return _blk_off<sizeof...(args), Args...>(args...);
+ }
+
+ template<bool skip_first, typename T, typename ...Args>
+ dim_t blk_off(T xn, Args... args) const {
+ return skip_first
+ ? blk_off<Args...>(args...)
+ : blk_off<T, Args...>(xn, args...);
+ }
+
+ /* static functions section */
+ /* TODO: replace with non-static, once md_ becomes non-const ref */
+
+ static status_t compute_blocking(memory_desc_t &memory_desc,
+ format_tag_t tag);
+
+private:
+ /* TODO: put logical_offset in utils */
+ template<typename T>
+ dim_t logical_offset(T x0) const { return x0; }
+
+ template<typename T, typename... Args>
+ dim_t logical_offset(T xn, Args... args) const {
+ const size_t n_args = sizeof...(args);
+ return xn * utils::array_product<n_args>(
+ &dims()[ndims() - n_args]) + logical_offset(args...);
+ }
+
+ template<int ORIG_LEN, typename ...Void>
+ dim_t _blk_off() const { return offset0(); }
+
+ template<int ORIG_LEN, typename T, typename ...Args>
+ dim_t _blk_off(T xc, Args ...args) const {
+ assert(is_blocking_desc());
+ constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
+ return xc * blocking_desc().strides[dc]
+ + _blk_off<ORIG_LEN, Args...>(args...);
+ }
+};
+
+inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
+ bool with_padding, bool with_data_type, int dim_start) const {
+ using namespace utils;
+
+ if (one_of(format_kind(), format_kind::undef, format_kind::any))
+ return false;
+ if (is_wino_desc() || is_rnn_packed_desc())
+ return false;
+
+ const int ds = dim_start;
+ const auto &blk = blocking_desc();
+ const auto &r_blk = rhs.blocking_desc();
+
+ return ndims() == rhs.ndims()
+ && dim_start <= ndims() /* guard */
+ && format_kind() == rhs.format_kind()
+ && IMPLICATION(with_data_type, data_type() == rhs.data_type())
+ && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
+ && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds)
+ && blk.inner_nblks == r_blk.inner_nblks
+ && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
+ && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)
+ && IMPLICATION(with_padding, true
+ && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds,
+ ndims() - ds)
+ && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds,
+ ndims() - ds));
+}
+
+inline bool memory_desc_wrapper::consistent_with(
+ const memory_desc_wrapper &rhs) const {
+ if (ndims() == rhs.ndims()) {
+ for (int d = 0; d < ndims(); ++d) {
+ if (dims()[d] != rhs.dims()[d]) return false;
+ }
+ return true;
+ } else {
+ /* TODO: revise.
+ * is the following possible?
+ * [1, a, b] <--reorder--> [a, b]
+ * [a, 1, b] <--reorder--> [a, b]
+ * not, at least for now */
+ return false;
+ }
+}
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp
new file mode 100644
index 0000000000..ec077b308c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp
@@ -0,0 +1,295 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MEMORY_TRACKING_HPP
+#define MEMORY_TRACKING_HPP
+
+#include <assert.h>
+#include <unordered_map>
+
+#include "nstl.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace memory_tracking {
+
+/* Memory tracking capabilities
+ *
+ * The main purpose of this header file is to provide uniform way to register
+ * required memory for a scratchpad at a primitive descriptor creation time
+ * and then easily access it having only the base address of the scratchpad.
+ *
+ * Primitives might contain multiple disjoint parts that require temporary
+ * buffers (known as scratchpad) during their execution. A primitive descriptor
+ * should summarize all the needs into one single number -- the buffer size
+ * that would be requested from a user. At execution time, the corresponding
+ * primitive will receive a base pointer to a scratchpad. It then needs to
+ * provide each part of algorithm the corresponding piece of memory. Three main
+ * challenges here are:
+ * 1. Track correct offset (from the base scratchpad address) for each piece
+ * 2. Algorithm might require that different memory pieces to be aligned, so
+ * the scratchpad size is no more just a sum of size of the corresponding
+ * subparts.
+ * 3. While a primitive is responsible for its scratchpad, the implementation
+ * might use some other basic blocks (e.g. cpu_reducer) that also require
+ * scratchpad memory. So there should be a simple way of passing the
+ * information back and force between the main algorithm (a primitive) and
+ * auxiliary stuff that lives completely separately from it (e.g. reducer).
+ *
+ * To address these challenges this header file provides 3 structures:
+ * 1. registry_t -- the class the stores the information about requested
+ * memory. The information includes required size and desired
+ * alignment for each piece. This class is also responsible
+ * for computing the right offset to a given piece using the
+ * base pointer.
+ * This class is basically a ledger with all entries.
+ * Lives in primitive descriptors.
+ *
+ * 2. registrar_t -- the interface to a registry_t to book memory. Used at
+ * primitive descriptor creation time only. Contains a
+ * reference to the corresponding *mutable* registry.
+ * Always modifiable.
+ * Allows chaining (using prefixes).
+ *
+ * 3. grantor_t -- the interface to a registry_t to access memory. Used at
+ * primitive execution time only. Contains a reference to
+ * the corresponding *constant* registry and base pointer.
+ * Always constant.
+ * Allows chaining (using prefixes).
+ *
+ * Both registrar_t and grantor_t allow chaining with extra prefix provided.
+ * The feature is useful when a primitive offload a part of computations to
+ * some other primitives which require their own scratchpad space
+ * (e.g. reducer). Prefixes are used to avoid key collision in cases when
+ * multiple sub-primitive (e.g. multiple reducers) are used.
+ *
+ * A short example below demonstrates how to use aforementioned classes. In it
+ * the main primitive is convolution that uses scratchpad for keeping padded
+ * bias. It also needs a reducer, that needs its own space as well.
+ *
+ * ``` c++
+ * struct reducer_t {
+ * static void init(registrar_t &scratchpad) {
+ * // preserve space for the reduction (one page aligned)
+ * scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096);
+ * }
+ *
+ * void exec(const grantor_t &scratchpad) {
+ * // get the pointer to preserved space. scratchpad came from
+ * // upper primitive (convolution in this example)
+ * auto space = scratchpad.get<float>(key_reducer_space);
+ *
+ * space[:] += ...;
+ * }
+ * };
+ *
+ * struct conv_t {
+ * struct pd_t {
+ * void init() {
+ * registrar_t scratchpad(scratchpad_registry_);
+ *
+ * // preserve a space for padded bias (using default alignment)
+ * scratchpad.book(key_conv_padded_bias, 128);
+ *
+ * // create a proxy registrar for the reducer All entries made
+ * // by reducer would live in convolution's registry, but would
+ * // have their own `prefix`, so no interference with conv's
+ * // buffers.
+ * registrar_t reducer_scratchpad(scratchpad, prefix_reducer);
+ *
+ * reducer_t::init(reducer_scratchpad);
+ * }
+ *
+ * registry_t scratchpad_registry_;
+ * }
+ *
+ * void exec() {
+ * // get the base pointer to a scratchpad memory from a user
+ * void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD);
+ *
+ * // create a grantor to the scratchpad (and provide the base
+ * // pointer).
+ * grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr);
+ *
+ * // access the padded_bias (need only key name and the grantor)
+ * auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+ *
+ * // to give the `right` grantor to reducer we need to add the
+ * // corresponding prefix, so that reducer would be able to access
+ * // its keys. The call is very similar to the one in pd_t::init
+ * // with only difference in types: grantor_t vs registrar_t.
+ * grantor_t reducer_scratchpad(scratchpad, prefix_reducer);
+ * reducer->exec(reducer_scratchpad);
+ * }
+ * };
+ * ```
+ */
+
+
+/* namespace with common keys and prefixes */
+namespace names {
+enum {
+ key_none = 0,
+ key_bnorm_tmp_mean,
+ key_bnorm_tmp_var,
+ key_bnorm_tmp_diff_ss,
+ key_bnorm_tmp_stats,
+ key_bnorm_reduction,
+ key_concat_iptrs,
+ key_concat_istrides,
+ key_concat_nelems,
+ key_concat_optrs,
+ key_conv_adjusted_scales,
+ key_conv_bia_reduction,
+ key_conv_gemm_col,
+ key_conv_gemm_imtr,
+ key_conv_int_dat_in_acc_dt,
+ key_conv_padded_bias,
+ key_conv_rtus_space,
+ key_conv_tr_diff_dst,
+ key_conv_tr_diff_dst_bctx,
+ key_conv_tr_src,
+ key_conv_tr_src_bctx,
+ key_conv_wei_reduction,
+ key_conv_wei_bia_reduction,
+ key_conv_wei_bia_reduction_bctx,
+ key_iprod_int_dat_in_acc_dt,
+ key_reducer_space,
+ key_reducer_space_bctx,
+ key_reorder_wino_plain,
+ key_reorder_wino_transform_space,
+ key_reorder_rnn_weights_quantization,
+ key_reorder_rnn_weights_reduction,
+ key_rnn_space,
+ key_rnn_ptrs_bia,
+ key_rnn_ptrs_wei_layer,
+ key_rnn_ptrs_wei_iter,
+ key_softmax_reduction,
+ key_wino_U,
+ key_wino_V,
+ key_wino_M,
+ key_barrier,
+};
+
+enum {
+ prefix_none = 0,
+ prefix_reducer_bia,
+ prefix_reducer_wei,
+};
+}
+
+// level 0: 00 00 00 xxx
+// level 1: 00 00 aa xxx
+// level 2: 00 aa bb xxx
+// level 3: aa bb cc xxx
+// max # of levels: 3 + 1 (base_level)
+// here:
+// xxx : [1 .. MAX_KEY) : key
+// aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3
+
+using key_t = uint32_t;
+enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), };
+
+/// generates global key based on a prefix and a local key
+inline key_t make_key(key_t prefix, key_t key) { return prefix + key; }
+
+/// generates global prefix based on the global parent and the local ones
+inline key_t make_prefix(key_t parent_prefix, key_t prefix)
+{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; }
+
+struct registrar_t;
+struct grantor_t;
+
+struct registry_t {
+ void book(const key_t &key, size_t size, size_t alignment) {
+ if (size == 0) return;
+ assert(offset_map_.count(key) == 0);
+
+ size = utils::rnd_up(size, minimal_alignment);
+ alignment = nstl::max<size_t>(alignment, minimal_alignment);
+ offset_map_[key] = entry_t{size_, size, alignment};
+
+ size_ += size + alignment - minimal_alignment;
+ }
+
+ void *get(const key_t &key, void *base_ptr) const {
+ if (base_ptr == nullptr) { assert(size() == 0); return nullptr; }
+ if (offset_map_.count(key) != 1) return nullptr;
+
+ const auto &e = offset_map_.at(key);
+ base_ptr = utils::align_ptr<void>(base_ptr, minimal_alignment);
+ char *ptr = (char *)base_ptr + e.offset;
+ return utils::align_ptr<void>(ptr, e.alignment);
+ }
+
+ size_t size() const
+ { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; }
+
+ registrar_t registrar();
+ grantor_t grantor(void *base_ptr) const;
+
+protected:
+ enum { minimal_alignment = 64 };
+ struct entry_t { size_t offset, size, alignment; };
+
+ std::unordered_map<key_t, entry_t> offset_map_;
+ size_t size_ = 0;
+};
+
+struct registrar_t {
+ enum { default_alignment = 64 };
+
+ registrar_t(registry_t &registry): registry_(registry), prefix_(0) {}
+ registrar_t(registrar_t &parent, const key_t &prefix)
+ : registry_(parent.registry_)
+ , prefix_(make_prefix(parent.prefix_, prefix)) {}
+
+ void book(const key_t &key, size_t size,
+ size_t alignment = default_alignment)
+ { registry_.book(make_key(prefix_, key), size, alignment); }
+
+protected:
+ registry_t &registry_;
+ const key_t prefix_;
+};
+
+struct grantor_t {
+ grantor_t(const registry_t &registry, void *base_ptr)
+ : registry_(registry), prefix_(0), base_ptr_(base_ptr) {}
+ grantor_t(const grantor_t &parent, const key_t &prefix)
+ : registry_(parent.registry_)
+ , prefix_(make_prefix(parent.prefix_, prefix))
+ , base_ptr_(parent.base_ptr_) {}
+
+ template <typename T = void> T *get(const key_t &key) const
+ { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); }
+
+protected:
+ const registry_t &registry_;
+ const key_t prefix_;
+ void *base_ptr_;
+};
+
+inline registrar_t registry_t::registrar() { return registrar_t(*this); }
+inline grantor_t registry_t::grantor(void *base_ptr) const
+{ return grantor_t(*this, base_ptr); }
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
new file mode 100644
index 0000000000..2ef4a8fddc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
@@ -0,0 +1,131 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <stdio.h>
+#include <cinttypes>
+
+#include "mkldnn_debug.h"
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#define DPRINT(...) do { \
+ int l = snprintf(str + written_len, str_len, __VA_ARGS__); \
+ if (l < 0) return l; \
+ if ((size_t)l >= str_len) return -1; \
+ written_len += l; str_len -= l; \
+} while(0)
+
+int mkldnn_md2fmt_str(char *str, size_t str_len,
+ const mkldnn_memory_desc_t *mdesc) {
+ using namespace mkldnn::impl;
+
+ if (str == nullptr || str_len <= 1u)
+ return -1;
+
+ int written_len = 0;
+
+ if (mdesc == nullptr) {
+ DPRINT("%s::%s::",
+ mkldnn_dt2str(data_type::undef),
+ mkldnn_fmt_kind2str(format_kind::undef));
+ return written_len;
+ }
+
+ memory_desc_wrapper md(mdesc);
+
+ DPRINT("%s:", mkldnn_dt2str(md.data_type()));
+
+ bool padded_dims = false, padded_offsets = false;
+ for (int d = 0; d < md.ndims(); ++d) {
+ if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true;
+ if (md.padded_offsets()[d] != 0) padded_offsets = true;
+ }
+ bool offset0 = md.offset0();
+ DPRINT("%s%s%s:",
+ padded_dims ? "p" : "",
+ padded_offsets ? "o" : "",
+ offset0 ? "0" : "");
+
+ DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind()));
+
+ if (!md.is_blocking_desc()) {
+ /* TODO: extend */
+ DPRINT("%s:", "");
+ } else {
+ const auto &blk = md.blocking_desc();
+
+ dims_t blocks;
+ md.compute_blocks(blocks);
+
+ char dim_chars[MKLDNN_MAX_NDIMS + 1];
+
+ bool plain = true;
+ for (int d = 0; d < md.ndims(); ++d) {
+ dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d;
+ if (blocks[d] != 1) plain = false;
+ }
+
+ dims_t strides;
+ utils::array_copy(strides, blk.strides, md.ndims());
+ utils::simultaneous_sort(strides, dim_chars, md.ndims(),
+ [](dim_t a, dim_t b) { return b - a; });
+
+ dim_chars[md.ndims()] = '\0';
+ DPRINT("%s", dim_chars);
+
+ if (!plain) {
+ for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
+ DPRINT("%d%c", (int)blk.inner_blks[iblk],
+ 'a' + (char)blk.inner_idxs[iblk]);
+ }
+ }
+
+ DPRINT("%s", ":");
+ }
+
+ DPRINT("f%lx", (long)md.extra().flags);
+
+ return written_len;
+}
+
+int mkldnn_md2dim_str(char *str, size_t str_len,
+ const mkldnn_memory_desc_t *mdesc) {
+ using namespace mkldnn::impl;
+
+ if (str == nullptr || str_len <= 1)
+ return -1;
+
+ int written_len = 0;
+
+ if (mdesc == nullptr || mdesc->ndims == 0) {
+ DPRINT("%s", "");
+ return written_len;
+ }
+
+ memory_desc_wrapper md(mdesc);
+
+ for (int d = 0; d < md.ndims() - 1; ++d)
+ DPRINT("%" PRId64 "x", md.dims()[d]);
+ DPRINT("%" PRId64, md.dims()[md.ndims() - 1]);
+
+ return written_len;
+}
+
+#undef DPRINT
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp
new file mode 100644
index 0000000000..16a8f7ea5e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp
@@ -0,0 +1,365 @@
+/*******************************************************************************
+* Copyright 2018-2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/* DO NOT EDIT, AUTO-GENERATED */
+
+#include <assert.h>
+
+#include "mkldnn_debug.h"
+#include "mkldnn_types.h"
+
+const char *mkldnn_status2str(mkldnn_status_t v) {
+ if (v == mkldnn_success) return "success";
+ if (v == mkldnn_out_of_memory) return "out_of_memory";
+ if (v == mkldnn_try_again) return "try_again";
+ if (v == mkldnn_invalid_arguments) return "invalid_arguments";
+ if (v == mkldnn_not_ready) return "not_ready";
+ if (v == mkldnn_unimplemented) return "unimplemented";
+ if (v == mkldnn_iterator_ends) return "iterator_ends";
+ if (v == mkldnn_runtime_error) return "runtime_error";
+ if (v == mkldnn_not_required) return "not_required";
+ assert(!"unknown status");
+ return "unknown status";
+}
+
+const char *mkldnn_dt2str(mkldnn_data_type_t v) {
+ if (v == mkldnn_data_type_undef) return "undef";
+ if (v == mkldnn_f32) return "f32";
+ if (v == mkldnn_s32) return "s32";
+ if (v == mkldnn_s8) return "s8";
+ if (v == mkldnn_u8) return "u8";
+ assert(!"unknown dt");
+ return "unknown dt";
+}
+
+const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) {
+ if (v == mkldnn_format_kind_undef) return "undef";
+ if (v == mkldnn_format_kind_any) return "any";
+ if (v == mkldnn_blocked) return "blocked";
+ if (v == mkldnn_format_kind_wino) return "wino";
+ if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed";
+ assert(!"unknown fmt_kind");
+ return "unknown fmt_kind";
+}
+
+const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) {
+ if (v == mkldnn_format_tag_undef) return "undef";
+ if (v == mkldnn_format_tag_any) return "format_tag_any";
+ if (v == mkldnn_a) return "a";
+ if (v == mkldnn_ab) return "ab";
+ if (v == mkldnn_abc) return "abc";
+ if (v == mkldnn_abcd) return "abcd";
+ if (v == mkldnn_abcde) return "abcde";
+ if (v == mkldnn_abcdef) return "abcdef";
+ if (v == mkldnn_abdec) return "abdec";
+ if (v == mkldnn_acb) return "acb";
+ if (v == mkldnn_acbde) return "acbde";
+ if (v == mkldnn_acdb) return "acdb";
+ if (v == mkldnn_acdeb) return "acdeb";
+ if (v == mkldnn_ba) return "ba";
+ if (v == mkldnn_bac) return "bac";
+ if (v == mkldnn_bacd) return "bacd";
+ if (v == mkldnn_bcda) return "bcda";
+ if (v == mkldnn_cba) return "cba";
+ if (v == mkldnn_cdba) return "cdba";
+ if (v == mkldnn_cdeba) return "cdeba";
+ if (v == mkldnn_decab) return "decab";
+ if (v == mkldnn_Abc16a) return "Abc16a";
+ if (v == mkldnn_ABc16a16b) return "ABc16a16b";
+ if (v == mkldnn_aBc16b) return "aBc16b";
+ if (v == mkldnn_ABc16b16a) return "ABc16b16a";
+ if (v == mkldnn_Abc4a) return "Abc4a";
+ if (v == mkldnn_aBc4b) return "aBc4b";
+ if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b";
+ if (v == mkldnn_ABc4b4a) return "ABc4b4a";
+ if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a";
+ if (v == mkldnn_ABc8a8b) return "ABc8a8b";
+ if (v == mkldnn_aBc8b) return "aBc8b";
+ if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b";
+ if (v == mkldnn_ABc8b8a) return "ABc8b8a";
+ if (v == mkldnn_Abcd16a) return "Abcd16a";
+ if (v == mkldnn_ABcd16a16b) return "ABcd16a16b";
+ if (v == mkldnn_aBcd16b) return "aBcd16b";
+ if (v == mkldnn_ABcd16b16a) return "ABcd16b16a";
+ if (v == mkldnn_aBCd16b16c) return "aBCd16b16c";
+ if (v == mkldnn_aBCd16c16b) return "aBCd16c16b";
+ if (v == mkldnn_Abcd4a) return "Abcd4a";
+ if (v == mkldnn_aBcd4b) return "aBcd4b";
+ if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b";
+ if (v == mkldnn_ABcd4b4a) return "ABcd4b4a";
+ if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c";
+ if (v == mkldnn_aBCd4c4b) return "aBCd4c4b";
+ if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a";
+ if (v == mkldnn_ABcd8a8b) return "ABcd8a8b";
+ if (v == mkldnn_aBcd8b) return "aBcd8b";
+ if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b";
+ if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b";
+ if (v == mkldnn_ABcd8b8a) return "ABcd8b8a";
+ if (v == mkldnn_aBCd8b8c) return "aBCd8b8c";
+ if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c";
+ if (v == mkldnn_aBCd8c8b) return "aBCd8c8b";
+ if (v == mkldnn_Abcde16a) return "Abcde16a";
+ if (v == mkldnn_ABcde16a16b) return "ABcde16a16b";
+ if (v == mkldnn_aBcde16b) return "aBcde16b";
+ if (v == mkldnn_ABcde16b16a) return "ABcde16b16a";
+ if (v == mkldnn_aBCde16b16c) return "aBCde16b16c";
+ if (v == mkldnn_aBCde16c16b) return "aBCde16c16b";
+ if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c";
+ if (v == mkldnn_Abcde4a) return "Abcde4a";
+ if (v == mkldnn_aBcde4b) return "aBcde4b";
+ if (v == mkldnn_ABcde4b4a) return "ABcde4b4a";
+ if (v == mkldnn_aBCde4b4c) return "aBCde4b4c";
+ if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c";
+ if (v == mkldnn_aBCde4c4b) return "aBCde4c4b";
+ if (v == mkldnn_Abcde8a) return "Abcde8a";
+ if (v == mkldnn_ABcde8a8b) return "ABcde8a8b";
+ if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b";
+ if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b";
+ if (v == mkldnn_ABcde8b8a) return "ABcde8b8a";
+ if (v == mkldnn_aBCde8b8c) return "aBCde8b8c";
+ if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c";
+ if (v == mkldnn_aBCde8c8b) return "aBCde8c8b";
+ if (v == mkldnn_aBcdef16b) return "aBcdef16b";
+ if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c";
+ if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b";
+ if (v == mkldnn_aBcdef4b) return "aBcdef4b";
+ if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b";
+ if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c";
+ if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c";
+ if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b";
+ if (v == mkldnn_aBdc16b) return "aBdc16b";
+ if (v == mkldnn_aBdc4b) return "aBdc4b";
+ if (v == mkldnn_aBdc8b) return "aBdc8b";
+ if (v == mkldnn_aBdec16b) return "aBdec16b";
+ if (v == mkldnn_aBdec4b) return "aBdec4b";
+ if (v == mkldnn_aBdec8b) return "aBdec8b";
+ if (v == mkldnn_aBdefc16b) return "aBdefc16b";
+ if (v == mkldnn_aBdefc4b) return "aBdefc4b";
+ if (v == mkldnn_aBdefc8b) return "aBdefc8b";
+ if (v == mkldnn_Acb16a) return "Acb16a";
+ if (v == mkldnn_Acb4a) return "Acb4a";
+ if (v == mkldnn_Acb8a) return "Acb8a";
+ if (v == mkldnn_aCBd16b16c) return "aCBd16b16c";
+ if (v == mkldnn_aCBde16b16c) return "aCBde16b16c";
+ if (v == mkldnn_Acdb16a) return "Acdb16a";
+ if (v == mkldnn_Acdb4a) return "Acdb4a";
+ if (v == mkldnn_Acdb8a) return "Acdb8a";
+ if (v == mkldnn_Acdeb16a) return "Acdeb16a";
+ if (v == mkldnn_Acdeb4a) return "Acdeb4a";
+ if (v == mkldnn_Acdeb8a) return "Acdeb8a";
+ if (v == mkldnn_BAc16a16b) return "BAc16a16b";
+ if (v == mkldnn_BAcd16a16b) return "BAcd16a16b";
+ if (v == mkldnn_format_tag_last) return "format_tag_last";
+ if (v == mkldnn_x) return "x";
+ if (v == mkldnn_nc) return "nc";
+ if (v == mkldnn_cn) return "cn";
+ if (v == mkldnn_ncw) return "ncw";
+ if (v == mkldnn_nwc) return "nwc";
+ if (v == mkldnn_nchw) return "nchw";
+ if (v == mkldnn_nhwc) return "nhwc";
+ if (v == mkldnn_chwn) return "chwn";
+ if (v == mkldnn_ncdhw) return "ncdhw";
+ if (v == mkldnn_ndhwc) return "ndhwc";
+ if (v == mkldnn_oi) return "oi";
+ if (v == mkldnn_io) return "io";
+ if (v == mkldnn_oiw) return "oiw";
+ if (v == mkldnn_wio) return "wio";
+ if (v == mkldnn_oihw) return "oihw";
+ if (v == mkldnn_hwio) return "hwio";
+ if (v == mkldnn_ihwo) return "ihwo";
+ if (v == mkldnn_iohw) return "iohw";
+ if (v == mkldnn_oidhw) return "oidhw";
+ if (v == mkldnn_dhwio) return "dhwio";
+ if (v == mkldnn_goiw) return "goiw";
+ if (v == mkldnn_goihw) return "goihw";
+ if (v == mkldnn_hwigo) return "hwigo";
+ if (v == mkldnn_giohw) return "giohw";
+ if (v == mkldnn_goidhw) return "goidhw";
+ if (v == mkldnn_tnc) return "tnc";
+ if (v == mkldnn_ntc) return "ntc";
+ if (v == mkldnn_ldsnc) return "ldsnc";
+ if (v == mkldnn_ldigo) return "ldigo";
+ if (v == mkldnn_ldgoi) return "ldgoi";
+ if (v == mkldnn_ldgo) return "ldgo";
+ if (v == mkldnn_nCdhw16c) return "nCdhw16c";
+ if (v == mkldnn_nCdhw4c) return "nCdhw4c";
+ if (v == mkldnn_nCdhw8c) return "nCdhw8c";
+ if (v == mkldnn_nChw16c) return "nChw16c";
+ if (v == mkldnn_nChw4c) return "nChw4c";
+ if (v == mkldnn_nChw8c) return "nChw8c";
+ if (v == mkldnn_nCw16c) return "nCw16c";
+ if (v == mkldnn_nCw4c) return "nCw4c";
+ if (v == mkldnn_nCw8c) return "nCw8c";
+ if (v == mkldnn_IOw16o16i) return "IOw16o16i";
+ if (v == mkldnn_OIw16i16o) return "OIw16i16o";
+ if (v == mkldnn_OIw16o16i) return "OIw16o16i";
+ if (v == mkldnn_Oiw16o) return "Oiw16o";
+ if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i";
+ if (v == mkldnn_OIw4i4o) return "OIw4i4o";
+ if (v == mkldnn_Oiw4o) return "Oiw4o";
+ if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i";
+ if (v == mkldnn_OIw8i8o) return "OIw8i8o";
+ if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o";
+ if (v == mkldnn_OIw8o8i) return "OIw8o8i";
+ if (v == mkldnn_Owi16o) return "Owi16o";
+ if (v == mkldnn_Owi4o) return "Owi4o";
+ if (v == mkldnn_Owi8o) return "Owi8o";
+ if (v == mkldnn_IOhw16o16i) return "IOhw16o16i";
+ if (v == mkldnn_Ohwi16o) return "Ohwi16o";
+ if (v == mkldnn_Ohwi4o) return "Ohwi4o";
+ if (v == mkldnn_Ohwi8o) return "Ohwi8o";
+ if (v == mkldnn_OIhw16i16o) return "OIhw16i16o";
+ if (v == mkldnn_OIhw16o16i) return "OIhw16o16i";
+ if (v == mkldnn_Oihw16o) return "Oihw16o";
+ if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i";
+ if (v == mkldnn_OIhw4i4o) return "OIhw4i4o";
+ if (v == mkldnn_Oihw4o) return "Oihw4o";
+ if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i";
+ if (v == mkldnn_OIhw8i8o) return "OIhw8i8o";
+ if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o";
+ if (v == mkldnn_OIhw8o8i) return "OIhw8o8i";
+ if (v == mkldnn_Odhwi16o) return "Odhwi16o";
+ if (v == mkldnn_Odhwi4o) return "Odhwi4o";
+ if (v == mkldnn_Odhwi8o) return "Odhwi8o";
+ if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o";
+ if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i";
+ if (v == mkldnn_Oidhw16o) return "Oidhw16o";
+ if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o";
+ if (v == mkldnn_Oidhw4o) return "Oidhw4o";
+ if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i";
+ if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o";
+ if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i";
+ if (v == mkldnn_Goiw16g) return "Goiw16g";
+ if (v == mkldnn_gIOw16o16i) return "gIOw16o16i";
+ if (v == mkldnn_gOIw16i16o) return "gOIw16i16o";
+ if (v == mkldnn_gOIw16o16i) return "gOIw16o16i";
+ if (v == mkldnn_gOiw16o) return "gOiw16o";
+ if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i";
+ if (v == mkldnn_gOIw4i4o) return "gOIw4i4o";
+ if (v == mkldnn_gOiw4o) return "gOiw4o";
+ if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i";
+ if (v == mkldnn_gOIw8i8o) return "gOIw8i8o";
+ if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o";
+ if (v == mkldnn_gOIw8o8i) return "gOIw8o8i";
+ if (v == mkldnn_gOwi16o) return "gOwi16o";
+ if (v == mkldnn_gOwi4o) return "gOwi4o";
+ if (v == mkldnn_gOwi8o) return "gOwi8o";
+ if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i";
+ if (v == mkldnn_gOhwi16o) return "gOhwi16o";
+ if (v == mkldnn_gOhwi4o) return "gOhwi4o";
+ if (v == mkldnn_gOhwi8o) return "gOhwi8o";
+ if (v == mkldnn_Goihw16g) return "Goihw16g";
+ if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o";
+ if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i";
+ if (v == mkldnn_gOihw16o) return "gOihw16o";
+ if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i";
+ if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i";
+ if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o";
+ if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i";
+ if (v == mkldnn_gOihw4o) return "gOihw4o";
+ if (v == mkldnn_Goihw8g) return "Goihw8g";
+ if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i";
+ if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o";
+ if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o";
+ if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i";
+ if (v == mkldnn_gOdhwi16o) return "gOdhwi16o";
+ if (v == mkldnn_gOdhwi4o) return "gOdhwi4o";
+ if (v == mkldnn_gOdhwi8o) return "gOdhwi8o";
+ if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o";
+ if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i";
+ if (v == mkldnn_gOidhw16o) return "gOidhw16o";
+ if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o";
+ if (v == mkldnn_gOidhw4o) return "gOidhw4o";
+ if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i";
+ if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o";
+ if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i";
+ assert(!"unknown fmt_tag");
+ return "unknown fmt_tag";
+}
+
+const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) {
+ if (v == mkldnn_prop_kind_undef) return "undef";
+ if (v == mkldnn_forward_training) return "forward_training";
+ if (v == mkldnn_forward_inference) return "forward_inference";
+ if (v == mkldnn_forward_scoring) return "forward_scoring";
+ if (v == mkldnn_forward) return "forward";
+ if (v == mkldnn_backward) return "backward";
+ if (v == mkldnn_backward_data) return "backward_data";
+ if (v == mkldnn_backward_weights) return "backward_weights";
+ if (v == mkldnn_backward_bias) return "backward_bias";
+ assert(!"unknown prop_kind");
+ return "unknown prop_kind";
+}
+
+const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) {
+ if (v == mkldnn_undefined_primitive) return "undef";
+ if (v == mkldnn_reorder) return "reorder";
+ if (v == mkldnn_shuffle) return "shuffle";
+ if (v == mkldnn_concat) return "concat";
+ if (v == mkldnn_sum) return "sum";
+ if (v == mkldnn_convolution) return "convolution";
+ if (v == mkldnn_deconvolution) return "deconvolution";
+ if (v == mkldnn_eltwise) return "eltwise";
+ if (v == mkldnn_softmax) return "softmax";
+ if (v == mkldnn_pooling) return "pooling";
+ if (v == mkldnn_lrn) return "lrn";
+ if (v == mkldnn_batch_normalization) return "batch_normalization";
+ if (v == mkldnn_inner_product) return "inner_product";
+ if (v == mkldnn_rnn) return "rnn";
+ assert(!"unknown prim_kind");
+ return "unknown prim_kind";
+}
+
+const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
+ if (v == mkldnn_alg_kind_undef) return "undef";
+ if (v == mkldnn_convolution_direct) return "convolution_direct";
+ if (v == mkldnn_convolution_winograd) return "convolution_winograd";
+ if (v == mkldnn_convolution_auto) return "convolution_auto";
+ if (v == mkldnn_deconvolution_direct) return "deconvolution_direct";
+ if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd";
+ if (v == mkldnn_eltwise_relu) return "eltwise_relu";
+ if (v == mkldnn_eltwise_tanh) return "eltwise_tanh";
+ if (v == mkldnn_eltwise_elu) return "eltwise_elu";
+ if (v == mkldnn_eltwise_square) return "eltwise_square";
+ if (v == mkldnn_eltwise_abs) return "eltwise_abs";
+ if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt";
+ if (v == mkldnn_eltwise_linear) return "eltwise_linear";
+ if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu";
+ if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu";
+ if (v == mkldnn_eltwise_logistic) return "eltwise_logistic";
+ if (v == mkldnn_pooling_max) return "pooling_max";
+ if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding";
+ if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding";
+ if (v == mkldnn_pooling_avg) return "pooling_avg";
+ if (v == mkldnn_lrn_across_channels) return "lrn_across_channels";
+ if (v == mkldnn_lrn_within_channel) return "lrn_within_channel";
+ if (v == mkldnn_vanilla_rnn) return "vanilla_rnn";
+ if (v == mkldnn_vanilla_lstm) return "vanilla_lstm";
+ if (v == mkldnn_vanilla_gru) return "vanilla_gru";
+ if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset";
+ assert(!"unknown alg_kind");
+ return "unknown alg_kind";
+}
+
+const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) {
+ if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right";
+ if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left";
+ if (v == mkldnn_bidirectional_concat) return "bidirectional_concat";
+ if (v == mkldnn_bidirectional_sum) return "bidirectional_sum";
+ if (v == mkldnn_unidirectional) return "unidirectional";
+ assert(!"unknown rnn_direction");
+ return "unknown rnn_direction";
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
new file mode 100644
index 0000000000..7e5789e2c3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
@@ -0,0 +1,115 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_THREAD_HPP
+#define MKLDNN_THREAD_HPP
+
+#include "utils.hpp"
+#include "z_magic.hpp"
+
+#define MKLDNN_THR_SEQ 0
+#define MKLDNN_THR_OMP 1
+#define MKLDNN_THR_TBB 2
+
+/* Ideally this condition below should never happen (if the library is built
+ * using regular cmake). For the 3rd-party projects that build the library
+ * from the sources on their own try to guess the right threading... */
+#if !defined(MKLDNN_THR)
+# define MKLDNN_THR MKLDNN_THR_TBB
+#endif
+
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+#define MKLDNN_THR_SYNC 1
+inline int mkldnn_get_max_threads() { return 1; }
+inline int mkldnn_get_num_threads() { return 1; }
+inline int mkldnn_get_thread_num() { return 0; }
+inline int mkldnn_in_parallel() { return 0; }
+inline void mkldnn_thr_barrier() {}
+
+#define PRAGMA_OMP(...)
+
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+#include <omp.h>
+#define MKLDNN_THR_SYNC 1
+
+inline int mkldnn_get_max_threads() { return omp_get_max_threads(); }
+inline int mkldnn_get_num_threads() { return omp_get_num_threads(); }
+inline int mkldnn_get_thread_num() { return omp_get_thread_num(); }
+inline int mkldnn_in_parallel() { return omp_in_parallel(); }
+inline void mkldnn_thr_barrier() {
+# pragma omp barrier
+}
+
+#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__))
+
+#elif MKLDNN_THR == MKLDNN_THR_TBB
+#include "tbb/task_arena.h"
+#include "tbb/parallel_for.h"
+#define MKLDNN_THR_SYNC 0
+
+inline int mkldnn_get_max_threads()
+{ return tbb::this_task_arena::max_concurrency(); }
+inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); }
+inline int mkldnn_get_thread_num()
+{ return tbb::this_task_arena::current_thread_index(); }
+inline int mkldnn_in_parallel() { return 0; }
+inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); }
+
+#define PRAGMA_OMP(...)
+
+#endif
+
+/* MSVC still supports omp 2.0 only */
+#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+# define collapse(x)
+# define PRAGMA_OMP_SIMD(...)
+#else
+# define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__))
+#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER)
+
+namespace mkldnn {
+namespace impl {
+
+inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; }
+
+template <typename T, typename U>
+inline void balance211(T n, U team, U tid, T &n_start, T &n_end) {
+ T n_min = 1;
+ T &n_my = n_end;
+ if (team <= 1 || n == 0) {
+ n_start = 0;
+ n_my = n;
+ } else if (n_min == 1) {
+ // team = T1 + T2
+ // n = T1*n1 + T2*n2 (n1 - n2 = 1)
+ T n1 = utils::div_up(n, (T)team);
+ T n2 = n1 - 1;
+ T T1 = n - n2 * (T)team;
+ n_my = (T)tid < T1 ? n1 : n2;
+ n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
+ }
+
+ n_end += n_start;
+}
+
+} // namespace impl
+} // namespace mkldnn
+
+#include "mkldnn_thread_parallel_nd.hpp"
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp
new file mode 100644
index 0000000000..50f9b29622
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp
@@ -0,0 +1,277 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP
+#define MKLDNN_THREAD_PARALLEL_ND_HPP
+
+/* This header must be included by mkldnn_thread.hpp only */
+
+/* Functions:
+ * - parallel(nthr, f) - executes f in parallel using at most
+ * nthr threads. If nthr equals 0
+ * mkldnn_get_max_threads() threads is
+ * used
+ * - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already
+ * created threads
+ * - parallel_nd(dims..., f) - creates a parallel section and then
+ * calls for_nd
+ * - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then
+ * calls for_nd (mostly for convenience)
+ */
+
+namespace mkldnn {
+namespace impl {
+
+/* general parallelization */
+template <typename F>
+void parallel(int nthr, F f) {
+ if (nthr == 0) nthr = mkldnn_get_max_threads();
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+ assert(nthr == 1);
+ f(0, 1);
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+ if (nthr == 1) { f(0, 1); return; }
+# pragma omp parallel num_threads(nthr)
+ f(mkldnn_get_thread_num(), mkldnn_get_num_threads());
+#elif MKLDNN_THR == MKLDNN_THR_TBB
+ if (nthr == 1) { f(0, 1); return; }
+ tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner());
+#endif
+}
+
+/* for_nd section */
+
+template <typename T0, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
+ T0 start{0}, end{0};
+ balance211(D0, nthr, ithr, start, end);
+ for (T0 d0 = start; d0 < end; ++d0) f(d0);
+}
+
+template <typename T0, typename T1, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) {
+ const size_t work_amount = (size_t)D0 * D1;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1);
+ utils::nd_iterator_step(d0, D0, d1, D1);
+ }
+}
+
+template <typename T0, typename T1, typename T2, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+ const T2 &D2, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0}; T2 d2{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1, d2);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
+ }
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+ const T2 &D2, const T3 &D3, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1, d2, d3);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
+ }
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+ typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+ const T2 &D2, const T3 &D3, const T4 &D4, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1, d2, d3, d4);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+ }
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+ typename T5, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+ const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
+ if (work_amount == 0) return;
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
+ utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
+ d5, D5);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ f(d0, d1, d2, d3, d4, d5);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
+ }
+}
+
+// Skip a lambda function in the parameter pack.
+template <typename T>
+constexpr size_t get_work_amount(const T &v) { return 1; }
+template <typename T, typename ...Args>
+constexpr size_t get_work_amount(const T &v, Args &&...args)
+{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
+
+/* parallel_nd and parallel_nd_in_omp section */
+
+#if MKLDNN_THR != MKLDNN_THR_TBB
+template <typename ...Args>
+void parallel_nd(Args &&...args) {
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+ for_nd(0, 1, utils::forward<Args>(args)...);
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+ const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
+# pragma omp parallel if (do_parallel)
+ {
+ const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
+ const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
+ for_nd(ithr, nthr, utils::forward<Args>(args)...);
+ }
+#endif
+}
+#else // MKLDNN_THR != MKLDNN_THR_TBB
+
+// gcc 4.8 has a bug with passing parameter pack to lambdas.
+// So have to explicitly instantiate all the cases.
+
+template <typename T0, typename F>
+void parallel_nd(const T0 &D0, F f) {
+ const size_t work_amount = (size_t)D0;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(T0(iwork));
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, F f) {
+ const size_t work_amount = (size_t)D0 * D1;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1);
+ utils::nd_iterator_step(d0, D0, d1, D1);
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0}; T2 d2{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1, d2);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1, d2, d3);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+ typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
+ const T4 &D4, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1, d2, d3, d4);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+ }
+ }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+ typename T5, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
+ const T4 &D4, const T5 &D5, F f) {
+ const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
+ if (work_amount == 0) return;
+ tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+ T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
+ utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
+ d5, D5);
+ for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+ f(d0, d1, d2, d3, d4, d5);
+ utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
+ }
+ }, tbb::static_partitioner());
+}
+#endif
+
+template <typename ...Args>
+void parallel_nd_in_omp(Args &&...args) {
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+ for_nd(0, 1, utils::forward<Args>(args)...);
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+ for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
+ utils::forward<Args>(args)...);
+#elif MKLDNN_THR == MKLDNN_THR_TBB
+ assert(!"unsupported parallel_nd_in_omp()");
+#endif
+}
+
+} // namespace impl
+} // namespace mkldnn
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp
new file mode 100644
index 0000000000..aa671a0b6e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp
@@ -0,0 +1,77 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_TRAITS_HPP
+#define MKLDNN_TRAITS_HPP
+
+#include <assert.h>
+#include <stdint.h>
+
+#include "mkldnn.h"
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+template <data_type_t> struct prec_traits {}; /* ::type -> float */
+template <typename> struct data_traits {}; /* ::data_type -> f32 */
+template <int> struct typesize_traits {}; /* ::data_type_size -> f32 */
+template <primitive_kind_t> struct pkind_traits {}; /* ::desc_type, ::query_d */
+
+template <> struct prec_traits<data_type::f32> { typedef float type; };
+template <> struct prec_traits<data_type::s32> { typedef int32_t type; };
+template <> struct prec_traits<data_type::s8> { typedef int8_t type; };
+template <> struct prec_traits<data_type::u8> { typedef uint8_t type; };
+
+template <> struct data_traits<float>
+{ static constexpr data_type_t data_type = data_type::f32; };
+template <> struct data_traits<int32_t>
+{ static constexpr data_type_t data_type = data_type::s32; };
+template <> struct data_traits<int8_t>
+{ static constexpr data_type_t data_type = data_type::s8; };
+template <> struct data_traits<uint8_t>
+{ static constexpr data_type_t data_type = data_type::u8; };
+
+template <> struct typesize_traits<4> { typedef float type; };
+template <> struct typesize_traits<2> { typedef int16_t type; };
+template <> struct typesize_traits<1> { typedef uint8_t type; };
+
+#define PKIND_TRAITS_INST(op) \
+template <> struct pkind_traits<primitive_kind::op> { \
+ typedef CONCAT2(op, _desc_t) desc_type; \
+ static constexpr query_t query_d = query::CONCAT2(op, _d); \
+}
+PKIND_TRAITS_INST(convolution);
+PKIND_TRAITS_INST(deconvolution);
+PKIND_TRAITS_INST(shuffle);
+PKIND_TRAITS_INST(eltwise);
+PKIND_TRAITS_INST(softmax);
+PKIND_TRAITS_INST(pooling);
+PKIND_TRAITS_INST(lrn);
+PKIND_TRAITS_INST(batch_normalization);
+PKIND_TRAITS_INST(inner_product);
+PKIND_TRAITS_INST(rnn);
+#undef PKIND_TRAITS_INST
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
new file mode 100644
index 0000000000..f89ea999e2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
@@ -0,0 +1,193 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef NSTL_HPP
+#define NSTL_HPP
+
+#include <stdint.h>
+#include <limits.h>
+#include <float.h>
+
+#include <vector>
+#include <map>
+
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+void *malloc(size_t size, int alignment);
+void free(void *p);
+
+struct c_compatible {
+ enum { default_alignment = 64 };
+ static void *operator new(size_t sz) {
+ return malloc(sz, default_alignment);
+ }
+ static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; }
+ static void *operator new[](size_t sz) {
+ return malloc(sz, default_alignment);
+ }
+ static void operator delete(void *p) { free(p); }
+ static void operator delete[](void *p) { free(p); }
+};
+
+namespace nstl {
+
+template<typename T>
+inline const T abs(const T& a) {
+ return a >= 0 ? a : -a;
+}
+
+template<typename T>
+inline const T& max(const T& a, const T& b) {
+ return a > b ? a : b;
+}
+
+template<typename T>
+inline const T& min(const T& a, const T& b) {
+ return a < b ? a : b;
+}
+
+template<typename T> void swap(T& t1, T& t2) {
+ T tmp(t1);
+ t1 = t2;
+ t2 = tmp;
+}
+
+// Rationale: MKL-DNN needs numeric limits implementation that does not
+// generate dependencies on C++ run-time libraries.
+
+template<typename T> struct numeric_limits;
+
+template<> struct numeric_limits<float> {
+ static constexpr float lowest() { return -FLT_MAX; }
+ static constexpr float max() { return FLT_MAX; }
+};
+
+template<> struct numeric_limits<int32_t> {
+ static constexpr int lowest() { return INT32_MIN; }
+ static constexpr int max() { return INT32_MAX; }
+};
+
+template<> struct numeric_limits<int16_t> {
+ static constexpr int16_t lowest() { return INT16_MIN; }
+ static constexpr int16_t max() { return INT16_MAX; }
+};
+
+template<> struct numeric_limits<int8_t> {
+ static constexpr int8_t lowest() { return INT8_MIN; }
+ static constexpr int8_t max() { return INT8_MAX; }
+};
+
+template<> struct numeric_limits<uint8_t> {
+ static constexpr uint8_t lowest() { return 0; }
+ static constexpr uint8_t max() { return UINT8_MAX; }
+};
+
+template<typename T> struct is_integral
+{ static constexpr bool value = false; };
+template<> struct is_integral<int32_t> { static constexpr bool value = true; };
+template<> struct is_integral<int16_t> { static constexpr bool value = true; };
+template<> struct is_integral<int8_t> { static constexpr bool value = true; };
+template<> struct is_integral<uint8_t> { static constexpr bool value = true; };
+
+template <typename T, typename U> struct is_same
+{ static constexpr bool value = false; };
+template <typename T> struct is_same<T, T>
+{ static constexpr bool value = true; };
+
+// Rationale: MKL-DNN needs container implementations that do not generate
+// dependencies on C++ run-time libraries.
+//
+// Implementation philosophy: caller is responsible to check if the operation
+// is valid. The only functions that have to return status are those that
+// depend on memory allocation or similar operations.
+//
+// This means that e.g. an operator [] does not have to check for boundaries.
+// The caller should have checked the boundaries. If it did not we crash and
+// burn: this is a bug in MKL-DNN and throwing an exception would not have been
+// recoverable.
+//
+// On the other hand, insert() or resize() or a similar operation needs to
+// return a status because the outcome depends on factors external to the
+// caller. The situation is probably also not recoverable also, but MKL-DNN
+// needs to be nice and report "out of memory" to the users.
+
+enum nstl_status_t {
+ success = 0,
+ out_of_memory
+};
+
+template <typename T> class vector: public c_compatible {
+private:
+ std::vector<T> _impl;
+public:
+ typedef typename std::vector<T>::iterator iterator;
+ typedef typename std::vector<T>::const_iterator const_iterator;
+ typedef typename std::vector<T>::size_type size_type;
+ vector() {}
+ vector(size_type n): _impl(n) {}
+ vector(size_type n, const T &value): _impl(n, value) {}
+ template <typename input_iterator>
+ vector(input_iterator first, input_iterator last): _impl(first, last) {}
+ ~vector() {}
+ size_type size() const { return _impl.size(); }
+ T& operator[] (size_type i) { return _impl[i]; }
+ const T& operator[] (size_type i) const { return _impl[i]; }
+ iterator begin() { return _impl.begin(); }
+ const_iterator begin() const { return _impl.begin(); }
+ iterator end() { return _impl.end(); }
+ const_iterator end() const { return _impl.end(); }
+ template <typename input_iterator>
+ nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end)
+ {
+ _impl.insert(pos, begin, end);
+ return success;
+ }
+ void clear() { _impl.clear(); }
+ void push_back(const T& t) { _impl.push_back(t); }
+ void resize(size_type count) { _impl.resize(count); }
+ void reserve(size_type count) { _impl.reserve(count); }
+};
+
+template <typename Key, typename T> class map: public c_compatible {
+private:
+ std::map<Key, T> _impl;
+public:
+ typedef typename std::map<Key, T>::iterator iterator;
+ typedef typename std::map<Key, T>::const_iterator const_iterator;
+ typedef typename std::map<Key, T>::size_type size_type;
+ map() {}
+ ~map() {}
+ size_type size() const { return _impl.size(); }
+ T& operator[](const Key &k) { return _impl[k]; }
+ const T& operator[](const Key &k) const { return _impl[k]; }
+ iterator begin() { return _impl.begin(); }
+ const_iterator begin() const { return _impl.begin(); }
+ iterator end() { return _impl.end(); }
+ const_iterator end() const { return _impl.end(); }
+ template <typename input_iterator>
+ void clear() { _impl.clear(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
new file mode 100644
index 0000000000..be96e654ff
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
@@ -0,0 +1,114 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t pooling_desc_init(pooling_desc_t *pool_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t kernel, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ bool args_ok = true
+ && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l)
+ && one_of(alg_kind, pooling_max,
+ pooling_avg_include_padding,
+ pooling_avg_exclude_padding)
+ && one_of(padding_kind, padding_kind::padding_zero);
+ if (!args_ok) return invalid_arguments;
+
+ if (padding_r == nullptr) padding_r = padding_l;
+
+ auto pd = pooling_desc_t();
+ pd.primitive_kind = primitive_kind::pooling;
+ pd.prop_kind = prop_kind;
+ pd.alg_kind = alg_kind;
+ pd.src_desc.ndims = src_desc->ndims;
+
+ const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+
+ pd.diff_src_desc = pd.src_desc = zero_md();
+ pd.diff_dst_desc = pd.dst_desc = zero_md();
+
+ (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc;
+ (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc;
+
+ int sp_dims = src_desc->ndims - 2;
+ utils::array_copy(pd.strides, strides, sp_dims);
+ utils::array_copy(pd.kernel, kernel, sp_dims);
+ utils::array_copy(pd.padding[0], padding_l, sp_dims);
+ utils::array_copy(pd.padding[1], padding_r, sp_dims);
+
+ pd.padding_kind = padding_kind;
+ if (one_of(alg_kind, pooling_max, pooling_avg_include_padding,
+ pooling_avg_exclude_padding)) {
+ pd.accum_data_type = types::default_accum_data_type(
+ src_desc->data_type, dst_desc->data_type);
+ } else {
+ pd.accum_data_type = dst_desc->data_type;
+ }
+
+ bool consistency = true
+ && utils::one_of(src_desc->ndims, 4, 5)
+ && utils::one_of(dst_desc->ndims, 4, 5)
+ && src_desc->dims[0] == dst_desc->dims[0]
+ && src_desc->dims[1] == dst_desc->dims[1];
+ for (int i = 2; i < src_desc->ndims; ++i)
+ consistency = consistency && (
+ (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2]
+ + padding_r[i - 2]) / strides[i - 2] + 1
+ == dst_desc->dims[i]);
+ if (!consistency) return invalid_arguments;
+
+ *pool_desc = pd;
+ return success;
+}
+}
+
+status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc,
+ prop_kind_t prop_kind, alg_kind_t alg_kind,
+ const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
+ const dims_t strides, const dims_t kernel, const dims_t padding_l,
+ const dims_t padding_r, padding_kind_t padding_kind) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc,
+ dst_desc, strides, kernel, padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc,
+ alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
+ const memory_desc_t *diff_dst_desc, const dims_t strides,
+ const dims_t kernel, const dims_t padding_l, const dims_t padding_r,
+ padding_kind_t padding_kind) {
+ return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind,
+ diff_src_desc, diff_dst_desc, strides, kernel, padding_l,
+ padding_r, padding_kind);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
new file mode 100644
index 0000000000..4c9c009412
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
@@ -0,0 +1,238 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef POOLING_PD_HPP
+#define POOLING_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct pooling_fwd_pd_t;
+
+struct pooling_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::pooling;
+
+ pooling_pd_t(engine_t *engine,
+ const pooling_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const pooling_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , ws_md_()
+ {}
+
+ const pooling_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::pooling_d:
+ *(const pooling_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common pooling aux functions */
+
+ dim_t MB() const { return src_desc().dims[0]; }
+ dim_t C() const { return src_desc().dims[1]; }
+
+ dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; }
+ dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; }
+ dim_t IW() const { return src_desc().dims[ndims() - 1]; }
+
+ dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; }
+ dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; }
+ dim_t OW() const { return dst_desc().dims[ndims() - 1]; }
+
+ dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; }
+ dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; }
+ dim_t KW() const { return desc_.kernel[ndims() - 3]; }
+
+ dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
+ dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
+ dim_t KSW() const { return desc_.strides[ndims() - 3]; }
+
+ dim_t padFront() const
+ { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
+ dim_t padBack() const
+ { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
+ dim_t padT() const
+ { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
+ dim_t padB() const
+ { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
+ dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
+ dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
+
+ int ndims() const { return src_desc().ndims; }
+ bool is_3d() const { return ndims() == 5; }
+
+ bool has_zero_dim_memory() const
+ { return memory_desc_wrapper(src_desc()).has_zero_dim(); }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+protected:
+ pooling_desc_t desc_;
+ const pooling_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t ws_md_;
+
+ void init_default_ws() {
+ ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md();
+ ws_md_.data_type = indices_data_type();
+ }
+
+ data_type_t indices_data_type() const {
+ /* the simplest way to express 256... */
+ const int u8_max = nstl::numeric_limits<
+ typename prec_traits<data_type::u8>::type>::max();
+ return utils::array_product(desc()->kernel, ndims()) <= u8_max
+ ? data_type::u8 : data_type::s32;
+ }
+
+private:
+ const memory_desc_t &src_desc() const
+ { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; }
+ const memory_desc_t &dst_desc() const
+ { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; }
+};
+
+struct pooling_fwd_pd_t: public pooling_pd_t {
+ typedef pooling_fwd_pd_t base_class;
+ typedef pooling_fwd_pd_t hint_class;
+
+ pooling_fwd_pd_t(engine_t *engine,
+ const pooling_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const pooling_fwd_pd_t *hint_fwd_pd)
+ : pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , src_md_(desc_.src_desc)
+ , dst_md_(desc_.dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override
+ { return 1 + (workspace_md() != nullptr); }
+
+protected:
+ memory_desc_t src_md_;
+ memory_desc_t dst_md_;
+
+ virtual status_t set_default_params() {
+ if (dst_md()->format_kind != format_kind::any)
+ return status::success;
+
+ if (src_md()->format_kind != format_kind::blocked)
+ return status::unimplemented;
+
+ return memory_desc_init_by_blocking_desc(dst_md_,
+ src_md_.format_desc.blocking);
+ }
+};
+
+struct pooling_bwd_pd_t: public pooling_pd_t {
+ typedef pooling_bwd_pd_t base_class;
+ typedef pooling_fwd_pd_t hint_class;
+
+ pooling_bwd_pd_t(engine_t *engine,
+ const pooling_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const pooling_fwd_pd_t *hint_fwd_pd)
+ : pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_md_(desc_.diff_src_desc)
+ , diff_dst_md_(desc_.diff_dst_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_DIFF_DST)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::input;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_src_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_dst_md_ : nullptr; }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ virtual int n_inputs() const override
+ { return 1 + (workspace_md() != nullptr); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_src_md_;
+ memory_desc_t diff_dst_md_;
+
+ virtual status_t set_default_params() {
+ if (diff_src_md()->format_kind != format_kind::any)
+ return status::success;
+
+ if (diff_dst_md()->format_kind != format_kind::blocked)
+ return status::unimplemented;
+
+ return memory_desc_init_by_blocking_desc(diff_src_md_,
+ diff_dst_md_.format_desc.blocking);
+ }
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
new file mode 100644
index 0000000000..fdf6522f62
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
@@ -0,0 +1,103 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "primitive.hpp"
+#include "type_helpers.hpp"
+#include "stream.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::primitive_kind;
+
+namespace {
+// XXX: this is a huge hammer. This disables all and any msan checks on
+// primitives outputs.
+//
+// A proper approach would be an implementation-specific unpoisoning.
+void unpoison_outputs(const exec_args_t &args) {
+ for(const auto &arg: args) {
+ if (arg.second.is_const) continue;
+ auto *mem = arg.second.mem;
+ void *p;
+ mem->get_data_handle(&p);
+ size_t s = memory_desc_wrapper(*mem->md()).size();
+ msan_unpoison(p, s);
+ }
+}
+}
+
+status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) {
+ if (primitive_desc) delete primitive_desc;
+ return success;
+}
+
+status_t mkldnn_primitive_create(primitive_t **primitive,
+ const primitive_desc_t *primitive_desc) {
+ if (utils::any_null(primitive, primitive_desc))
+ return invalid_arguments;
+ return primitive_desc->create_primitive(primitive);
+}
+
+status_t mkldnn_primitive_execute(const primitive_t *primitive,
+ stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) {
+ bool ok = true
+ && !utils::any_null(primitive, stream)
+ && primitive->engine() == stream->engine()
+ && IMPLICATION(nargs > 0, c_args != nullptr);
+ if (!ok) return invalid_arguments;
+
+ exec_args_t args;
+ status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args);
+ if (status != status::success) return status;
+
+ exec_ctx_t ctx(stream, std::move(args));
+
+ if (mkldnn_verbose()->level) {
+ double ms = get_msec();
+ status = primitive->execute(ctx);
+ ms = get_msec() - ms;
+ printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms);
+ fflush(0);
+ } else {
+ status = primitive->execute(ctx);
+ }
+
+ if (msan_enabled) unpoison_outputs(ctx.args());
+
+ return status;
+}
+
+status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive,
+ const primitive_desc_t **primitive_desc) {
+ if (utils::any_null(primitive, primitive_desc))
+ return invalid_arguments;
+ return safe_ptr_assign<const primitive_desc_t>(*primitive_desc,
+ primitive->pd());
+}
+
+status_t mkldnn_primitive_destroy(primitive_t *primitive) {
+ if (primitive != nullptr)
+ delete primitive;
+ return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
new file mode 100644
index 0000000000..3b506d6d1f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
@@ -0,0 +1,76 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_HPP
+#define PRIMITIVE_HPP
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+#include "primitive_exec_types.hpp"
+
+/** \brief A pure virtual primitive class
+ *
+ * Primitive contains links to its inputs & outputs, though it does not track
+ * their readiness on execution step.
+ *
+ * @remark @b Rational.
+ * Dependencies are essential through-out the whole MKL-DNN library, so it
+ * makes sense to include them on the very low level. On the other hand,
+ * tracking them should be a task for corresponding essence, like scheduler,
+ * stream or whatever. Primitive itself should know nothing about the
+ * environment it is running in.
+ *
+ * @note
+ * To make user experience better we should provide API which allows
+ * achieving the best (or good enough) performance when creating primitives
+ * in natural order: i.e. from bottom to top for forward pass and from top to
+ * bottom for backward pass. Please consider restriction [1] in Level 0.
+ */
+struct mkldnn_primitive: public mkldnn::impl::c_compatible {
+ mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd)
+ : pd_(pd->clone()) {}
+ virtual ~mkldnn_primitive() { delete pd_; }
+
+ /** returns primitive's engine */
+ mkldnn::impl::engine_t *engine() const { return pd_->engine(); }
+ /** returns primitive's inputs */
+ const mkldnn::impl::primitive_desc_t *pd() const { return pd_; }
+ /** returns primitive's kind */
+ mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); }
+
+ /** executes primitive with execution context @p ctx */
+ virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx)
+ const = 0;
+
+protected:
+ const mkldnn::impl::primitive_desc_t *pd_;
+
+private:
+ mkldnn_primitive() = delete;
+ mkldnn_primitive(const mkldnn_primitive &) = delete;
+ mkldnn_primitive(mkldnn_primitive &&) = delete;
+ mkldnn_primitive &operator=(const mkldnn_primitive &) = delete;
+ mkldnn_primitive &operator=(mkldnn_primitive &&) = delete;
+};
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp
new file mode 100644
index 0000000000..9fd638842c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp
@@ -0,0 +1,290 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_attr.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+
+namespace mkldnn {
+namespace impl {
+
+status_t scales_t::set(dim_t count, int mask, const float *scales) {
+ cleanup();
+
+ count_ = count;
+ mask_ = mask;
+
+ if (count_ == 1) {
+ scales_ = scales_buf_;
+ utils::array_set(scales_, scales[0], scales_buf_size);
+ } else {
+ scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
+ if (scales_ == nullptr)
+ return status::out_of_memory;
+
+ for (dim_t c = 0; c < count_; ++c)
+ scales_[c] = scales[c];
+ }
+
+ return status::success;
+}
+
+}
+}
+
+status_t post_ops_t::append_sum(float scale) {
+ if (len_ == capacity)
+ return out_of_memory;
+
+ entry_[len_].kind = primitive_kind::sum;
+ entry_[len_].sum.scale = scale;
+
+ len_++;
+
+ return success;
+}
+
+status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
+ float beta) {
+ using namespace mkldnn::impl::alg_kind;
+ bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
+ eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+ eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic);
+ if (!known_alg)
+ return invalid_arguments;
+
+ if (len_ == capacity)
+ return out_of_memory;
+
+ entry_[len_].kind = primitive_kind::eltwise;
+ entry_[len_].eltwise.scale = scale;
+ entry_[len_].eltwise.alg = alg;
+ entry_[len_].eltwise.alpha = alpha;
+ entry_[len_].eltwise.beta = beta;
+
+ len_++;
+
+ return success;
+}
+
+status_t primitive_attr_t::set_scratchpad_mode(
+ scratchpad_mode_t scratchpad_mode) {
+ using namespace mkldnn::impl::scratchpad_mode;
+
+ const bool ok = one_of(scratchpad_mode, library, user);
+ if (!ok)
+ return invalid_arguments;
+
+ scratchpad_mode_ = scratchpad_mode;
+ return success;
+}
+
+status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
+ this->post_ops_ = post_ops;
+ return success;
+}
+
+/* Public C API */
+
+status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
+ if (attr == nullptr)
+ return invalid_arguments;
+
+ return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
+ new mkldnn_primitive_attr);
+}
+
+status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
+ const primitive_attr_t *existing_attr) {
+ if (any_null(attr, existing_attr))
+ return invalid_arguments;
+
+ return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
+ existing_attr->clone());
+}
+
+status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
+ if (attr)
+ delete attr;
+
+ return success;
+}
+
+status_t mkldnn_primitive_attr_get_scratchpad_mode(
+ const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) {
+ if (any_null(attr, scratchpad_mode))
+ return invalid_arguments;
+
+ *scratchpad_mode = attr->scratchpad_mode_;
+
+ return success;
+}
+
+status_t mkldnn_primitive_attr_set_scratchpad_mode(
+ primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) {
+ if (any_null(attr))
+ return invalid_arguments;
+
+ return attr->set_scratchpad_mode(scratchpad_mode);
+}
+
+status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
+ dim_t *count, int *mask, const float **scales) {
+ if (any_null(attr, count, mask, scales))
+ return invalid_arguments;
+
+ *count = attr->output_scales_.count_;
+ *mask = attr->output_scales_.mask_;
+ *scales = attr->output_scales_.scales_;
+
+ return success;
+}
+
+status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
+ dim_t count, int mask, const float *scales) {
+ bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
+ if (!ok)
+ return invalid_arguments;
+
+ return attr->output_scales_.set(count, mask, scales);
+}
+
+status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
+ const post_ops_t **post_ops) {
+ if (any_null(attr, post_ops))
+ return invalid_arguments;
+
+ *post_ops = &attr->post_ops_;
+ return success;
+}
+
+status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
+ const post_ops_t *post_ops) {
+ if (any_null(attr, post_ops))
+ return invalid_arguments;
+
+ return attr->set_post_ops(*post_ops);
+}
+
+status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
+ if (post_ops == nullptr)
+ return invalid_arguments;
+
+ return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
+}
+
+status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
+ if (post_ops)
+ delete post_ops;
+
+ return success;
+}
+
+int mkldnn_post_ops_len(const post_ops_t *post_ops) {
+ if (post_ops)
+ return post_ops->len_;
+
+ return 0;
+}
+
+primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
+ int index) {
+ bool ok = post_ops && 0 <= index && index < post_ops->len_;
+ if (!ok)
+ return primitive_kind::undefined;
+
+ return post_ops->entry_[index].kind;
+}
+
+status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
+ if (post_ops == nullptr)
+ return invalid_arguments;
+
+ return post_ops->append_sum(scale);
+}
+
+namespace {
+bool simple_get_params_check(const post_ops_t *post_ops, int index,
+ primitive_kind_t kind) {
+ bool ok = true
+ && post_ops != nullptr
+ && 0 <= index
+ && index < post_ops->len_
+ && post_ops->entry_[index].kind == kind;
+ return ok;
+}
+}
+
+status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
+ float *scale) {
+ bool ok = true
+ && simple_get_params_check(post_ops, index, primitive_kind::sum)
+ && !any_null(scale);
+ if (!ok)
+ return invalid_arguments;
+
+ *scale = post_ops->entry_[index].sum.scale;
+ return success;
+}
+
+status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
+ alg_kind_t kind, float alpha, float beta) {
+ if (post_ops == nullptr)
+ return invalid_arguments;
+
+ return post_ops->append_eltwise(scale, kind, alpha, beta);
+}
+
+status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
+ int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
+ bool ok = true
+ && simple_get_params_check(post_ops, index, primitive_kind::eltwise)
+ && !any_null(scale, alpha, beta);
+ if (!ok)
+ return invalid_arguments;
+
+ const auto &e = post_ops->entry_[index].eltwise;
+ *scale = e.scale;
+ *alg = e.alg;
+ *alpha = e.alpha;
+ *beta = e.beta;
+
+ return success;
+}
+
+status_t mkldnn_primitive_attr_set_rnn_data_qparams(
+ primitive_attr_t *attr, const float scale, const float shift) {
+ if (attr == nullptr)
+ return invalid_arguments;
+
+ return attr->rnn_data_qparams_.set(scale, shift);
+}
+
+status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
+ primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
+ bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
+ if (!ok)
+ return invalid_arguments;
+
+ return attr->rnn_weights_qparams_.set(count, mask, scales);
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
new file mode 100644
index 0000000000..e2130c7ab1
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
@@ -0,0 +1,183 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_ATTR_HPP
+#define PRIMITIVE_ATTR_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct rnn_data_qparams_t : public c_compatible {
+ rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
+ bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
+
+ status_t set(float scale, float shift) {
+ scale_ = scale;
+ shift_ = shift;
+ return status::success;
+ }
+
+ float scale_;
+ float shift_;
+};
+
+struct scales_t: public c_compatible {
+ scales_t(): count_(1), mask_(0), scales_(scales_buf_)
+ { set(1.); }
+
+ scales_t(const scales_t &rhs): scales_t()
+ { set(rhs.count_, rhs.mask_, rhs.scales_); }
+
+ ~scales_t() { cleanup(); }
+
+ scales_t &operator=(const scales_t &rhs) {
+ if (&rhs == this)
+ return *this;
+ status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
+ assert(status == status::success);
+ (void)status;
+ return *this;
+ }
+
+ bool has_default_values() const {
+ for (dim_t c = 0; c < count_; ++c) {
+ if(scales_[c] != 1.) return false;
+ }
+ return true;
+ }
+
+ status_t set(dim_t count, int mask, const float *scales);
+ status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
+
+ dim_t count_;
+ int mask_;
+ float *scales_;
+
+private:
+ enum { scales_buf_size = 16 };
+ float scales_buf_[scales_buf_size];
+
+ void cleanup() {
+ if (scales_ != scales_buf_ && scales_ != nullptr)
+ impl::free(scales_);
+
+ count_ = 1;
+ mask_ = 0;
+ scales_ = scales_buf_;
+ }
+};
+
+}
+}
+
+struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
+ struct entry_t {
+ struct eltwise_t {
+ mkldnn::impl::alg_kind_t alg;
+ float scale, alpha, beta;
+ };
+
+ mkldnn::impl::primitive_kind_t kind;
+ union {
+ struct { float scale; } sum;
+ eltwise_t eltwise;
+ };
+
+ bool is_eltwise(bool require_scale_one = true) const {
+ using namespace mkldnn::impl;
+ return kind == primitive_kind::eltwise
+ && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
+ }
+
+ bool is_relu(bool require_scale_one = true,
+ bool require_nslope_zero = true) const {
+ using namespace mkldnn::impl;
+ return is_eltwise(require_scale_one)
+ && eltwise.alg == alg_kind::eltwise_relu
+ && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
+ }
+
+ bool is_sum(bool require_scale_one = true) const {
+ using namespace mkldnn::impl;
+ return kind == primitive_kind::sum
+ && IMPLICATION(require_scale_one, sum.scale == 1.f);
+ }
+ };
+
+ mkldnn_post_ops(): len_(0) {}
+
+ mkldnn::impl::status_t append_sum(float scale);
+ mkldnn::impl::status_t append_eltwise(float scale,
+ mkldnn::impl::alg_kind_t alg, float alpha, float beta);
+
+ int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
+ int stop = -1) const {
+ if (stop == -1) stop = len_;
+ stop = mkldnn::impl::nstl::min(stop, len_);
+ for (int idx = start; idx < stop; ++idx)
+ if (entry_[idx].kind == kind) return idx;
+ return -1;
+ }
+
+ bool has_default_values() const { return len_ == 0; }
+
+ bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
+ { return find(kind, index, index + 1) == index; }
+
+ enum { capacity = 4 };
+
+ int len_;
+ entry_t entry_[capacity];
+};
+
+struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
+ mkldnn_primitive_attr()
+ : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library)
+ {}
+
+ mkldnn_primitive_attr *clone() const
+ { return new mkldnn_primitive_attr(*this); }
+
+ /** Returns true if the attributes have default values.
+ *
+ * @note The scratchpad_mode_ is not take into account */
+ bool has_default_values() const {
+ return true
+ && output_scales_.has_default_values()
+ && post_ops_.has_default_values()
+ && rnn_data_qparams_.has_default_values()
+ && rnn_weights_qparams_.has_default_values();
+ }
+
+ mkldnn::impl::status_t set_scratchpad_mode(
+ mkldnn::impl::scratchpad_mode_t scratchpad_mode);
+ mkldnn::impl::status_t set_post_ops(
+ const mkldnn::impl::post_ops_t &post_ops);
+
+ mkldnn::impl::scratchpad_mode_t scratchpad_mode_;
+ mkldnn::impl::scales_t output_scales_;
+ mkldnn::impl::post_ops_t post_ops_;
+ mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
+ mkldnn::impl::scales_t rnn_weights_qparams_;
+};
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp
new file mode 100644
index 0000000000..723c41e05a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp
@@ -0,0 +1,78 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
+ auto safe_ret_md = [&](const memory_desc_t *_) {
+ if (_ == nullptr) return not_required;
+ *(const memory_desc_t **)result = _;
+ return success;
+ };
+
+ switch (what) {
+ case query::engine: *(engine_t**)result = engine(); break;
+ case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
+
+ case query::scratchpad_engine:
+ *(engine_t**)result = scratchpad_engine(); break;
+
+ case query::memory_consumption_s64:
+ *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break;
+
+ case query::op_d:
+ if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
+ *(const_c_op_desc_t *)result
+ = static_cast<const_c_op_desc_t>(op_desc()); break;
+
+ case query::src_md: return safe_ret_md(src_md(idx));
+ case query::diff_src_md: return safe_ret_md(diff_src_md(idx));
+ case query::dst_md: return safe_ret_md(dst_md(idx));
+ case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx));
+ case query::weights_md: return safe_ret_md(weights_md(idx));
+ case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx));
+ case query::workspace_md:
+ if (idx != 0) return status::invalid_arguments;
+ return safe_ret_md(workspace_md(idx));
+ case query::scratchpad_md:
+ if (idx != 0) return status::invalid_arguments;
+ return safe_ret_md(scratchpad_md(idx));
+
+ case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
+ case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
+
+ case query::impl_info_str: *(const char **)result = name(); break;
+
+ default: return unimplemented;
+ }
+ return success;
+}
+
+status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
+ const primitive_attr_t **attr) {
+ if (utils::any_null(primitive_desc, attr))
+ return invalid_arguments;
+
+ *attr = primitive_desc->attr();
+ return success;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp
new file mode 100644
index 0000000000..536dcfa1d0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp
@@ -0,0 +1,174 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_DESC_HPP
+#define PRIMITIVE_DESC_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "primitive_attr.hpp"
+#include "verbose.hpp"
+
+struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
+ using md_t = mkldnn::impl::memory_desc_t;
+
+ mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_attr_t *attr,
+ mkldnn::impl::primitive_kind_t kind)
+ : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
+
+ mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
+ mkldnn::impl::primitive_kind_t kind)
+ : engine_(engine), kind_(kind) { info_[0] = '\0'; }
+
+ virtual mkldnn_primitive_desc *clone() const = 0;
+ virtual ~mkldnn_primitive_desc() {}
+
+ const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
+ mkldnn::impl::engine_t *engine() const { return engine_; }
+ mkldnn::impl::primitive_kind_t kind() const { return kind_; }
+
+ virtual void init_info() {}
+ const char *info() const { return info_; }
+
+ mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
+ { return scratchpad_registry_; }
+ const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
+ { return scratchpad_registry_; }
+ virtual mkldnn::impl::engine_t *scratchpad_engine() const
+ { return engine_; }
+
+ virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; }
+
+ enum class arg_usage_t { unused, input, output };
+ virtual arg_usage_t arg_usage(
+ mkldnn::impl::primitive_arg_index_t arg) const {
+ using mkldnn::impl::types::is_zero_md;
+ if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
+ return arg_usage_t::output;
+ return arg_usage_t::unused;
+ }
+
+# define DECLARE_MD_STUB(stub) \
+ virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \
+ { return nullptr; }
+
+ DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md);
+ DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md);
+ DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md);
+ DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md);
+ DECLARE_MD_STUB(workspace_md);
+# undef DECLARE_MD_STUB
+
+ const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const {
+ return idx == 0 ? &scratchpad_md_ : nullptr;
+ }
+
+ virtual void init_scratchpad_md() {
+ auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user);
+ mkldnn::impl::dims_t dims = { size };
+ mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims,
+ mkldnn::impl::data_type::u8, mkldnn_x);
+ }
+
+ /** returns the scratchpad size for the given scratchpad mode. */
+ mkldnn::impl::dim_t scratchpad_size(
+ mkldnn::impl::scratchpad_mode_t mode) const {
+ if (mode != attr_.scratchpad_mode_) return 0;
+ return scratchpad_registry().size();
+ }
+
+ virtual int n_inputs() const { return 0; }
+ virtual int n_outputs() const { return 0; }
+
+ virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
+ void *result) const;
+
+ virtual mkldnn::impl::status_t create_primitive(
+ mkldnn::impl::primitive_t **primitive) const = 0;
+
+ virtual const char *name() const { return "mkldnn_primitive_desc"; }
+
+ /* static magic */
+
+ template<typename pd_t>
+ static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
+ const mkldnn::impl::op_desc_t *adesc,
+ const mkldnn::impl::primitive_attr_t *attr,
+ mkldnn::impl::engine_t *engine,
+ const mkldnn::impl::primitive_desc_t *hint_fwd) {
+ using namespace mkldnn::impl;
+ using namespace mkldnn::impl::status;
+ using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
+ if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
+ assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
+ auto hint =
+ reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
+ auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
+ if (_pd == nullptr) return out_of_memory;
+ if (_pd->init() != success) { delete _pd; return unimplemented; }
+ _pd->init_info();
+ _pd->init_scratchpad_md();
+ *pd = _pd;
+ return success;
+ }
+
+protected:
+ mkldnn::impl::engine_t *engine_;
+ mkldnn::impl::primitive_attr_t attr_;
+ mkldnn::impl::primitive_kind_t kind_;
+
+ mkldnn::impl::memory_desc_t scratchpad_md_;
+
+ char info_[MKLDNN_VERBOSE_BUF_LEN];
+
+ mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
+
+protected:
+ /** compares ws between fwd_pd and this (make sense to use for bwd_pd)
+ * Expectation: this already set workspace, and this workspace should
+ * exactly match the one from fwd_pd */
+ bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const {
+ using namespace mkldnn::impl;
+ if (!workspace_md()) return true; // the impl lives fine w/o workspace
+ return fwd_pd && fwd_pd->workspace_md()
+ && *fwd_pd->workspace_md() == *workspace_md();
+ }
+};
+
+#define DECLARE_COMMON_PD_t(impl_name, ...) \
+ virtual pd_t *clone() const override { return new pd_t(*this); } \
+ virtual status_t create_primitive(primitive_t **p) const override { \
+ double ms = get_msec(); \
+ auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
+ ms = get_msec() - ms; \
+ if (mkldnn_verbose()->level >= 2) { \
+ printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
+ fflush(0); \
+ } \
+ return ret; \
+ } \
+ virtual const char *name() const override { return impl_name; }
+#define DECLARE_COMMON_PD_T(impl_name, ...) \
+ DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp
new file mode 100644
index 0000000000..43e5a31ef3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp
@@ -0,0 +1,90 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "memory.hpp"
+#include "primitive.hpp"
+#include "primitive_exec_types.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
+ const mkldnn_exec_arg_t *c_args, exec_args_t &args) {
+ using namespace status;
+
+ if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments;
+
+ int n_inputs = 0;
+ int n_outputs = 0;
+
+ for (int i = 0; i < nargs; ++i) {
+ primitive_arg_index_t arg = c_args[i].arg;
+ auto *mem = c_args[i].memory;
+
+ switch (pd->arg_usage(arg)) {
+ case primitive_desc_t::arg_usage_t::input:
+ if (args.count(arg) != 0) return invalid_arguments;
+ args[arg] = {mem, true};
+ n_inputs++;
+ break;
+ case primitive_desc_t::arg_usage_t::output:
+ if (args.count(arg) != 0) return invalid_arguments;
+ args[arg] = {mem, false};
+ n_outputs++;
+ break;
+ case primitive_desc_t::arg_usage_t::unused:
+ break;
+ }
+ }
+
+ bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md());
+
+ if (n_inputs != pd->n_inputs()) return invalid_arguments;
+ if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0))
+ return invalid_arguments;
+
+ return success;
+}
+
+const void *exec_ctx_t::input(primitive_arg_index_t arg) const {
+ if (args_.count(arg) != 1) return nullptr;
+ const auto ma = args_.at(arg);
+ assert(ma.is_const);
+ void *ptr;
+ status_t status = ma.mem->get_data_handle(&ptr);
+ assert(status == status::success); MAYBE_UNUSED(status);
+ return ptr;
+}
+
+void *exec_ctx_t::output(primitive_arg_index_t arg) const {
+ if (args_.count(arg) != 1) return nullptr;
+ const auto ma = args_.at(arg);
+ assert(!ma.is_const);
+ void *ptr;
+ status_t status = ma.mem->get_data_handle(&ptr);
+ assert(status == status::success); MAYBE_UNUSED(status);
+ return ptr;
+}
+
+const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const {
+ assert(args_.count(arg) == 1);
+ const auto ma = args_.at(arg);
+ assert(!ma.is_const);
+ return ma.mem;
+}
+
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp
new file mode 100644
index 0000000000..0645891da7
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp
@@ -0,0 +1,68 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_EXEC_TYPES_HPP
+#define PRIMITIVE_EXEC_TYPES_HPP
+
+#include <unordered_map>
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "memory.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct memory_arg_t {
+ memory_t *mem;
+ bool is_const;
+};
+
+using exec_args_t = std::unordered_map<primitive_arg_index_t, memory_arg_t>;
+
+status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
+ const mkldnn_exec_arg_t *c_args, exec_args_t &args);
+
+/** Primitive execution context (helps passing stream, memories, and events. */
+struct exec_ctx_t {
+ exec_ctx_t(const exec_ctx_t &) = default;
+ exec_ctx_t(exec_ctx_t &&) = default;
+
+ exec_ctx_t(stream_t *stream): stream_(stream) {}
+ exec_ctx_t(stream_t *stream, exec_args_t &&args)
+ : stream_(stream)
+ , args_(std::move(args)) {}
+
+ stream_t *stream() const { return stream_; }
+ const exec_args_t &args() const { return args_; }
+
+ /* tentative solution... TODO: replace with functions return memory_t */
+ const void *input(primitive_arg_index_t arg) const;
+ void *output(primitive_arg_index_t arg) const;
+
+ const memory_t *memory(primitive_arg_index_t arg) const;
+
+private:
+ stream_t *stream_;
+ exec_args_t args_;
+};
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp
new file mode 100644
index 0000000000..5a1cd7d379
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp
@@ -0,0 +1,89 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+#include "primitive_iterator.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_primitive_desc_iterator_create(
+ primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc,
+ const primitive_attr_t *attr, engine_t *engine,
+ const primitive_desc_t *hint_fwd_pd) {
+ const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
+
+ auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd);
+ if (it == nullptr) return out_of_memory;
+
+ ++(*it);
+ if (*it == it->end()) {
+ delete it;
+ return unimplemented;
+ }
+
+ *iterator = it;
+ return success;
+}
+
+status_t mkldnn_primitive_desc_iterator_next(
+ primitive_desc_iterator_t *iterator) {
+ if (iterator == nullptr) return invalid_arguments;
+ ++(*iterator);
+ return *iterator == iterator->end() ? iterator_ends : success;
+}
+
+primitive_desc_t *mkldnn_primitive_desc_iterator_fetch(
+ const primitive_desc_iterator_t *iterator) {
+ if (iterator == nullptr) return nullptr;
+ return *(*iterator);
+}
+
+status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc,
+ const primitive_desc_t *existing_primitive_desc) {
+ if (utils::any_null(primitive_desc, existing_primitive_desc))
+ return invalid_arguments;
+ return safe_ptr_assign<primitive_desc_t>(*primitive_desc,
+ existing_primitive_desc->clone());
+}
+
+status_t mkldnn_primitive_desc_iterator_destroy(
+ primitive_desc_iterator_t *iterator) {
+ if (iterator != nullptr)
+ delete iterator;
+ return success;
+}
+
+status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc,
+ const_c_op_desc_t c_op_desc, const primitive_attr_t *attr,
+ engine_t *engine, const primitive_desc_t *hint_fwd_pd) {
+ const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
+
+ mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd);
+ ++it;
+ if (it == it.end()) return unimplemented;
+
+ return safe_ptr_assign<primitive_desc_t>(*primitive_desc, *it);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp
new file mode 100644
index 0000000000..4e88ab3aa5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp
@@ -0,0 +1,79 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+#ifndef PRIMITIVE_ITERATOR_HPP
+#define PRIMITIVE_ITERATOR_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+
+struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible {
+ using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
+
+ mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc,
+ const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd)
+ : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc)
+ , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd)
+ , impl_list_(engine_->get_implementation_list()), last_idx_(0)
+ {
+ while (impl_list_[last_idx_] != nullptr) ++last_idx_;
+ }
+ ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; }
+
+ bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
+ { return idx_ == rhs.idx_ && engine_ == rhs.engine_; }
+ bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
+ { return !operator==(rhs); }
+
+ mkldnn::impl::primitive_desc_iterator_t end() const
+ { return mkldnn_primitive_desc_iterator(engine_, last_idx_); }
+
+ mkldnn::impl::primitive_desc_iterator_t &operator++() {
+ if (pd_) { delete pd_; pd_ = nullptr; }
+ while (++idx_ != last_idx_) {
+ auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_,
+ hint_fwd_pd_);
+ if (s == mkldnn::impl::status::success) break;
+ }
+ return *this;
+ }
+
+ mkldnn::impl::primitive_desc_t *operator*() const {
+ if (*this == end() || pd_ == nullptr) return nullptr;
+ return pd_->clone();
+ }
+
+protected:
+ int idx_;
+ mkldnn::impl::engine_t *engine_;
+ mkldnn::impl::primitive_desc_t *pd_;
+ const mkldnn::impl::op_desc_t *op_desc_;
+ const mkldnn::impl::primitive_attr_t attr_;
+ const mkldnn::impl::primitive_desc_t *hint_fwd_pd_;
+ const pd_create_f *impl_list_;
+ int last_idx_;
+
+private:
+ mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx)
+ : idx_(last_idx), engine_(engine), pd_(nullptr)
+ , op_desc_(nullptr), hint_fwd_pd_(nullptr)
+ , impl_list_(nullptr), last_idx_(last_idx) {}
+};
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/query.cpp b/thirdparty/oidn/mkl-dnn/src/common/query.cpp
new file mode 100644
index 0000000000..835cd73581
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/query.cpp
@@ -0,0 +1,59 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc,
+ query_t what, int index, void *result) {
+ if (any_null(primitive_desc, result))
+ return invalid_arguments;
+
+ return primitive_desc->query(what, index, result);
+}
+
+const memory_desc_t *mkldnn_primitive_desc_query_md(
+ const primitive_desc_t *primitive_desc, query_t what, int index) {
+ const memory_desc_t *res_md = nullptr;
+ bool args_ok = true
+ && primitive_desc != nullptr
+ && (what & query::some_md) == query::some_md
+ && what != query::some_md
+ && mkldnn_primitive_desc_query(primitive_desc,
+ what, index, &res_md) == success;
+ return args_ok ? res_md : nullptr;
+}
+
+int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc,
+ query_t what, int index) {
+ int res_s32;
+ bool args_ok = primitive_desc != nullptr
+ && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32)
+ && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32)
+ == success;
+ return args_ok ? res_s32 : 0;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
new file mode 100644
index 0000000000..d11f1a0361
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
@@ -0,0 +1,68 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "reorder_pd.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_reorder_primitive_desc_create(
+ primitive_desc_t **reorder_pd,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md,
+ const primitive_attr_t *attr) {
+ if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md))
+ return invalid_arguments;
+
+ auto s_ek = src_engine->kind();
+ auto d_ek = dst_engine->kind();
+ if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek)))
+ return invalid_arguments;
+
+ auto r_pd = reinterpret_cast<reorder_pd_t **>(reorder_pd);
+ auto s_mdw = memory_desc_wrapper(*src_md);
+ auto d_mdw = memory_desc_wrapper(*dst_md);
+
+ if (!s_mdw.consistent_with(d_mdw))
+ return invalid_arguments;
+
+ auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine;
+
+ const primitive_attr_t dummy_attr;
+ if (attr == NULL)
+ attr = &dummy_attr;
+
+ for (auto r = e->get_reorder_implementation_list(); *r; ++r) {
+ if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md)
+ == success) {
+ (*r_pd)->init_info();
+ (*r_pd)->init_scratchpad_md();
+ return success;
+ }
+ }
+ return unimplemented;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp
new file mode 100644
index 0000000000..963cb0f58a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp
@@ -0,0 +1,85 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef REORDER_PD_HPP
+#define REORDER_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "primitive_attr.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct reorder_pd_t: public primitive_desc_t {
+ reorder_pd_t(engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md)
+ : primitive_desc_t(engine, attr, primitive_kind::reorder)
+ , src_engine_(src_engine)
+ , dst_engine_(dst_engine)
+ , scratchpad_engine_(nullptr)
+ , src_md_(*src_md)
+ , dst_md_(*dst_md)
+ {}
+
+ virtual const op_desc_t *op_desc() const override { return nullptr; }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_FROM)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_TO)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &src_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override { return 1; }
+
+ float alpha() const { return attr()->output_scales_.scales_[0]; }
+ float beta() const {
+ const int sum_idx = attr()->post_ops_.find(primitive_kind::sum);
+ return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale;
+ }
+ virtual mkldnn::impl::engine_t *scratchpad_engine() const override
+ { return scratchpad_engine_; }
+
+protected:
+ engine_t *src_engine_;
+ engine_t *dst_engine_;
+ engine_t *scratchpad_engine_;
+
+ memory_desc_t src_md_;
+ memory_desc_t dst_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
new file mode 100644
index 0000000000..36967431a6
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
@@ -0,0 +1,400 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "cpu/gemm/os_blas.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::types;
+using namespace mkldnn::impl::utils;
+
+namespace {
+memory_desc_t copy_maybe_null(const memory_desc_t *md) {
+ return md ? *md : zero_md();
+}
+
+rnn_desc_t zero_rnn_desc() {
+ auto rd = rnn_desc_t();
+ rd.src_layer_desc = zero_md();
+ rd.src_iter_desc = zero_md();
+ rd.weights_layer_desc = zero_md();
+ rd.weights_iter_desc = zero_md();
+ rd.bias_desc = zero_md();
+ rd.dst_layer_desc = zero_md();
+ rd.dst_iter_desc = zero_md();
+ rd.diff_src_layer_desc = zero_md();
+ rd.diff_src_iter_desc = zero_md();
+ rd.diff_weights_layer_desc = zero_md();
+ rd.diff_weights_iter_desc = zero_md();
+ rd.diff_bias_desc = zero_md();
+ rd.diff_dst_layer_desc = zero_md();
+ rd.diff_dst_iter_desc = zero_md();
+ return rd;
+}
+}
+
+/* Public C Api */
+
+status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc,
+ mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f,
+ unsigned int flags, float alpha, float clipping) {
+ using namespace mkldnn::impl::alg_kind;
+
+ bool args_ok = true
+ && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru,
+ gru_linear_before_reset)
+ && IMPLICATION(cell_kind == vanilla_rnn,
+ one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic));
+ if (!args_ok)
+ return invalid_arguments;
+
+ auto rcd = mkldnn_rnn_cell_desc_t();
+
+ rcd.cell_kind = cell_kind;
+ rcd.activation_kind = act_f;
+ rcd.flags = flags;
+ rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0;
+ rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0;
+
+ *rnn_cell_desc = rcd;
+
+ return success;
+}
+
+int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) {
+ switch (rnn_cell_desc->cell_kind) {
+ case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
+ case mkldnn::impl::alg_kind::vanilla_gru: return 3;
+ case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3;
+ case mkldnn::impl::alg_kind::vanilla_lstm: return 4;
+ default: assert(!"unknown cell kind"); return 0;
+ }
+ return 0;
+}
+
+int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) {
+ switch (rnn_cell_desc->cell_kind) {
+ case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
+ case mkldnn::impl::alg_kind::vanilla_gru: return 1;
+ case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1;
+ case mkldnn::impl::alg_kind::vanilla_lstm: return 2;
+ default: assert(!"unknown cell kind"); return 0;
+ }
+ return 0;
+}
+
+status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc,
+ prop_kind_t prop_kind, const memory_desc_t *src_layer_desc,
+ const memory_desc_t *src_iter_desc,
+ const memory_desc_t *weights_layer_desc,
+ const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_layer_desc,
+ const memory_desc_t *dst_iter_desc) {
+ using namespace data_type;
+ data_type_t src_layer_dt = src_layer_desc->data_type;
+ data_type_t dst_layer_dt = dst_layer_desc->data_type;
+ data_type_t weights_iter_dt = weights_iter_desc->data_type;
+ data_type_t weights_layer_dt = weights_layer_desc->data_type;
+
+ bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt,
+ weights_layer_dt)
+ && IMPLICATION(!is_zero_md(src_iter_desc),
+ src_iter_desc->data_type == f32)
+ && IMPLICATION(!is_zero_md(dst_iter_desc),
+ dst_iter_desc->data_type == f32)
+ && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
+
+#if USE_MKL_PACKED_GEMM
+ bool is_u8u8u8 = src_layer_dt == u8
+ && IMPLICATION(!is_zero_md(src_iter_desc),
+ src_iter_desc->data_type == u8)
+ && IMPLICATION(!is_zero_md(dst_iter_desc),
+ dst_iter_desc->data_type == u8)
+ && one_of(dst_layer_dt, u8, f32)
+ && everyone_is(s8, weights_iter_dt, weights_layer_dt)
+ && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
+
+ bool is_f32u8f32 = src_layer_dt == u8
+ && IMPLICATION(!is_zero_md(src_iter_desc),
+ src_iter_desc->data_type == f32)
+ && IMPLICATION(!is_zero_md(dst_iter_desc),
+ dst_iter_desc->data_type == f32)
+ && one_of(dst_layer_dt, u8, f32)
+ && everyone_is(s8, weights_iter_dt, weights_layer_dt)
+ && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
+
+ bool is_inference = prop_kind == prop_kind::forward_inference;
+ bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm;
+
+ return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference))
+ ? success
+ : unimplemented;
+#else
+ return is_f32 ? success : unimplemented;
+#endif
+}
+
+status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc,
+ rnn_direction_t direction, int L, int D, int T, int N, int S, int G,
+ int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc,
+ const memory_desc_t *src_iter_desc,
+ const memory_desc_t *weights_layer_desc,
+ const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_layer_desc,
+ const memory_desc_t *dst_iter_desc) {
+ bool args_ok;
+
+ // * algorithm specific
+ args_ok = true
+ && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru,
+ DIC == SIC);
+ if (!args_ok) return invalid_arguments;
+ int extra_bias =
+ rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset;
+
+ // * on num layers
+ args_ok = true
+ && L == weights_layer_desc->dims[0]
+ && L == weights_iter_desc->dims[0]
+ && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0])
+ && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0])
+ && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on num directions
+ args_ok = true
+ && D == weights_layer_desc->dims[1]
+ && D == weights_iter_desc->dims[1]
+ && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1])
+ && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1])
+ && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on num iterations
+ args_ok = true
+ && T == src_layer_desc->dims[0]
+ && T == dst_layer_desc->dims[0];
+ if (!args_ok) return invalid_arguments;
+
+ // * on mb
+ args_ok = true
+ && N == src_layer_desc->dims[1]
+ && N == dst_layer_desc->dims[1]
+ && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3])
+ && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on num gates
+ args_ok = true
+ && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc)
+ && G == weights_layer_desc->dims[3]
+ && G == weights_iter_desc->dims[3]
+ && IMPLICATION(!is_zero_md(bias_desc),
+ G + extra_bias == bias_desc->dims[2]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on num states
+ args_ok = true
+ && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc)
+ && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2])
+ && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on slc
+ args_ok = true
+ && SLC == weights_layer_desc->dims[2]
+ && SLC == src_layer_desc->dims[2];
+ if (!args_ok) return invalid_arguments;
+
+ // * on sic
+ args_ok = true
+ && SIC == weights_iter_desc->dims[2]
+ && IMPLICATION(!is_zero_md(src_iter_desc),
+ SIC == src_iter_desc->dims[4]);
+ if (!args_ok) return invalid_arguments;
+
+ // * on dlc
+ int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1;
+ args_ok = true
+ && DLC == dlc_multiplier * DIC
+ && DLC == dst_layer_desc->dims[2];
+ if (!args_ok) return invalid_arguments;
+
+ // * on dic
+ args_ok = true
+ && DIC == weights_layer_desc->dims[4]
+ && DIC == weights_iter_desc->dims[4]
+ && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3])
+ && IMPLICATION(!is_zero_md(dst_iter_desc),
+ DIC == dst_iter_desc->dims[4]);
+ if (!args_ok) return invalid_arguments;
+
+ // * unrolling/fusion conditions
+ args_ok = true
+ && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC)
+ && IMPLICATION(T > 1, SIC == DIC);
+ if (!args_ok) return invalid_arguments;
+
+ return success;
+}
+
+status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
+ prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
+ const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
+ const memory_desc_t *src_iter_desc,
+ const memory_desc_t *weights_layer_desc,
+ const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_layer_desc,
+ const memory_desc_t *dst_iter_desc) {
+ bool args_ok = true && rnn_cell_desc != nullptr
+ && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
+ dst_layer_desc);
+ if (!args_ok) return invalid_arguments;
+
+ //check dimensions consistency
+ int L = weights_layer_desc->dims[0];
+ int T = src_layer_desc->dims[0];
+ int N = src_layer_desc->dims[1];
+ const int D = one_of(direction, mkldnn_unidirectional_left2right,
+ mkldnn_unidirectional_right2left) ?
+ 1 :
+ 2;
+ int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
+ int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
+ int SLC = src_layer_desc->dims[2];
+ int SIC = weights_iter_desc->dims[2];
+ int DLC = dst_layer_desc->dims[2];
+ int DIC = weights_layer_desc->dims[4];
+
+ CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
+ G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
+ weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
+ dst_iter_desc));
+
+ CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind,
+ src_layer_desc, src_iter_desc, weights_layer_desc,
+ weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc));
+
+ // Create the descriptor
+ mkldnn_rnn_desc_t rd = zero_rnn_desc();
+
+ rd.primitive_kind = primitive_kind::rnn;
+ rd.prop_kind = prop_kind;
+ rd.cell_desc = *rnn_cell_desc;
+ rd.direction = direction;
+ rd.src_layer_desc = copy_maybe_null(src_layer_desc);
+ rd.src_iter_desc = copy_maybe_null(src_iter_desc);
+ rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
+ rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
+ rd.bias_desc = copy_maybe_null(bias_desc);
+ rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
+ rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
+
+ *rnn_desc = rd;
+
+ return success;
+}
+
+status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
+ prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
+ const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
+ const memory_desc_t *src_iter_desc,
+ const memory_desc_t *weights_layer_desc,
+ const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+ const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
+ const memory_desc_t *diff_src_layer_desc,
+ const memory_desc_t *diff_src_iter_desc,
+ const memory_desc_t *diff_weights_layer_desc,
+ const memory_desc_t *diff_weights_iter_desc,
+ const memory_desc_t *diff_bias_desc,
+ const memory_desc_t *diff_dst_layer_desc,
+ const memory_desc_t *diff_dst_iter_desc) {
+ bool args_ok = true
+ && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
+ dst_layer_desc, diff_src_layer_desc,
+ diff_weights_layer_desc, diff_weights_iter_desc,
+ diff_dst_layer_desc);
+ if (!args_ok)
+ return invalid_arguments;
+
+ auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) {
+ return is_zero_md(a_md) == is_zero_md(b_md);
+ };
+
+ args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
+ && xnor_md(dst_iter_desc, diff_dst_iter_desc)
+ && xnor_md(src_iter_desc, diff_src_iter_desc);
+ if (!args_ok)
+ return invalid_arguments;
+
+ //check dimensions consistency
+ int L = weights_layer_desc->dims[0];
+ int T = src_layer_desc->dims[0];
+ int N = src_layer_desc->dims[1];
+ const int D = one_of(direction, mkldnn_unidirectional_left2right,
+ mkldnn_unidirectional_right2left) ?
+ 1 :
+ 2;
+ int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
+ int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
+ int SLC = src_layer_desc->dims[2];
+ int SIC = weights_iter_desc->dims[2];
+ int DLC = dst_layer_desc->dims[2];
+ int DIC = weights_layer_desc->dims[4];
+
+ status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
+ G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
+ weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
+ dst_iter_desc);
+ if (st != success) return st;
+
+ st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
+ G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc,
+ diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc,
+ diff_dst_layer_desc, diff_dst_iter_desc);
+ if (st != success) return st;
+
+ mkldnn_rnn_desc_t rd = zero_rnn_desc();
+
+ rd.primitive_kind = primitive_kind::rnn;
+ rd.prop_kind = prop_kind;
+ rd.cell_desc = *rnn_cell_desc;
+ rd.direction = direction;
+
+ rd.src_layer_desc = copy_maybe_null(src_layer_desc);
+ rd.src_iter_desc = copy_maybe_null(src_iter_desc);
+ rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
+ rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
+ rd.bias_desc = copy_maybe_null(bias_desc);
+ rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
+ rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
+ rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc);
+ rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc);
+ rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc);
+ rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc);
+ rd.diff_bias_desc = copy_maybe_null(diff_bias_desc);
+ rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc);
+ rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc);
+
+ *rnn_desc = rd;
+
+ return success;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
new file mode 100644
index 0000000000..1ee2ba1114
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
@@ -0,0 +1,280 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef RNN_PD_HPP
+#define RNN_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct rnn_fwd_pd_t;
+
+struct rnn_pd_t : public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::rnn;
+
+ rnn_pd_t(engine_t *engine,
+ const rnn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const rnn_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , src_layer_md_(desc_.src_layer_desc)
+ , src_iter_md_(desc_.src_iter_desc)
+ , weights_layer_md_(desc_.weights_layer_desc)
+ , weights_iter_md_(desc_.weights_iter_desc)
+ , bias_md_(desc_.bias_desc)
+ , dst_layer_md_(desc_.dst_layer_desc)
+ , dst_iter_md_(desc_.dst_iter_desc)
+ , ws_md_()
+ {}
+
+ const rnn_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::rnn_d: *(const rnn_desc_t **)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override {
+ if (index == 0) return &src_layer_md_;
+ if (index == 1 && with_src_iter()) return &src_iter_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *weights_md(int index = 0) const override {
+ if (index == 0) return &weights_layer_md_;
+ if (index == 1) return &weights_iter_md_;
+ if (index == 2 && with_bias()) return &bias_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *dst_md(int index = 0) const override {
+ if (index == 0) return &dst_layer_md_;
+ if (index == 1 && with_dst_iter()) return &dst_iter_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *workspace_md(int index = 0) const override
+ { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+ /* common pooling aux functions */
+
+ bool is_training() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::backward);
+ }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ dim_t T() const { return desc_.src_layer_desc.dims[0]; }
+ dim_t MB() const { return desc_.src_layer_desc.dims[1]; }
+
+ dim_t L() const { return desc_.weights_layer_desc.dims[0]; }
+ dim_t D() const { return desc_.weights_layer_desc.dims[1]; }
+
+ dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; }
+
+ dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; }
+ dim_t G() const { return desc_.weights_layer_desc.dims[3]; }
+ dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; }
+
+ dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; }
+
+ bool with_bias() const
+ { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
+
+ bool with_src_iter() const
+ { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); }
+
+ bool with_dst_iter() const
+ { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); }
+
+ mkldnn::impl::alg_kind_t cell_kind() const
+ { return desc_.cell_desc.cell_kind; }
+ mkldnn::impl::alg_kind_t activation_kind() const
+ { return desc_.cell_desc.activation_kind; }
+
+ bool is_lbr() const
+ { return cell_kind() == mkldnn_gru_linear_before_reset; }
+
+ mkldnn_rnn_direction_t direction() const { return desc_.direction; }
+
+protected:
+ rnn_desc_t desc_;
+ const rnn_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t src_layer_md_;
+ memory_desc_t src_iter_md_;
+ memory_desc_t weights_layer_md_;
+ memory_desc_t weights_iter_md_;
+ memory_desc_t bias_md_;
+ memory_desc_t dst_layer_md_;
+ memory_desc_t dst_iter_md_;
+
+ memory_desc_t ws_md_;
+};
+
+struct rnn_fwd_pd_t: public rnn_pd_t {
+ typedef rnn_fwd_pd_t base_class;
+ typedef rnn_fwd_pd_t hint_class;
+
+ rnn_fwd_pd_t(engine_t *engine,
+ const rnn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const rnn_fwd_pd_t *hint_fwd_pd)
+ : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC_LAYER)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter())
+ return arg_usage_t::input;
+
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
+ MKLDNN_ARG_WEIGHTS_ITER))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_BIAS && with_bias())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST_LAYER)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter())
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && is_training())
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual int n_inputs() const override
+ { return 3 + with_bias() + with_src_iter(); }
+ virtual int n_outputs() const override
+ { return 1 + with_dst_iter() + is_training(); }
+};
+
+struct rnn_bwd_pd_t : public rnn_pd_t {
+ typedef rnn_bwd_pd_t base_class;
+ typedef rnn_fwd_pd_t hint_class;
+
+ rnn_bwd_pd_t(engine_t *engine,
+ const rnn_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const rnn_fwd_pd_t *hint_fwd_pd)
+ : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_src_layer_md_(desc_.diff_src_layer_desc)
+ , diff_src_iter_md_(desc_.diff_src_iter_desc)
+ , diff_weights_layer_md_(desc_.diff_weights_layer_desc)
+ , diff_weights_iter_md_(desc_.diff_weights_iter_desc)
+ , diff_bias_md_(desc_.diff_bias_desc)
+ , diff_dst_layer_md_(desc_.diff_dst_layer_desc)
+ , diff_dst_iter_md_(desc_.diff_dst_iter_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER,
+ MKLDNN_ARG_DIFF_DST_LAYER))
+ return arg_usage_t::input;
+
+ if (with_src_iter()) {
+ if (arg == MKLDNN_ARG_SRC_ITER)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC_ITER)
+ return arg_usage_t::output;
+ }
+
+ if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
+ MKLDNN_ARG_WEIGHTS_ITER))
+ return arg_usage_t::input;
+
+ if (with_bias()) {
+ if (arg == MKLDNN_ARG_BIAS)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_BIAS)
+ return arg_usage_t::output;
+ }
+
+ if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER)
+ && with_dst_iter())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_WORKSPACE)
+ return arg_usage_t::input;
+
+ if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER,
+ MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
+ MKLDNN_ARG_DIFF_WEIGHTS_ITER))
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override {
+ if (index == 0) return &diff_src_layer_md_;
+ if (index == 1 && with_src_iter()) return &diff_src_iter_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *diff_weights_md(
+ int index = 0) const override {
+ if (index == 0) return &diff_weights_layer_md_;
+ if (index == 1) return &diff_weights_iter_md_;
+ if (index == 2 && with_bias()) return &diff_bias_md_;
+ return nullptr;
+ }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override {
+ if (index == 0) return &diff_dst_layer_md_;
+ if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_;
+ return nullptr;
+ }
+
+ virtual int n_inputs() const override
+ { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); }
+ virtual int n_outputs() const override
+ { return 3 + with_src_iter() + with_bias(); }
+
+protected:
+ memory_desc_t diff_src_layer_md_;
+ memory_desc_t diff_src_iter_md_;
+ memory_desc_t diff_weights_layer_md_;
+ memory_desc_t diff_weights_iter_md_;
+ memory_desc_t diff_bias_md_;
+ memory_desc_t diff_dst_layer_md_;
+ memory_desc_t diff_dst_iter_md_;
+};
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp
new file mode 100644
index 0000000000..6bc14fc72a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp
@@ -0,0 +1,112 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "scratchpad.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+/* Allocating memory buffers on a page boundary to reduce TLB/page misses */
+const size_t page_size = 2097152;
+
+/*
+ Implementation of the scratchpad_t interface that is compatible with
+ a concurrent execution
+*/
+struct concurent_scratchpad_t : public scratchpad_t {
+ concurent_scratchpad_t(size_t size) {
+ size_ = size;
+ scratchpad_ = (char *) malloc(size, page_size);
+ assert(scratchpad_ != nullptr);
+ }
+
+ ~concurent_scratchpad_t() {
+ free(scratchpad_);
+ }
+
+ virtual char *get() const {
+ return scratchpad_;
+ }
+
+private:
+ char *scratchpad_;
+ size_t size_;
+};
+
+/*
+ Implementation of the scratchpad_t interface that uses a global
+ scratchpad
+*/
+
+struct global_scratchpad_t : public scratchpad_t {
+ global_scratchpad_t(size_t size) {
+ if (size > size_) {
+ if (scratchpad_ != nullptr) free(scratchpad_);
+ size_ = size;
+ scratchpad_ = (char *) malloc(size, page_size);
+ assert(scratchpad_ != nullptr);
+ }
+ reference_count_++;
+ }
+
+ ~global_scratchpad_t() {
+ reference_count_--;
+ if (reference_count_ == 0) {
+ free(scratchpad_);
+ scratchpad_ = nullptr;
+ size_ = 0;
+ }
+ }
+
+ virtual char *get() const {
+ return scratchpad_;
+ }
+
+private:
+ /*
+ Using thread-local here is unnecessary and even buggy! All threads
+ actually share the same scratchpad, which is created and queried only
+ on the main thread. If the scratchpad is queried on some thread other
+ than the one it was created on (e.g. the application calls the API from
+ multiple threads), thread-local causes a segfault because the scratchpad
+ is uninitialized on the current thread.
+ */
+ /*thread_local*/ static char *scratchpad_;
+ /*thread_local*/ static size_t size_;
+ /*thread_local*/ static unsigned int reference_count_;
+};
+
+/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr;
+/*thread_local*/ size_t global_scratchpad_t::size_ = 0;
+/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0;
+
+
+/*
+ Scratchpad creation routine
+*/
+scratchpad_t *create_scratchpad(size_t size) {
+#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC
+ return new global_scratchpad_t(size);
+#else
+ return new concurent_scratchpad_t(size);
+#endif
+}
+
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp
new file mode 100644
index 0000000000..f7a246bc99
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/scratchpad.hpp
@@ -0,0 +1,36 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef COMMON_SCRATCHPAD_HPP
+#define COMMON_SCRATCHPAD_HPP
+
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct scratchpad_t {
+ virtual ~scratchpad_t() {}
+ virtual char *get() const = 0;
+};
+
+scratchpad_t *create_scratchpad(size_t size);
+
+}
+}
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp
new file mode 100644
index 0000000000..e32e735224
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle.cpp
@@ -0,0 +1,72 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t shuffle_desc_init(shuffle_desc_t *shuffle_desc, prop_kind_t prop_kind,
+ const memory_desc_t *data_desc, int axis, dim_t group_size) {
+ bool args_ok = true
+ && !any_null(shuffle_desc, data_desc)
+ && one_of(prop_kind, forward_training, forward_inference,
+ backward, backward_data)
+ && axis >= 0 && axis < data_desc->ndims
+ && group_size > 0 && group_size <= data_desc->dims[axis];
+ if (!args_ok) return invalid_arguments;
+
+ auto sd = shuffle_desc_t();
+ sd.primitive_kind = primitive_kind::shuffle;
+ sd.prop_kind = prop_kind;
+ sd.data_desc = *data_desc;
+ sd.axis = axis;
+ sd.group_size = group_size;
+
+ bool consistency = true
+ && sd.data_desc.dims[axis] % sd.group_size == 0;
+ if (!consistency) return invalid_arguments;
+
+ *shuffle_desc = sd;
+ return success;
+}
+}
+
+status_t mkldnn_shuffle_forward_desc_init(shuffle_desc_t *shuffle_desc,
+ prop_kind_t prop_kind, const memory_desc_t *data_desc, int axis,
+ dim_t group_size) {
+ if (!one_of(prop_kind, forward_training, forward_inference))
+ return invalid_arguments;
+ return shuffle_desc_init(shuffle_desc, prop_kind, data_desc, axis,
+ group_size);
+}
+
+status_t mkldnn_shuffle_backward_desc_init(shuffle_desc_t *shuffle_desc,
+ const memory_desc_t *diff_data_desc, int axis, dim_t group_size) {
+ return shuffle_desc_init(shuffle_desc, backward_data, diff_data_desc, axis,
+ group_size);
+}
+
+// vim: et ts=5 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp
new file mode 100644
index 0000000000..cc5553fe7f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/shuffle_pd.hpp
@@ -0,0 +1,121 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SHUFFLE_PD_HPP
+#define SHUFFLE_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct shuffle_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::shuffle;
+
+ typedef shuffle_pd_t base_class;
+ typedef shuffle_pd_t hint_class;
+
+ shuffle_pd_t(engine_t *engine,
+ const shuffle_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const shuffle_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ {}
+
+ const shuffle_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::shuffle_d:
+ *(const shuffle_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (is_fwd()) {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+ } else {
+ if (arg == MKLDNN_ARG_DIFF_DST)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+ }
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 && is_fwd() ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 && is_fwd() ? &data_md_ : nullptr; }
+
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 && !is_fwd() ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 && !is_fwd() ? &data_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override { return 1; }
+
+ /* shuffle aux functions */
+
+ dim_t MB() const { return data_md()->dims[0]; }
+ dim_t C() const { return ndims() >= 2 ? data_md()->dims[1] : 1; }
+ dim_t D() const { return ndims() >= 5 ? data_md()->dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_md()->dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_md()->dims[ndims() - 1] : 1; }
+
+ int ndims() const { return data_md()->ndims; }
+
+ int axis() const { return desc_.axis; }
+ dim_t group_size() const { return desc_.group_size; }
+ dim_t axis_size() const { return data_md()->dims[axis()]; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+ const memory_desc_t *data_md() const { return &data_md_; }
+
+protected:
+ shuffle_desc_t desc_;
+ const shuffle_pd_t *hint_fwd_pd_;
+ memory_desc_t data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp
new file mode 100644
index 0000000000..82848e3d1f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/softmax.cpp
@@ -0,0 +1,68 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "memory_desc_wrapper.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t softmax_desc_init(softmax_desc_t *softmax_desc, prop_kind_t prop_kind,
+ const memory_desc_t *data_desc, const memory_desc_t *diff_desc, int softmax_axis) {
+ bool args_ok = true
+ && !any_null(softmax_desc, data_desc)
+ && 0 <= softmax_axis
+ && softmax_axis < data_desc->ndims;
+ if (!args_ok) return invalid_arguments;
+
+ auto sd = softmax_desc_t();
+ sd.primitive_kind = primitive_kind::softmax;
+ sd.prop_kind = prop_kind;
+
+ bool is_bwd = (sd.prop_kind == backward_data);
+ sd.data_desc = *data_desc;
+ sd.diff_desc = is_bwd ? *diff_desc : zero_md();
+ sd.softmax_axis = softmax_axis;
+
+ *softmax_desc = sd;
+ return success;
+}
+}
+
+status_t mkldnn_softmax_forward_desc_init(softmax_desc_t *softmax_desc,
+ prop_kind_t prop_kind, const memory_desc_t *data_desc,
+ int softmax_axis) {
+ if (!one_of(prop_kind, forward_inference, forward_training))
+ return invalid_arguments;
+ return softmax_desc_init(softmax_desc, prop_kind, data_desc, nullptr, softmax_axis);
+}
+
+status_t mkldnn_softmax_backward_desc_init(softmax_desc_t *softmax_desc,
+ const memory_desc_t *diff_desc, const mkldnn_memory_desc_t *data_desc,
+ int softmax_axis) {
+ return softmax_desc_init(softmax_desc, prop_kind::backward_data,
+ data_desc, diff_desc, softmax_axis);
+}
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp
new file mode 100644
index 0000000000..8a16ce901c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/softmax_pd.hpp
@@ -0,0 +1,161 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SOFTMAX_PD_HPP
+#define SOFTMAX_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct softmax_fwd_pd_t;
+
+struct softmax_pd_t: public primitive_desc_t {
+ static constexpr auto base_pkind = primitive_kind::softmax;
+
+ softmax_pd_t(engine_t *engine,
+ const softmax_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const softmax_fwd_pd_t *hint_fwd_pd)
+ : primitive_desc_t(engine, attr, base_pkind)
+ , desc_(*adesc)
+ , hint_fwd_pd_(hint_fwd_pd)
+ , data_md_(desc_.data_desc)
+ {}
+
+ const softmax_desc_t *desc() const { return &desc_; }
+ virtual const op_desc_t *op_desc() const override
+ { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual status_t query(query_t what, int idx, void *result) const override {
+ switch (what) {
+ case query::softmax_d:
+ *(const softmax_desc_t**)result = desc(); break;
+ default: return primitive_desc_t::query(what, idx, result);
+ }
+ return status::success;
+ }
+
+ /* common softmax aux functions */
+
+ dim_t MB() const { return data_desc().dims[0]; }
+ dim_t C() const { return data_desc().dims[1]; }
+ dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+ dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+ dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+ int ndims() const { return data_desc().ndims; }
+
+ bool is_fwd() const {
+ return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ }
+
+protected:
+ softmax_desc_t desc_;
+ const softmax_fwd_pd_t *hint_fwd_pd_;
+
+ memory_desc_t data_md_;
+
+private:
+ const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct softmax_fwd_pd_t: public softmax_pd_t {
+ typedef softmax_fwd_pd_t base_class;
+ typedef softmax_fwd_pd_t hint_class;
+
+ softmax_fwd_pd_t(engine_t *engine,
+ const softmax_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const softmax_fwd_pd_t *hint_fwd_pd)
+ : softmax_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg == MKLDNN_ARG_SRC)
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return 1; }
+ virtual int n_outputs() const override
+ { return 1 + (workspace_md() != nullptr); }
+};
+
+struct softmax_bwd_pd_t: public softmax_pd_t {
+ typedef softmax_bwd_pd_t base_class;
+ typedef softmax_fwd_pd_t hint_class;
+
+ softmax_bwd_pd_t(engine_t *engine,
+ const softmax_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const softmax_fwd_pd_t *hint_fwd_pd)
+ : softmax_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , diff_data_md_(desc_.diff_desc)
+ {}
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (utils::one_of(arg, MKLDNN_ARG_DST, MKLDNN_ARG_DIFF_DST))
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DIFF_SRC)
+ return arg_usage_t::output;
+
+ if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+ return arg_usage_t::input;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+ virtual const memory_desc_t *diff_src_md(int index = 0) const override
+ { return index == 0 ? &diff_data_md_ : nullptr; }
+
+ virtual int n_inputs() const override
+ { return 2 + (workspace_md() != nullptr); }
+ virtual int n_outputs() const override { return 1; }
+
+protected:
+ memory_desc_t diff_data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.cpp b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp
new file mode 100644
index 0000000000..00af8935c0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/stream.cpp
@@ -0,0 +1,46 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "stream.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+/* API */
+
+status_t mkldnn_stream_create(stream_t **stream, engine_t *engine,
+ unsigned flags) {
+ bool args_ok = true
+ && !utils::any_null(stream, engine)
+ && flags == stream_flags::default_flags;
+ if (!args_ok)
+ return invalid_arguments;
+
+ return safe_ptr_assign<stream_t>(*stream, new stream_t(engine, flags));
+}
+
+status_t mkldnn_stream_destroy(stream_t *stream) {
+ delete stream;
+ return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/stream.hpp b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp
new file mode 100644
index 0000000000..f010e5f6ed
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/stream.hpp
@@ -0,0 +1,44 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef STREAM_HPP
+#define STREAM_HPP
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+
+struct mkldnn_stream: public mkldnn::impl::c_compatible {
+ mkldnn_stream(mkldnn::impl::engine_t *engine, unsigned flags)
+ : engine_(engine), flags_(flags) {}
+ virtual ~mkldnn_stream() {}
+
+ /** returns stream's engine */
+ mkldnn::impl::engine_t *engine() const { return engine_; }
+
+ /** returns stream's kind */
+ unsigned flags() const { return flags_; }
+
+protected:
+ mkldnn::impl::engine_t *engine_;
+ unsigned flags_;
+};
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum.cpp b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp
new file mode 100644
index 0000000000..365663c0f8
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/sum.cpp
@@ -0,0 +1,79 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "sum_pd.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_sum_primitive_desc_create(primitive_desc_t **sum_pd,
+ const memory_desc_t *dst_md, int n, const float *scales,
+ const memory_desc_t *src_mds, const primitive_attr_t *attr,
+ engine_t *engine) {
+ bool args_ok = !any_null(sum_pd, src_mds, scales) && n > 0;
+ if (!args_ok) return invalid_arguments;
+
+ const primitive_attr_t dummy_attr;
+ if (attr == NULL)
+ attr = &dummy_attr;
+
+ const int ndims = src_mds[0].ndims;
+ const dims_t &dims = src_mds[0].dims;
+ const data_type_t dt = src_mds[0].data_type;
+
+ for (int i = 1; i < n; ++i) {
+ if (src_mds[i].ndims != ndims) return invalid_arguments;
+ for (int d = 0; d < ndims; ++d) {
+ if (src_mds[i].dims[d] != dims[d])
+ return invalid_arguments;
+ }
+ if (src_mds[i].data_type != dt) return invalid_arguments;
+ }
+
+ memory_desc_t dummy_dst_md;
+ if (dst_md) {
+ if (dst_md->ndims != ndims) return invalid_arguments;
+ for (int d = 0; d < ndims; ++d) {
+ if (dst_md->dims[d] != dims[d])
+ return invalid_arguments;
+ }
+ } else {
+ dummy_dst_md = src_mds[0];
+ dummy_dst_md.format_kind = format_kind::any;
+ dst_md = &dummy_dst_md;
+ }
+
+ auto s_pd = reinterpret_cast<sum_pd_t **>(sum_pd);
+
+ for (auto s = engine->get_sum_implementation_list(); *s; ++s) {
+ if ((*s)(s_pd, engine, attr, dst_md, n, scales, src_mds) == success) {
+ (*s_pd)->init_info();
+ (*s_pd)->init_scratchpad_md();
+ return success;
+ }
+ }
+ return unimplemented;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp
new file mode 100644
index 0000000000..80254667df
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/sum_pd.hpp
@@ -0,0 +1,143 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SUM_PD_HPP
+#define SUM_PD_HPP
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct sum_pd_t: public primitive_desc_t {
+ sum_pd_t(engine_t *engine, const primitive_attr_t *attr,
+ const memory_desc_t *dst_md, int n, const float *scales,
+ const memory_desc_t *src_mds)
+ : primitive_desc_t(engine, attr, primitive_kind::sum)
+ , n_(n), dst_md_(*dst_md)
+ {
+ scales_.reserve(n_);
+ for (int i = 0; i < n_; ++i) scales_.push_back(scales[i]);
+ src_mds_.reserve(n_);
+ for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
+ }
+
+ virtual void init_info() override { impl::init_info(this, this->info_); }
+
+ virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+ if (arg >= MKLDNN_ARG_MULTIPLE_SRC
+ && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
+ return arg_usage_t::input;
+
+ if (arg == MKLDNN_ARG_DST)
+ return arg_usage_t::output;
+
+ return primitive_desc_t::arg_usage(arg);
+ }
+
+ virtual const memory_desc_t *src_md(int index = 0) const override
+ { return index < n_inputs() ? &src_mds_[index] : nullptr; }
+ virtual const memory_desc_t *dst_md(int index = 0) const override
+ { return index == 0 ? &dst_md_ : nullptr; }
+
+ virtual int n_inputs() const override { return n_; }
+ virtual int n_outputs() const override { return 1; }
+
+ const float *scales() const { return &scales_[0]; }
+
+protected:
+ int n_;
+ nstl::vector<float> scales_;
+ memory_desc_t dst_md_;
+ nstl::vector<memory_desc_t> src_mds_;
+
+protected:
+ /* inits dst_md_ in simple cases. The call may fail. */
+ status_t init() {
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper src_d(&src_mds_[i]);
+ if (!src_d.is_blocking_desc() || src_d.is_additional_buffer())
+ return status::unimplemented;
+ }
+ bool ok = true
+ && set_default_params() == status::success
+ && attr()->has_default_values();
+ return ok ? status::success : status::unimplemented;
+ }
+
+ status_t set_default_params() {
+ if (dst_md_.format_kind != format_kind::any)
+ return status::success;
+
+ /* The stupidest ever heuristics (but not the same as we had before):
+ * - Pick the first non-plain format;
+ * - If all formats are plain, pick the format of the first input
+ */
+ for (int i = 0; i < n_; ++i) {
+ const memory_desc_wrapper src_d(src_mds_[i]);
+ if (!src_d.is_plain() && src_d.is_blocking_desc()) {
+ return memory_desc_init_by_blocking_desc(dst_md_,
+ src_d.blocking_desc());
+ }
+ }
+
+ if (src_mds_[0].format_kind != format_kind::blocked)
+ return status::unimplemented;
+
+ dst_md_ = src_mds_[0];
+
+ return status::success;
+ }
+};
+
+#define DECLARE_SUM_PD_t(impl_name, ...) \
+ static status_t create(sum_pd_t **sum_pd, \
+ engine_t *engine, const primitive_attr_t *attr, \
+ const memory_desc_t *dst_md, int n, const float *scales, \
+ const memory_desc_t *src_mds) { \
+ using namespace status; \
+ auto _pd = new pd_t(engine, attr, dst_md, n, scales, src_mds); \
+ if (_pd == nullptr) return out_of_memory; \
+ if (_pd->init() != success) { delete _pd; return unimplemented; } \
+ return safe_ptr_assign<sum_pd_t>(*sum_pd, _pd); \
+ } \
+ virtual status_t create_primitive(primitive_t **p) const override { \
+ double ms = get_msec(); \
+ auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
+ ms = get_msec() - ms; \
+ if (mkldnn_verbose()->level >= 2) { \
+ printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
+ fflush(0); \
+ } \
+ return ret; \
+ } \
+ virtual pd_t *clone() const override { return new pd_t(*this); } \
+ virtual const char *name() const override { return impl_name; } \
+
+#define DECLARE_SUM_PD_T(impl_name, ...) \
+ DECLARE_SUM_PD_t(impl_name, __VA_ARGS__)
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp
new file mode 100644
index 0000000000..a408f45980
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/tag_traits.hpp
@@ -0,0 +1,200 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef TAG_TRAITS_HPP
+#define TAG_TRAITS_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+enum class block_dim_t {
+ _,
+ _A, _B,
+ _AB, _BC,
+};
+
+enum class inner_blk_t {
+ _,
+ _4a, _4b,
+ _8a, _8b,
+ _16a, _16b,
+
+ _4b4a, _4b4c, _4c4b,
+ _8a8b, _8b8a, _8b8c, _8c8b,
+ _16a16b, _16a4b, _16b16a, _16b4c, _16b16c, _16c16b,
+
+ _2c8b4c, _8a16b2a, _4b16a4b, _8b16a2b, _8b16c2b, _4c16b4c, _8c16b2c,
+};
+
+/** returns the offset within the block for weights blocked over oc and ic */
+template <inner_blk_t f>
+constexpr int AB_or_BC_blk_off(int x0, int x1) {
+ using ib = inner_blk_t;
+ static_assert(utils::one_of(f, ib::_4b4a, ib::_4b4c, ib::_4c4b, ib::_8a8b,
+ ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b, ib::_16a4b,
+ ib::_16b16a, ib::_16b4c, ib::_16b16c, ib::_16c16b, ib::_2c8b4c,
+ ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b,
+ ib::_4c16b4c, ib::_8c16b2c),
+ "unexpected inner_blk format");
+ return false ? 0
+ : (f == ib::_4b4c) ? 4 * x0 + x1
+ : (f == ib::_4b4a || f == ib::_4c4b) ? 4 * x1 + x0
+ : (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1
+ : (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0
+ : (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1
+ : (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0
+ : (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1
+ : (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2
+ : (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4
+ : (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2
+ : (f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4
+ : INT_MIN;
+}
+
+template <inner_blk_t b> struct inner_blk_traits {
+ using ib = inner_blk_t;
+};
+
+template <format_tag_t> struct tag_traits {
+ // block_dim_t block_dims;
+ // inner_blk_t inner_blks;
+ // int ndims;
+};
+
+#define DECL_TRAITS(_tag, _blk_fmt, _inner_blk, _ndims) \
+template <> struct tag_traits<format_tag::_tag> { \
+ static constexpr block_dim_t block_dims = block_dim_t::_blk_fmt; \
+ static constexpr inner_blk_t inner_blks = inner_blk_t::_inner_blk; \
+ static constexpr int ndims = _ndims; \
+}
+
+DECL_TRAITS(a, _, _, 1);
+DECL_TRAITS(ab, _, _, 2);
+DECL_TRAITS(abc, _, _, 3);
+DECL_TRAITS(abcd, _, _, 4);
+DECL_TRAITS(abcde, _, _, 5);
+DECL_TRAITS(abcdef, _, _, 6);
+DECL_TRAITS(abdec, _, _, 5);
+DECL_TRAITS(acb, _, _, 3);
+DECL_TRAITS(acbde, _, _, 5);
+DECL_TRAITS(acdb, _, _, 4);
+DECL_TRAITS(acdeb, _, _, 5);
+DECL_TRAITS(ba, _, _, 2);
+DECL_TRAITS(bac, _, _, 3);
+DECL_TRAITS(bacd, _, _, 4);
+DECL_TRAITS(bcda, _, _, 4);
+DECL_TRAITS(cba, _, _, 3);
+DECL_TRAITS(cdba, _, _, 4);
+DECL_TRAITS(cdeba, _, _, 5);
+DECL_TRAITS(decab, _, _, 5);
+
+DECL_TRAITS(Abc4a, _A, _4a, 3);
+DECL_TRAITS(aBc4b, _B, _4b, 3);
+DECL_TRAITS(ABc4b16a4b, _AB, _4b16a4b, 3);
+DECL_TRAITS(ABc4b4a, _AB, _4b4a, 3);
+DECL_TRAITS(Abcd4a, _A, _4a, 4);
+DECL_TRAITS(aBcd4b, _B, _4b, 4);
+DECL_TRAITS(ABcd4b4a, _AB, _4b4a, 4);
+DECL_TRAITS(aBCd4c16b4c, _BC, _4c16b4c, 4);
+DECL_TRAITS(aBCd4c4b, _BC, _4c4b, 4);
+DECL_TRAITS(Abcde4a, _A, _4a, 5);
+DECL_TRAITS(aBcde4b, _B, _4b, 5);
+DECL_TRAITS(ABcde4b4a, _AB, _4b4a, 5);
+DECL_TRAITS(aBCde4c4b, _BC, _4c4b, 5);
+DECL_TRAITS(aBcdef4b, _B, _4b, 6);
+DECL_TRAITS(aBCdef4c4b, _BC, _4c4b, 6);
+DECL_TRAITS(aBdc4b, _B, _4b, 4);
+DECL_TRAITS(aBdec4b, _B, _4b, 5);
+DECL_TRAITS(aBdefc4b, _B, _4b, 6);
+DECL_TRAITS(Acb4a, _A, _4a, 3);
+DECL_TRAITS(Acdb4a, _A, _4a, 4);
+DECL_TRAITS(Acdeb4a, _A, _4a, 5);
+
+DECL_TRAITS(Abc16a, _A, _16a, 3);
+DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3);
+DECL_TRAITS(aBc16b, _B, _16b, 3);
+DECL_TRAITS(ABc16b16a, _AB, _16b16a, 3);
+DECL_TRAITS(ABc8a16b2a, _AB, _8a16b2a, 3);
+DECL_TRAITS(ABc8a8b, _AB, _8a8b, 3);
+DECL_TRAITS(aBc8b, _B, _8b, 3);
+DECL_TRAITS(ABc8b16a2b, _AB, _8b16a2b, 3);
+DECL_TRAITS(ABc8b8a, _AB, _8b8a, 3);
+DECL_TRAITS(Abcd16a, _A, _16a, 4);
+DECL_TRAITS(ABcd16a16b, _AB, _16a16b, 4);
+DECL_TRAITS(aBcd16b, _B, _16b, 4);
+DECL_TRAITS(ABcd16b16a, _AB, _16b16a, 4);
+DECL_TRAITS(aBCd16b16c, _BC, _16b16c, 4);
+DECL_TRAITS(aBCd16c16b, _BC, _16c16b, 4);
+DECL_TRAITS(ABcd4b16a4b, _AB, _4b16a4b, 4);
+DECL_TRAITS(ABcd8a16b2a, _AB, _8a16b2a, 4);
+DECL_TRAITS(ABcd8a8b, _AB, _8a8b, 4);
+DECL_TRAITS(aBcd8b, _B, _8b, 4);
+DECL_TRAITS(ABcd8b16a2b, _AB, _8b16a2b, 4);
+DECL_TRAITS(aBCd8b16c2b, _BC, _8b16c2b, 4);
+DECL_TRAITS(ABcd8b8a, _AB, _8b8a, 4);
+DECL_TRAITS(aBCd8b8c, _BC, _8b8c, 4);
+DECL_TRAITS(aBCd8c16b2c, _BC, _8c16b2c, 4);
+DECL_TRAITS(aBCd8c8b, _BC, _8c8b, 4);
+DECL_TRAITS(Abcde16a, _A, _16a, 5);
+DECL_TRAITS(ABcde16a16b, _AB, _16a16b, 5);
+DECL_TRAITS(aBcde16b, _B, _16b, 5);
+DECL_TRAITS(ABcde16b16a, _AB, _16b16a, 5);
+DECL_TRAITS(aBCde16b16c, _BC, _16b16c, 5);
+DECL_TRAITS(aBCde16c16b, _BC, _16c16b, 5);
+DECL_TRAITS(aBCde4c16b4c, _BC, _4c16b4c, 5);
+DECL_TRAITS(Abcde8a, _A, _8a, 5);
+DECL_TRAITS(ABcde8a8b, _AB, _8a8b, 5);
+DECL_TRAITS(aBcde8b, _B, _8b, 5);
+DECL_TRAITS(ABcde8b16a2b, _AB, _8b16a2b, 5);
+DECL_TRAITS(aBCde8b16c2b, _BC, _8b16c2b, 5);
+DECL_TRAITS(ABcde8b8a, _AB, _8b8a, 5);
+DECL_TRAITS(aBCde8b8c, _BC, _8b8c, 5);
+DECL_TRAITS(aBCde2c8b4c, _BC, _2c8b4c, 5);
+DECL_TRAITS(aBCde8c16b2c, _BC, _8c16b2c, 5);
+DECL_TRAITS(aBCde4b4c, _BC, _4b4c, 5);
+DECL_TRAITS(aBCde8c8b, _BC, _8c8b, 5);
+DECL_TRAITS(aBcdef16b, _B, _16b, 6);
+DECL_TRAITS(aBCdef16b16c, _BC, _16b16c, 6);
+DECL_TRAITS(aBCdef16c16b, _BC, _16c16b, 6);
+DECL_TRAITS(aBCdef8b8c, _BC, _8b8c, 6);
+DECL_TRAITS(aBCdef8c16b2c, _BC, _8c16b2c, 6);
+DECL_TRAITS(aBCdef8c8b, _BC, _8c8b, 6);
+DECL_TRAITS(aBdc16b, _B, _16b, 4);
+DECL_TRAITS(aBdc8b, _B, _8b, 4);
+DECL_TRAITS(aBdec16b, _B, _16b, 5);
+DECL_TRAITS(aBdec8b, _B, _8b, 5);
+DECL_TRAITS(aBdefc16b, _B, _16b, 6);
+DECL_TRAITS(aBdefc8b, _B, _8b, 6);
+DECL_TRAITS(Acb16a, _A, _16a, 3);
+DECL_TRAITS(Acb8a, _A, _8a, 3);
+DECL_TRAITS(aCBd16b16c, _BC, _16b16c, 4);
+DECL_TRAITS(aCBde16b16c, _BC, _16b16c, 5);
+DECL_TRAITS(Acdb16a, _A, _16a, 4);
+DECL_TRAITS(Acdb8a, _A, _8a, 4);
+DECL_TRAITS(Acdeb16a, _A, _16a, 5);
+DECL_TRAITS(Acdeb8a, _A, _8a, 5);
+DECL_TRAITS(BAc16a16b, _AB, _16a16b, 3);
+DECL_TRAITS(BAcd16a16b, _AB, _16a16b, 4);
+
+} // namespace impl
+} // namespace mkldnn
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp
new file mode 100644
index 0000000000..4f06368738
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/type_helpers.hpp
@@ -0,0 +1,348 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef TYPE_HELPERS_HPP
+#define TYPE_HELPERS_HPP
+
+#include <assert.h>
+#include <math.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "mkldnn_traits.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+#include "math_utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+template <typename T>
+status_t safe_ptr_assign(T * &lhs, T* rhs) {
+ if (rhs == nullptr) return status::out_of_memory;
+ lhs = rhs;
+ return status::success;
+}
+
+template <typename T, typename U> struct is_subset
+{ static constexpr bool value = false; };
+template <typename T> struct is_subset<T, T>
+{ static constexpr bool value = true; };
+template <typename T> struct is_subset<T,
+ typename utils::enable_if<nstl::is_integral<T>::value, float>::type>
+{ static constexpr bool value = true; };
+#define ISSPEC(t1, t2) template <> \
+ struct is_subset<t1, t2> { static constexpr bool value = true; }
+ISSPEC(int16_t, int32_t);
+ISSPEC(int8_t, int32_t);
+ISSPEC(uint8_t, int32_t);
+ISSPEC(int8_t, int16_t);
+ISSPEC(uint8_t, int16_t);
+#undef ISSPEC
+
+inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs);
+
+namespace types {
+
+inline size_t data_type_size(data_type_t data_type) {
+ using namespace data_type;
+ switch (data_type) {
+ case f32: return sizeof(prec_traits<f32>::type);
+ case s32: return sizeof(prec_traits<s32>::type);
+ case s8: return sizeof(prec_traits<s8>::type);
+ case u8: return sizeof(prec_traits<u8>::type);
+ case data_type::undef:
+ default: assert(!"unknown data_type");
+ }
+ return 0; /* not supposed to be reachable */
+}
+
+inline format_kind_t format_tag_to_kind(format_tag_t tag) {
+ switch (tag) {
+ case format_tag::undef: return format_kind::undef;
+ case format_tag::any: return format_kind::any;
+ case format_tag::last: return format_kind::undef;
+ default: return format_kind::blocked;
+ }
+
+ assert(!"unreachable");
+ return format_kind::undef;
+}
+
+inline bool memory_extra_desc_is_equal(const memory_extra_desc_t &lhs,
+ const memory_extra_desc_t &rhs) {
+ return true
+ && lhs.flags == rhs.flags
+ && IMPLICATION(lhs.flags & memory_extra_flags::compensation_conv_s8s8,
+ lhs.compensation_mask == rhs.compensation_mask)
+ && IMPLICATION(lhs.flags & memory_extra_flags::scale_adjust,
+ lhs.scale_adjust == rhs.scale_adjust);
+}
+
+inline bool blocking_desc_is_equal(const blocking_desc_t &lhs,
+ const blocking_desc_t &rhs, int ndims = MKLDNN_MAX_NDIMS) {
+ using mkldnn::impl::utils::array_cmp;
+ return true
+ && lhs.inner_nblks == rhs.inner_nblks
+ && array_cmp(lhs.strides, rhs.strides, ndims)
+ && array_cmp(lhs.inner_blks, rhs.inner_blks, lhs.inner_nblks)
+ && array_cmp(lhs.inner_idxs, rhs.inner_idxs, lhs.inner_nblks);
+}
+
+inline bool wino_desc_is_equal(const wino_desc_t &lhs,
+ const wino_desc_t &rhs) {
+ return lhs.wino_format == rhs.wino_format
+ && lhs.alpha == rhs.alpha
+ && lhs.ic == rhs.ic
+ && lhs.oc == rhs.oc
+ && lhs.ic_block == rhs.ic_block
+ && lhs.oc_block == rhs.oc_block
+ && lhs.ic2_block == rhs.ic2_block
+ && lhs.oc2_block == rhs.oc2_block
+ && lhs.r == rhs.r;
+}
+
+inline bool rnn_packed_desc_is_equal(
+ const rnn_packed_desc_t &lhs, const rnn_packed_desc_t &rhs) {
+ bool ok = true
+ && lhs.format == rhs.format
+ && lhs.n_parts == rhs.n_parts
+ && lhs.offset_compensation == rhs.offset_compensation
+ && lhs.size == rhs.size
+ && lhs.n == rhs.n;
+ if (!ok)
+ return false;
+
+ for (int i = 0; i < rhs.n_parts; i++)
+ ok = ok && lhs.parts[i] == rhs.parts[i];
+ for (int i = 0; i < rhs.n_parts; i++)
+ ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i];
+ return ok;
+}
+
+inline memory_desc_t zero_md() {
+ auto zero = memory_desc_t();
+ return zero;
+}
+
+inline bool is_zero_md(const memory_desc_t *md) {
+ return md == nullptr || *md == zero_md();
+}
+
+inline data_type_t default_accum_data_type(data_type_t src_dt,
+ data_type_t dst_dt) {
+ using namespace utils;
+ using namespace data_type;
+
+ if (one_of(f32, src_dt, dst_dt)) return f32;
+ if (one_of(s32, src_dt, dst_dt)) return s32;
+
+ if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32;
+
+ assert(!"unimplemented use-case: no default parameters available");
+ return dst_dt;
+}
+
+inline data_type_t default_accum_data_type(data_type_t src_dt,
+ data_type_t wei_dt, data_type_t dst_dt, prop_kind_t prop_kind) {
+ using namespace utils;
+ using namespace data_type;
+ using namespace prop_kind;
+
+ /* prop_kind doesn't matter */
+ if (everyone_is(f32, src_dt, wei_dt, dst_dt)) return f32;
+
+ if (one_of(prop_kind, forward_training, forward_inference)) {
+ if ((src_dt == u8 || src_dt == s8)
+ && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8))
+ return s32;
+ } else if (prop_kind == backward_data) {
+ if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 &&
+ one_of(dst_dt, s8, u8))
+ return s32;
+ }
+
+ assert(!"unimplemented use-case: no default parameters available");
+ return dst_dt;
+}
+
+}
+
+inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) {
+ using namespace mkldnn::impl::utils;
+ bool base_equal = true
+ && lhs.ndims == rhs.ndims
+ && array_cmp(lhs.dims, rhs.dims, lhs.ndims)
+ && lhs.data_type == rhs.data_type
+ && array_cmp(lhs.padded_dims, rhs.padded_dims, lhs.ndims)
+ && array_cmp(lhs.padded_offsets, rhs.padded_offsets, lhs.ndims)
+ && lhs.offset0 == rhs.offset0
+ && lhs.format_kind == rhs.format_kind;
+ if (!base_equal) return false;
+ if (!types::memory_extra_desc_is_equal(lhs.extra, rhs.extra)) return false;
+ if (lhs.format_kind == format_kind::blocked)
+ return types::blocking_desc_is_equal(lhs.format_desc.blocking,
+ rhs.format_desc.blocking, lhs.ndims);
+ else if (lhs.format_kind == format_kind::wino)
+ return types::wino_desc_is_equal(lhs.format_desc.wino_desc,
+ rhs.format_desc.wino_desc);
+ else if (lhs.format_kind == format_kind::rnn_packed)
+ return types::rnn_packed_desc_is_equal(lhs.format_desc.rnn_packed_desc,
+ rhs.format_desc.rnn_packed_desc);
+ return true;
+}
+
+inline bool operator!=(const memory_desc_t &lhs, const memory_desc_t &rhs) {
+ return !operator==(lhs, rhs);
+}
+
+inline status_t memory_desc_init_by_strides(memory_desc_t &md,
+ const dims_t strides) {
+ return mkldnn_memory_desc_init_by_strides(
+ &md, md.ndims, md.dims, md.data_type, strides);
+}
+
+inline status_t memory_desc_init_by_tag(memory_desc_t &md, format_tag_t tag,
+ const dims_t strides = nullptr) {
+ status_t status = mkldnn_memory_desc_init_by_tag(
+ &md, md.ndims, md.dims, md.data_type, tag);
+ if (status != status::success || strides == nullptr)
+ return status;
+
+ /* TODO: add consistency check */
+
+ for (int d = 0; d < md.ndims; ++d)
+ md.format_desc.blocking.strides[d] = strides[d];
+
+ return status::success;
+}
+
+/** inits memory descriptor based on logical dimensions kept in @p md, and the
+ * blocking structure @p blk.
+ *
+ * @note blk.strides represent the order only (from smaller to bigger)
+ *
+ * TODO: move md related functions to one single place
+ */
+inline status_t memory_desc_init_by_blocking_desc(memory_desc_t &md,
+ const blocking_desc_t &blk) {
+ dims_t blocks = {0};
+ utils::array_set(blocks, 1, md.ndims);
+ dim_t block_size = 1;
+ for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
+ blocks[blk.inner_idxs[iblk]] *= blk.inner_blks[iblk];
+ block_size *= blk.inner_blks[iblk];
+ }
+
+ for (int d = 0; d < md.ndims; ++d) {
+ md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
+ md.padded_offsets[d] = 0;
+ }
+ md.offset0 = 0;
+
+ md.format_kind = format_kind::blocked;
+ auto &mblk = md.format_desc.blocking;
+ mblk = blk;
+
+ const int ndims = nstl::min(MKLDNN_MAX_NDIMS, md.ndims); // make GCC 5 happy
+ utils::array_copy(mblk.strides, blk.strides, ndims);
+
+ int perm[MKLDNN_MAX_NDIMS];
+ for (int d = 0; d < ndims; ++d) perm[d] = d;
+
+ utils::simultaneous_sort(mblk.strides, perm, ndims,
+ [](stride_t a, stride_t b) { return b - a; });
+
+ dim_t stride = block_size;
+ for (int _d = ndims - 1; _d >= 0; --_d) {
+ const int d = perm[_d];
+ md.format_desc.blocking.strides[d] = stride;
+ stride *= md.padded_dims[d] / blocks[d];
+ }
+
+ md.extra = utils::zero<memory_extra_desc_t>();
+
+ return status::success;
+}
+
+/** returns true if memory desc @p md corresponds to the given format tag and
+ * strides.
+ * If strides are not passed (or passed as nullptr) the dense structure is
+ * assumed (i.e. the one that mkldnn_memory_desc_init_by_tag() returns).
+ * Strides might contain `0` value, indicating the stride must match the one
+ * that mkldnn_memory_desc_init_by_tag() returns.
+ * Strides might contain `-1` values, that would be ignored during the
+ * comparison. For instance, this can be used if a stride along minibatch
+ * doesn't matter. */
+inline bool memory_desc_matches_tag(const memory_desc_t &md, format_tag_t tag,
+ const dims_t strides = nullptr) {
+ if (md.format_kind != types::format_tag_to_kind(tag))
+ return false;
+
+ memory_desc_t md_gold;
+ status_t status = mkldnn_memory_desc_init_by_tag(
+ &md_gold, md.ndims, md.dims, md.data_type, tag);
+ if (status != status::success) return false;
+
+ if (md.format_kind != format_kind::blocked)
+ return false; // unimplemented yet
+
+ const auto &blk = md.format_desc.blocking;
+ const auto &blk_gold = md_gold.format_desc.blocking;
+
+ using utils::array_cmp;
+ bool same_blocks = true
+ && blk.inner_nblks == blk_gold.inner_nblks
+ && array_cmp(blk.inner_blks, blk_gold.inner_blks, blk.inner_nblks)
+ && array_cmp(blk.inner_idxs, blk_gold.inner_idxs, blk.inner_nblks);
+
+ if (!same_blocks)
+ return false;
+
+ if (strides == nullptr)
+ return array_cmp(blk.strides, blk_gold.strides, md.ndims);
+
+ for (int d = 0; d < md.ndims; ++d) {
+ dim_t stride = strides[d];
+ if (stride == -1) continue;
+ if (stride == 0) stride = blk_gold.strides[d];
+ if (blk.strides[d] != stride) return false;
+ }
+
+ return true;
+}
+
+/** returns matching tag (or undef if match is not found)
+ * XXX: This is a workaround that eventually should go away! */
+template <typename... Tags>
+format_tag_t memory_desc_matches_one_of_tag(const memory_desc_t &md,
+ Tags ...tags) {
+ for (const auto tag: {tags...}) {
+ if (memory_desc_matches_tag(md, tag))
+ return tag;
+ }
+ return format_tag::undef;
+}
+
+}
+}
+
+#include "memory_desc_wrapper.hpp"
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.cpp b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp
new file mode 100644
index 0000000000..d23f4682dc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/utils.cpp
@@ -0,0 +1,135 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <string.h>
+#ifdef _WIN32
+#include <malloc.h>
+#include <windows.h>
+#endif
+#include <limits.h>
+#include <stdlib.h>
+#include <stdio.h>
+
+#include "mkldnn.h"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+int getenv(const char *name, char *buffer, int buffer_size) {
+ if (name == NULL || buffer_size < 0 || (buffer == NULL && buffer_size > 0))
+ return INT_MIN;
+
+ int result = 0;
+ int term_zero_idx = 0;
+ size_t value_length = 0;
+
+#ifdef _WIN32
+ value_length = GetEnvironmentVariable(name, buffer, buffer_size);
+#else
+ const char *value = ::getenv(name);
+ value_length = value == NULL ? 0 : strlen(value);
+#endif
+
+ if (value_length > INT_MAX)
+ result = INT_MIN;
+ else {
+ int int_value_length = (int)value_length;
+ if (int_value_length >= buffer_size) {
+ result = -int_value_length;
+ } else {
+ term_zero_idx = int_value_length;
+ result = int_value_length;
+#ifndef _WIN32
+ strncpy(buffer, value, value_length);
+#endif
+ }
+ }
+
+ if (buffer != NULL)
+ buffer[term_zero_idx] = '\0';
+ return result;
+}
+
+int getenv_int(const char *name, int default_value)
+{
+ int value = default_value;
+ // # of digits in the longest 32-bit signed int + sign + terminating null
+ const int len = 12;
+ char value_str[len];
+ if (getenv(name, value_str, len) > 0)
+ value = atoi(value_str);
+ return value;
+}
+
+FILE *fopen(const char *filename, const char *mode) {
+#ifdef _WIN32
+ FILE *fp = NULL;
+ return ::fopen_s(&fp, filename, mode) ? NULL : fp;
+#else
+ return ::fopen(filename, mode);
+#endif
+}
+
+void *malloc(size_t size, int alignment) {
+ void *ptr;
+
+#ifdef _WIN32
+ ptr = _aligned_malloc(size, alignment);
+ int rc = ptr ? 0 : -1;
+#else
+ int rc = ::posix_memalign(&ptr, alignment, size);
+#endif
+
+ return (rc == 0) ? ptr : 0;
+}
+
+void free(void *p) {
+#ifdef _WIN32
+ _aligned_free(p);
+#else
+ ::free(p);
+#endif
+}
+
+// Atomic operations
+int32_t fetch_and_add(int32_t *dst, int32_t val) {
+#ifdef _WIN32
+ return InterlockedExchangeAdd(reinterpret_cast<long*>(dst), val);
+#else
+ return __sync_fetch_and_add(dst, val);
+#endif
+}
+
+static int jit_dump_flag = 0;
+static bool jit_dump_flag_initialized = false;
+bool jit_dump_enabled() {
+ if (!jit_dump_flag_initialized) {
+ jit_dump_flag = getenv_int("MKLDNN_JIT_DUMP");
+ jit_dump_flag_initialized = true;
+ }
+ return jit_dump_flag != 0;
+}
+
+}
+}
+
+mkldnn_status_t mkldnn_set_jit_dump(int enabled) {
+ using namespace mkldnn::impl::status;
+ mkldnn::impl::jit_dump_flag = enabled;
+ mkldnn::impl::jit_dump_flag_initialized = true;
+ return success;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/utils.hpp b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp
new file mode 100644
index 0000000000..d5a8ec5139
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/utils.hpp
@@ -0,0 +1,370 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef UTILS_HPP
+#define UTILS_HPP
+
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <assert.h>
+#include <stdint.h>
+
+#if defined(__x86_64__) || defined(_M_X64)
+#define MKLDNN_X86_64
+#endif
+
+#define MSAN_ENABLED 0
+#if defined(__has_feature)
+#if __has_feature(memory_sanitizer)
+#undef MSAN_ENABLED
+#define MSAN_ENABLED 1
+#include <sanitizer/msan_interface.h>
+#endif
+#endif
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+// Sanity check for 64 bits
+static_assert(sizeof(void*) == 8, "Intel(R) MKL-DNN supports 64 bit only");
+
+#define CHECK(f) do { \
+ status_t status = f; \
+ if (status != status::success) \
+ return status; \
+} while (0)
+
+#define IMPLICATION(cause, effect) (!(cause) || !!(effect))
+
+namespace utils {
+
+/* a bunch of std:: analogues to be compliant with any msvs version
+ *
+ * Rationale: msvs c++ (and even some c) headers contain special pragma that
+ * injects msvs-version check into object files in order to abi-mismatches
+ * during the static linking. This makes sense if e.g. std:: objects are passed
+ * through between application and library, which is not the case for mkl-dnn
+ * (since there is no any c++-rt dependent stuff, ideally...). */
+
+/* SFINAE helper -- analogue to std::enable_if */
+template<bool expr, class T = void> struct enable_if {};
+template<class T> struct enable_if<true, T> { typedef T type; };
+
+/* analogue std::conditional */
+template <bool, typename, typename> struct conditional {};
+template <typename T, typename F> struct conditional<true, T, F>
+{ typedef T type; };
+template <typename T, typename F> struct conditional<false, T, F>
+{ typedef F type; };
+
+template <bool, typename, bool, typename, typename> struct conditional3 {};
+template <typename T, typename FT, typename FF>
+struct conditional3<true, T, false, FT, FF> { typedef T type; };
+template <typename T, typename FT, typename FF>
+struct conditional3<false, T, true, FT, FF> { typedef FT type; };
+template <typename T, typename FT, typename FF>
+struct conditional3<false, T, false, FT, FF> { typedef FF type; };
+
+template <bool, typename U, U, U> struct conditional_v {};
+template <typename U, U t, U f> struct conditional_v<true, U, t, f>
+{ static constexpr U value = t; };
+template <typename U, U t, U f> struct conditional_v<false, U, t, f>
+{ static constexpr U value = f; };
+
+template <typename T> struct remove_reference { typedef T type; };
+template <typename T> struct remove_reference<T&> { typedef T type; };
+template <typename T> struct remove_reference<T&&> { typedef T type; };
+
+template <typename T>
+inline T&& forward(typename utils::remove_reference<T>::type &t)
+{ return static_cast<T&&>(t); }
+template <typename T>
+inline T&& forward(typename utils::remove_reference<T>::type &&t)
+{ return static_cast<T&&>(t); }
+
+template <typename T>
+inline typename remove_reference<T>::type zero()
+{ auto zero = typename remove_reference<T>::type(); return zero; }
+
+template <typename T, typename P>
+inline bool everyone_is(T val, P item) { return val == item; }
+template <typename T, typename P, typename... Args>
+inline bool everyone_is(T val, P item, Args... item_others) {
+ return val == item && everyone_is(val, item_others...);
+}
+
+template <typename T, typename P>
+constexpr bool one_of(T val, P item) { return val == item; }
+template <typename T, typename P, typename... Args>
+constexpr bool one_of(T val, P item, Args... item_others) {
+ return val == item || one_of(val, item_others...);
+}
+
+template <typename... Args>
+inline bool any_null(Args... ptrs) { return one_of(nullptr, ptrs...); }
+
+template<typename T>
+inline void array_copy(T *dst, const T *src, size_t size) {
+ for (size_t i = 0; i < size; ++i) dst[i] = src[i];
+}
+template<typename T>
+inline bool array_cmp(const T *a1, const T *a2, size_t size) {
+ for (size_t i = 0; i < size; ++i) if (a1[i] != a2[i]) return false;
+ return true;
+}
+template<typename T, typename U>
+inline void array_set(T *arr, const U& val, size_t size) {
+ for (size_t i = 0; i < size; ++i) arr[i] = static_cast<T>(val);
+}
+
+namespace product_impl {
+template<size_t> struct int2type{};
+
+template <typename T>
+constexpr int product_impl(const T *arr, int2type<0>) { return arr[0]; }
+
+template <typename T, size_t num>
+inline T product_impl(const T *arr, int2type<num>) {
+ return arr[0]*product_impl(arr+1, int2type<num-1>()); }
+}
+
+template <size_t num, typename T>
+inline T array_product(const T *arr) {
+ return product_impl::product_impl(arr, product_impl::int2type<num-1>());
+}
+
+template<typename T, typename R = T>
+inline R array_product(const T *arr, size_t size) {
+ R prod = 1;
+ for (size_t i = 0; i < size; ++i) prod *= arr[i];
+ return prod;
+}
+
+/** sorts an array of values using @p comparator. While sorting the array
+ * of value, the function permutes an array of @p keys accordingly.
+ *
+ * @note The arrays of @p keys can be omitted. In this case the function
+ * sorts the array of @vals only.
+ */
+template <typename T, typename U, typename F>
+inline void simultaneous_sort(T *vals, U *keys, size_t size, F comparator) {
+ if (size == 0) return;
+
+ for (size_t i = 0; i < size - 1; ++i) {
+ bool swapped = false;
+
+ for (size_t j = 0; j < size - i - 1; j++) {
+ if (comparator(vals[j], vals[j + 1]) > 0) {
+ nstl::swap(vals[j], vals[j + 1]);
+ if (keys) nstl::swap(keys[j], keys[j + 1]);
+ swapped = true;
+ }
+ }
+
+ if (swapped == false) break;
+ }
+}
+
+template <typename T, typename U>
+inline typename remove_reference<T>::type div_up(const T a, const U b) {
+ assert(b);
+ return (a + b - 1) / b;
+}
+
+template <typename T, typename U>
+inline typename remove_reference<T>::type rnd_up(const T a, const U b) {
+ return div_up(a, b) * b;
+}
+
+template <typename T, typename U>
+inline typename remove_reference<T>::type rnd_dn(const T a, const U b) {
+ return (a / b) * b;
+}
+
+template <typename T> T *align_ptr(T *ptr, uintptr_t alignment)
+{ return (T *)(((uintptr_t)ptr + alignment - 1) & ~(alignment - 1)); }
+
+template <typename T, typename U, typename V>
+inline U this_block_size(const T offset, const U max, const V block_size) {
+ assert(offset < max);
+ // TODO (Roma): can't use nstl::max() due to circular dependency... we
+ // need to fix this
+ const T block_boundary = offset + block_size;
+ if (block_boundary > max)
+ return max - offset;
+ else
+ return block_size;
+}
+
+template<typename T>
+inline T nd_iterator_init(T start) { return start; }
+template<typename T, typename U, typename W, typename... Args>
+inline T nd_iterator_init(T start, U &x, const W &X, Args &&... tuple) {
+ start = nd_iterator_init(start, utils::forward<Args>(tuple)...);
+ x = start % X;
+ return start / X;
+}
+
+inline bool nd_iterator_step() { return true; }
+template<typename U, typename W, typename... Args>
+inline bool nd_iterator_step(U &x, const W &X, Args &&... tuple) {
+ if (nd_iterator_step(utils::forward<Args>(tuple)...) ) {
+ x = (x + 1) % X;
+ return x == 0;
+ }
+ return false;
+}
+
+template<typename U, typename W, typename Y>
+inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X)
+{
+ U max_jump = end - cur;
+ U dim_jump = X - x;
+ if (dim_jump <= max_jump) {
+ x = 0;
+ cur += dim_jump;
+ return true;
+ } else {
+ cur += max_jump;
+ x += max_jump;
+ return false;
+ }
+}
+template<typename U, typename W, typename Y, typename... Args>
+inline bool nd_iterator_jump(U &cur, const U end, W &x, const Y &X,
+ Args &&... tuple)
+{
+ if (nd_iterator_jump(cur, end, utils::forward<Args>(tuple)...)) {
+ x = (x + 1) % X;
+ return x == 0;
+ }
+ return false;
+}
+
+template <typename T>
+inline T pick(size_t i, const T &x0) { return x0; }
+template <typename T, typename ...Args>
+inline T pick(size_t i, const T &x0, Args &&... args) {
+ return i == 0 ? x0 : pick(i - 1, utils::forward<Args>(args)...);
+}
+
+template <typename T>
+T pick_by_prop_kind(prop_kind_t prop_kind, const T &val_fwd_inference,
+ const T &val_fwd_training, const T &val_bwd_d, const T &val_bwd_w) {
+ switch (prop_kind) {
+ case prop_kind::forward_inference: return val_fwd_inference;
+ case prop_kind::forward_training: return val_fwd_training;
+ case prop_kind::backward_data: return val_bwd_d;
+ case prop_kind::backward_weights: return val_bwd_w;
+ default: assert(!"unsupported prop_kind");
+ }
+ return T();
+}
+
+template <typename T>
+T pick_by_prop_kind(prop_kind_t prop_kind,
+ const T &val_fwd, const T &val_bwd_d, const T &val_bwd_w)
+{ return pick_by_prop_kind(prop_kind, val_fwd, val_fwd, val_bwd_d, val_bwd_w); }
+
+template <typename Telem, size_t Tdims>
+struct array_offset_calculator {
+ template <typename... Targs>
+ array_offset_calculator(Telem *base, Targs... Fargs) : _dims{ Fargs... }
+ {
+ _base_ptr = base;
+ }
+ template <typename... Targs>
+ inline Telem &operator()(Targs... Fargs)
+ {
+ return *(_base_ptr + _offset(1, Fargs...));
+ }
+
+private:
+ template <typename... Targs>
+ inline size_t _offset(size_t const dimension, size_t element)
+ {
+ return element;
+ }
+
+ template <typename... Targs>
+ inline size_t _offset(size_t const dimension, size_t theta, size_t element)
+ {
+ return element + (_dims[dimension] * theta);
+ }
+
+ template <typename... Targs>
+ inline size_t _offset(size_t const dimension, size_t theta, size_t element,
+ Targs... Fargs)
+ {
+ size_t t_prime = element + (_dims[dimension] * theta);
+ return _offset(dimension + 1, t_prime, Fargs...);
+ }
+
+ Telem *_base_ptr;
+ const int _dims[Tdims];
+};
+
+}
+
+int32_t fetch_and_add(int32_t *dst, int32_t val);
+inline void yield_thread() {}
+
+// Reads an environment variable 'name' and stores its string value in the
+// 'buffer' of 'buffer_size' bytes on success.
+//
+// - Returns the length of the environment variable string value (excluding
+// the terminating 0) if it is set and its contents (including the terminating
+// 0) can be stored in the 'buffer' without truncation.
+//
+// - Returns negated length of environment variable string value and writes
+// "\0" to the buffer (if it is not NULL) if the 'buffer_size' is to small to
+// store the value (including the terminating 0) without truncation.
+//
+// - Returns 0 and writes "\0" to the buffer (if not NULL) if the environment
+// variable is not set.
+//
+// - Returns INT_MIN if the 'name' is NULL.
+//
+// - Returns INT_MIN if the 'buffer_size' is negative.
+//
+// - Returns INT_MIN if the 'buffer' is NULL and 'buffer_size' is greater than
+// zero. Passing NULL 'buffer' with 'buffer_size' set to 0 can be used to
+// retrieve the length of the environment variable value string.
+//
+int getenv(const char *name, char *buffer, int buffer_size);
+// Reads an integer from the environment
+int getenv_int(const char *name, int default_value = 0);
+bool jit_dump_enabled();
+FILE *fopen(const char *filename, const char *mode);
+
+constexpr int msan_enabled = MSAN_ENABLED;
+inline void msan_unpoison(void *ptr, size_t size) {
+#if MSAN_ENABLED
+ __msan_unpoison(ptr, size);
+#endif
+}
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp
new file mode 100644
index 0000000000..89a57772cf
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.cpp
@@ -0,0 +1,665 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <stdlib.h>
+#ifndef _WIN32
+#include <sys/time.h>
+#endif
+
+#include "mkldnn.h"
+#include "mkldnn_version.h"
+#include "c_types_map.hpp"
+#include "verbose.hpp"
+#include "cpu/cpu_isa_traits.hpp"
+
+#include "batch_normalization_pd.hpp"
+#include "pooling_pd.hpp"
+#include "concat_pd.hpp"
+#include "reorder_pd.hpp"
+#include "convolution_pd.hpp"
+#include "rnn_pd.hpp"
+#include "deconvolution_pd.hpp"
+#include "shuffle_pd.hpp"
+#include "eltwise_pd.hpp"
+#include "softmax_pd.hpp"
+#include "inner_product_pd.hpp"
+#include "sum_pd.hpp"
+#include "lrn_pd.hpp"
+
+/* MKL-DNN CPU ISA info */
+#define ISA_ANY "No instruction set specific optimizations"
+#define SSE42 "Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2)"
+#define AVX "Intel(R) Advanced Vector Extensions (Intel(R) AVX)"
+#define AVX2 "Intel(R) Advanced Vector Extensions 2 (Intel(R) AVX2)"
+#define AVX512_COMMON "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
+ "AVX-512)"
+#define AVX512_CORE "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
+ "AVX-512) with AVX512BW, AVX512VL, and AVX512DQ extensions"
+#define AVX512_CORE_VNNI "Intel(R) AVX512-Deep Learning Boost (Intel(R) " \
+ "AVX512-DL Boost)"
+#define AVX512_MIC "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
+ "AVX-512) with AVX512CD, AVX512ER, and AVX512PF extensions"
+#define AVX512_MIC_4OPS "Intel(R) Advanced Vector Extensions 512 (Intel(R) " \
+ "AVX-512) with AVX512_4FMAPS and AVX512_4VNNIW extensions"
+
+namespace mkldnn {
+namespace impl {
+
+static verbose_t verbose;
+static bool initialized;
+static bool version_printed = false;
+
+const verbose_t *mkldnn_verbose() {
+#if !defined(DISABLE_VERBOSE)
+ if (!initialized) {
+ const int len = 2;
+ char val[len] = {0};
+ if (getenv("MKLDNN_VERBOSE", val, len) == 1)
+ verbose.level = atoi(val);
+ initialized = true;
+ }
+ if (!version_printed && verbose.level > 0) {
+ printf("mkldnn_verbose,info,"
+ "Intel(R) MKL-DNN v%d.%d.%d (Git Hash %s),%s\n",
+ mkldnn_version()->major, mkldnn_version()->minor,
+ mkldnn_version()->patch, mkldnn_version()->hash,
+ get_isa_info());
+ version_printed = true;
+ }
+#else
+ verbose.level = 0;
+#endif
+ return &verbose;
+}
+
+double get_msec() {
+#ifdef _WIN32
+ static LARGE_INTEGER frequency;
+ if (frequency.QuadPart == 0)
+ QueryPerformanceFrequency(&frequency);
+ LARGE_INTEGER now;
+ QueryPerformanceCounter(&now);
+ return 1e+3 * now.QuadPart / frequency.QuadPart;
+#else
+ struct timeval time;
+ gettimeofday(&time, NULL);
+ return 1e+3 * time.tv_sec + 1e-3 * time.tv_usec;
+#endif
+}
+
+const char *get_isa_info() {
+ using namespace mkldnn::impl::cpu;
+ if (mayiuse(avx512_mic_4ops)) return AVX512_MIC_4OPS;
+ if (mayiuse(avx512_mic)) return AVX512_MIC;
+ if (mayiuse(avx512_core_vnni)) return AVX512_CORE_VNNI;
+ if (mayiuse(avx512_core)) return AVX512_CORE;
+ if (mayiuse(avx512_common)) return AVX512_COMMON;
+ if (mayiuse(avx2)) return AVX2;
+ if (mayiuse(avx)) return AVX;
+ if (mayiuse(sse42)) return SSE42;
+ return ISA_ANY;
+}
+
+/* init_info section */
+namespace {
+#if !defined(DISABLE_VERBOSE)
+#define MKLDNN_VERBOSE_DAT_LEN 256
+#define MKLDNN_VERBOSE_AUX_LEN 384
+#define MKLDNN_VERBOSE_PRB_LEN 384
+
+#define DECL_DAT_AUX_PRB_STRS() \
+ int dat_written = 0, aux_written = 0, prb_written = 0; \
+ MAYBE_UNUSED((dat_written * aux_written * prb_written)); \
+ char dat_str[MKLDNN_VERBOSE_DAT_LEN] = {'\0'}; MAYBE_UNUSED(dat_str); \
+ char aux_str[MKLDNN_VERBOSE_AUX_LEN] = {'\0'}; MAYBE_UNUSED(aux_str); \
+ char prb_str[MKLDNN_VERBOSE_PRB_LEN] = {'\0'}; MAYBE_UNUSED(prb_str)
+
+#define DFMT "%" PRId64
+
+void clear_buf(char *buf, int &written) {
+ /* TODO: do it better */
+ buf[0] = '#';
+ buf[1] = '\0';
+ written = 1;
+}
+
+#define DPRINT(buf, buf_len, written, ...) do { \
+ int l = snprintf(buf + written, buf_len - written, __VA_ARGS__); \
+ if (l < 0 || written + l > buf_len) { \
+ clear_buf(buf, written); \
+ } else { \
+ written += l; \
+ } \
+} while(0)
+
+// XXX: Outputs strings corresponding to memory formats used for data tensors.
+void format_prb_desc_str(char *str, int len, const memory_desc_t *md) {
+ const auto dims = md->dims;
+ int written = 0;
+ if (md->ndims == 1)
+ DPRINT(str, len, written,
+ "x" DFMT, dims[0]);
+ else if (md->ndims == 2)
+ DPRINT(str, len, written,
+ "mb" DFMT "ic" DFMT, dims[0], dims[1]);
+ else if (md->ndims == 3)
+ DPRINT(str, len, written,
+ "mb" DFMT "ic" DFMT "iw" DFMT,
+ dims[0], dims[1], dims[2]);
+ else if (md->ndims == 4)
+ DPRINT(str, len, written,
+ "mb" DFMT "ic" DFMT "ih" DFMT "iw" DFMT,
+ dims[0], dims[1], dims[2], dims[3]);
+ else if (md->ndims == 5)
+ DPRINT(str, len, written,
+ "mb" DFMT "ic" DFMT "id" DFMT "ih" DFMT "iw" DFMT,
+ dims[0], dims[1], dims[2], dims[3], dims[4]);
+ else
+ mkldnn_md2dim_str(str, len, md);
+}
+
+void verbose_templ(char *buffer, mkldnn_primitive_kind_t prim_kind,
+ const char *impl_str, mkldnn_prop_kind_t prop_kind,
+ const char *data_str, const char *aux_str, const char *prb_str) {
+ MAYBE_UNUSED(verbose_templ);
+ int written = 0;
+ DPRINT(buffer, MKLDNN_VERBOSE_BUF_LEN, written, "%s,%s,%s,%s,%s,%s",
+ mkldnn_prim_kind2str(prim_kind), impl_str,
+ mkldnn_prop_kind2str(prop_kind), data_str, aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_bnorm(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // data
+ auto md = s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // diff data
+ auto md = s->diff_src_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "flags:%u", s->desc()->flags);
+
+ format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_conv(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src
+ auto md = s->desc()->prop_kind == prop_kind::backward_data
+ ? s->diff_src_md() : s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // wei
+ auto md = s->desc()->prop_kind == prop_kind::backward_weights
+ ? s->diff_weights_md() : s->weights_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // bia
+ auto md = s->desc()->prop_kind == prop_kind::backward_weights
+ ? s->diff_weights_md(1) : s->weights_md(1);
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+ if (1) { // dst
+ auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
+
+ if (s->ndims() == 5) {
+ if (s->with_groups())
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT
+ "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT
+ "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
+ "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
+ s->MB(), s->G(), s->IC(), s->OC(),
+ s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
+ else
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "_ic" DFMT "oc" DFMT
+ "_id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "dd" DFMT "pd" DFMT
+ "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
+ "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
+ s->MB(), s->IC(), s->OC(),
+ s->ID(), s->OD(), s->KD(), s->KSD(), s->KDD(), s->padFront(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
+ } else {
+ if (s->with_groups())
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "_g" DFMT "ic" DFMT "oc" DFMT
+ "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
+ "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
+ s->MB(), s->G(), s->IC(), s->OC(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
+ else
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "_ic" DFMT "oc" DFMT
+ "_ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "dh" DFMT "ph" DFMT
+ "_iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "dw" DFMT "pw" DFMT,
+ s->MB(), s->IC(), s->OC(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->KDH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->KDW(), s->padL());
+ }
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_shuffle(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ auto md = s->is_fwd() ? s->src_md() : s->diff_dst_md();
+
+ if (1) { // data
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "axis:%d group_size:" DFMT, s->axis(), s->group_size());
+
+ mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, md);
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_eltwise(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // data
+ auto md = s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // diff data
+ auto md = s->diff_src_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
+
+ mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_iprod(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src
+ auto md = s->desc()->prop_kind == prop_kind::backward_data
+ ? s->diff_src_md() : s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // wei
+ auto md = s->desc()->prop_kind == prop_kind::backward_weights
+ ? s->diff_weights_md() : s->weights_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // bia
+ auto md = s->desc()->prop_kind == prop_kind::backward_weights
+ ? s->diff_weights_md(1) : s->weights_md(1);
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bia_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+ if (1) { // dst
+ auto md = !s->is_fwd() ? s->diff_dst_md() : s->dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "ic" DFMT "oc" DFMT, s->MB(), s->IC_total(), s->OC());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_lrn(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // data
+ auto md = s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // diff data
+ auto md = s->diff_src_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
+
+ format_prb_desc_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->src_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_mem(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src
+ auto md = s->src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // dst
+ auto md = s->dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "num:%d", s->n_inputs());
+
+ mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), prop_kind::undef, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_pool(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src
+ auto md = s->is_fwd() ? s->src_md() : s->diff_src_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // dst
+ auto md = s->is_fwd() ? s->dst_md() : s->diff_dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " dst_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // ws
+ auto md = s->workspace_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " ws_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s", mkldnn_alg_kind2str(s->desc()->alg_kind));
+
+ if (s->is_3d()) {
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "ic" DFMT "_"
+ "id" DFMT "od" DFMT "kd" DFMT "sd" DFMT "pd" DFMT "_"
+ "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_"
+ "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT "",
+ s->MB(), s->C(),
+ s->ID(), s->OD(), s->KD(), s->KSD(), s->padFront(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->padL());
+ } else {
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "mb" DFMT "ic" DFMT "_"
+ "ih" DFMT "oh" DFMT "kh" DFMT "sh" DFMT "ph" DFMT "_"
+ "iw" DFMT "ow" DFMT "kw" DFMT "sw" DFMT "pw" DFMT,
+ s->MB(), s->C(),
+ s->IH(), s->OH(), s->KH(), s->KSH(), s->padT(),
+ s->IW(), s->OW(), s->KW(), s->KSW(), s->padL());
+ }
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_softmax(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // data
+ auto md = s->dst_md();
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "data_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // diff data
+ auto md = s->diff_src_md();
+ if (md) {
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " diff_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ }
+
+ mkldnn_md2dim_str(prb_str, MKLDNN_VERBOSE_PRB_LEN, s->dst_md());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+template <typename pd_t> static void init_info_rnn(pd_t *s, char *buffer) {
+ DECL_DAT_AUX_PRB_STRS();
+
+ if (1) { // src layer
+ auto md = s->is_fwd() ? s->src_md(0) : s->diff_src_md(0);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_layer_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // src iter
+ auto md = s->is_fwd() ? s->src_md(1) : s->diff_src_md(1);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "src_iter_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // wei_layer
+ auto md = s->is_fwd() ? s->weights_md(0) : s->diff_weights_md(0);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // wei_iter
+ auto md = s->is_fwd() ? s->weights_md(1) : s->diff_weights_md(1);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " wei_layer_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // bias
+ auto md = s->is_fwd() ? s->weights_md(2) : s->diff_weights_md(2);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, " bias_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // dst layer
+ auto md = s->is_fwd() ? s->dst_md(0) : s->diff_dst_md(0);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_layer_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+ if (1) { // dst iter
+ auto md = s->is_fwd() ? s->dst_md(1) : s->diff_dst_md(1);
+ DPRINT(dat_str, MKLDNN_VERBOSE_DAT_LEN, dat_written, "dst_iter_");
+ int l = mkldnn_md2fmt_str(dat_str + dat_written,
+ MKLDNN_VERBOSE_DAT_LEN - dat_written, md);
+ if (l >= 0) dat_written += l; else clear_buf(dat_str, dat_written);
+ }
+
+ alg_kind_t alg_kind = s->cell_kind();
+ rnn_direction_t rnn_dir = s->direction();
+ DPRINT(aux_str, MKLDNN_VERBOSE_AUX_LEN, aux_written,
+ "alg:%s_%s", mkldnn_alg_kind2str(alg_kind),
+ mkldnn_rnn_direction2str(rnn_dir));
+
+ DPRINT(prb_str, MKLDNN_VERBOSE_PRB_LEN, prb_written,
+ "l" DFMT "t" DFMT "mb" DFMT
+ "sic" DFMT "slc" DFMT "dic" DFMT "dlc" DFMT,
+ s->L(), s->T(), s->MB(),
+ s->SIC(), s->SLC(), s->DIC(), s->DLC());
+
+ verbose_templ(buffer, s->kind(), s->name(), s->desc()->prop_kind, dat_str,
+ aux_str, prb_str);
+}
+
+#undef DPRINT
+
+#else // !defined(DISABLE_VERBOSE)
+
+#define DEFINE_STUB(name) \
+ template <typename pd_t> \
+ static void CONCAT2(init_info_, name)(pd_t *s, char *buffer) \
+ { UNUSED(s); UNUSED(buffer); }
+
+DEFINE_STUB(bnorm);
+DEFINE_STUB(conv);
+DEFINE_STUB(eltwise);
+DEFINE_STUB(iprod);
+DEFINE_STUB(lrn);
+DEFINE_STUB(mem);
+DEFINE_STUB(pool);
+DEFINE_STUB(softmax);
+DEFINE_STUB(rnn);
+DEFINE_STUB(shuffle);
+#undef DEFINE_STUB
+
+#endif // !defined(DISABLE_VERBOSE)
+}
+
+void init_info(batch_normalization_pd_t *s, char *b)
+{ init_info_bnorm(s, b); }
+void init_info(concat_pd_t *s, char *b)
+{ init_info_mem(s, b); }
+void init_info(convolution_pd_t *s, char *b)
+{ init_info_conv(s, b); }
+void init_info(deconvolution_pd_t *s, char *b)
+{ init_info_conv(s, b); }
+void init_info(eltwise_pd_t *s, char *b)
+{ init_info_eltwise(s, b); }
+void init_info(inner_product_pd_t *s, char *b)
+{ init_info_iprod(s, b); }
+void init_info(lrn_pd_t *s, char *b)
+{ init_info_lrn(s, b); }
+void init_info(pooling_pd_t *s, char *b)
+{ init_info_pool(s, b); }
+void init_info(reorder_pd_t *s, char *b)
+{ init_info_mem(s, b); }
+void init_info(rnn_pd_t *s, char *b)
+{ init_info_rnn(s, b); }
+void init_info(shuffle_pd_t *s, char *b)
+{ init_info_shuffle(s, b); }
+void init_info(softmax_pd_t *s, char *b)
+{ init_info_softmax(s, b); }
+void init_info(sum_pd_t *s, char *b)
+{ init_info_mem(s, b); }
+
+}
+}
+
+mkldnn_status_t mkldnn_set_verbose(int level) {
+ using namespace mkldnn::impl::status;
+ if (level < 0 || level > 2) return invalid_arguments;
+ mkldnn::impl::verbose.level = level;
+ mkldnn::impl::initialized = true;
+ return success;
+}
+
+const mkldnn_version_t *mkldnn_version() {
+ static mkldnn_version_t ver = {
+ MKLDNN_VERSION_MAJOR,
+ MKLDNN_VERSION_MINOR,
+ MKLDNN_VERSION_PATCH,
+ MKLDNN_VERSION_HASH};
+ return &ver;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp
new file mode 100644
index 0000000000..e3049750cb
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/verbose.hpp
@@ -0,0 +1,62 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef VERBOSE_HPP
+#define VERBOSE_HPP
+
+#include <stdio.h>
+#include <cinttypes>
+
+#include "mkldnn_debug.h"
+#include "c_types_map.hpp"
+#include "utils.hpp"
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct verbose_t {
+ int level;
+};
+
+const verbose_t *mkldnn_verbose();
+double get_msec();
+const char *get_isa_info();
+
+#if !defined(DISABLE_VERBOSE)
+#define MKLDNN_VERBOSE_BUF_LEN 1024
+#else
+#define MKLDNN_VERBOSE_BUF_LEN 1
+#endif
+
+void init_info(batch_normalization_pd_t *s, char *buffer);
+void init_info(concat_pd_t *s, char *buffer);
+void init_info(convolution_pd_t *s, char *buffer);
+void init_info(deconvolution_pd_t *s, char *buffer);
+void init_info(eltwise_pd_t *s, char *buffer);
+void init_info(inner_product_pd_t *s, char *buffer);
+void init_info(lrn_pd_t *s, char *buffer);
+void init_info(pooling_pd_t *s, char *buffer);
+void init_info(reorder_pd_t *s, char *buffer);
+void init_info(rnn_pd_t *s, char *buffer);
+void init_info(shuffle_pd_t *s, char *buffer);
+void init_info(softmax_pd_t *s, char *buffer);
+void init_info(sum_pd_t *s, char *buffer);
+
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp
new file mode 100644
index 0000000000..520bd4710b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/common/z_magic.hpp
@@ -0,0 +1,46 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef Z_MAGIC_HPP
+#define Z_MAGIC_HPP
+
+#define CHAIn2(a,b) a b
+#define CHAIN2(a,b) CHAIn2(a,b)
+
+#define CONCAt2(a,b) a ## b
+#define CONCAT2(a,b) CONCAt2(a,b)
+
+#define STRINGIFy(s) #s
+#define STRINGIFY(s) STRINGIFy(s)
+
+#ifdef _MSC_VER
+# define PRAGMA_MACRo(x) __pragma(x)
+# define PRAGMA_MACRO(x) PRAGMA_MACRo(x)
+#else
+# define PRAGMA_MACRo(x) _Pragma(#x)
+# define PRAGMA_MACRO(x) PRAGMA_MACRo(x)
+#endif
+
+#define UNUSED(x) ((void)x)
+#define MAYBE_UNUSED(x) UNUSED(x)
+
+#if defined(_WIN32) && !defined(__GNUC__)
+#define __PRETTY_FUNCTION__ __FUNCSIG__
+#endif
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp
new file mode 100644
index 0000000000..7cf7822d90
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.cpp
@@ -0,0 +1,112 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "cpu_barrier.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace simple_barrier {
+
+void generate(jit_generator &code, Xbyak::Reg64 reg_ctx,
+ Xbyak::Reg64 reg_nthr) {
+# define BAR_CTR_OFF offsetof(ctx_t, ctr)
+# define BAR_SENSE_OFF offsetof(ctx_t, sense)
+ using namespace Xbyak;
+
+ Xbyak::Reg64 reg_tmp = [&]() {
+ /* returns register which is neither reg_ctx nor reg_nthr */
+ Xbyak::Reg64 regs[] = { util::rax, util::rbx, util::rcx };
+ for (size_t i = 0; i < sizeof(regs) / sizeof(regs[0]); ++i)
+ if (!utils::one_of(regs[i], reg_ctx, reg_nthr))
+ return regs[i];
+ return regs[0]; /* should not happen */
+ }();
+
+ Label barrier_exit_label, barrier_exit_restore_label, spin_label;
+
+ code.cmp(reg_nthr, 1);
+ code.jbe(barrier_exit_label);
+
+ code.push(reg_tmp);
+
+ /* take and save current sense */
+ code.mov(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]);
+ code.push(reg_tmp);
+ code.mov(reg_tmp, 1);
+
+ if (mayiuse(avx512_mic)) {
+ code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]);
+ code.prefetchwt1(code.ptr[reg_ctx + BAR_CTR_OFF]);
+ }
+
+ code.lock(); code.xadd(code.ptr[reg_ctx + BAR_CTR_OFF], reg_tmp);
+ code.add(reg_tmp, 1);
+ code.cmp(reg_tmp, reg_nthr);
+ code.pop(reg_tmp); /* restore previous sense */
+ code.jne(spin_label);
+
+ /* the last thread {{{ */
+ code.mov(code.qword[reg_ctx + BAR_CTR_OFF], 0); // reset ctx
+
+ // notify waiting threads
+ code.not_(reg_tmp);
+ code.mov(code.ptr[reg_ctx + BAR_SENSE_OFF], reg_tmp);
+ code.jmp(barrier_exit_restore_label);
+ /* }}} the last thread */
+
+ code.CodeGenerator::L(spin_label);
+ code.pause();
+ code.cmp(reg_tmp, code.ptr[reg_ctx + BAR_SENSE_OFF]);
+ code.je(spin_label);
+
+ code.CodeGenerator::L(barrier_exit_restore_label);
+ code.pop(reg_tmp);
+
+ code.CodeGenerator::L(barrier_exit_label);
+# undef BAR_CTR_OFF
+# undef BAR_SENSE_OFF
+}
+
+/** jit barrier generator */
+struct jit_t: public jit_generator {
+ void (*barrier)(ctx_t *ctx, size_t nthr);
+
+ jit_t() {
+ generate(*this, abi_param1, abi_param2);
+ ret();
+ barrier = reinterpret_cast<decltype(barrier)>(const_cast<uint8_t*>(
+ this->getCode()));
+ }
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_t)
+};
+
+void barrier(ctx_t *ctx, int nthr) {
+ static jit_t j; /* XXX: constructed on load ... */
+ j.barrier(ctx, nthr);
+}
+
+}
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp
new file mode 100644
index 0000000000..0f55e33aa8
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_barrier.hpp
@@ -0,0 +1,60 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_BARRIER_HPP
+#define CPU_BARRIER_HPP
+
+#include <assert.h>
+
+#include "jit_generator.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace simple_barrier {
+
+STRUCT_ALIGN(64,
+struct ctx_t {
+ enum { CACHE_LINE_SIZE = 64 };
+ volatile size_t ctr;
+ char pad1[CACHE_LINE_SIZE - 1 * sizeof(size_t)];
+ volatile size_t sense;
+ char pad2[CACHE_LINE_SIZE - 1 * sizeof(size_t)];
+});
+
+inline void ctx_init(ctx_t *ctx) { *ctx = utils::zero<ctx_t>(); }
+void barrier(ctx_t *ctx, int nthr);
+
+/** injects actual barrier implementation into another jitted code
+ * @params:
+ * code -- jit_generator object where the barrier is to be injected
+ * reg_ctx -- read-only register with pointer to the barrier context
+ * reg_nnthr -- read-only register with the # of synchronizing threads
+ */
+void generate(jit_generator &code, Xbyak::Reg64 reg_ctx,
+ Xbyak::Reg64 reg_nthr);
+
+}
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp
new file mode 100644
index 0000000000..1ed5ad57b9
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_pd.hpp
@@ -0,0 +1,40 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_BATCH_NORMALIZATION_PD_HPP
+#define CPU_BATCH_NORMALIZATION_PD_HPP
+
+#include "batch_normalization_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_batch_normalization_fwd_pd_t: public batch_normalization_fwd_pd_t {
+ using batch_normalization_fwd_pd_t::batch_normalization_fwd_pd_t;
+};
+
+struct cpu_batch_normalization_bwd_pd_t: public batch_normalization_bwd_pd_t {
+ using batch_normalization_bwd_pd_t::batch_normalization_bwd_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp
new file mode 100644
index 0000000000..b8d5c4fcaf
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.cpp
@@ -0,0 +1,140 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "utils.hpp"
+
+#include "jit_generator.hpp"
+
+#include "cpu_batch_normalization_utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+namespace bnorm_utils {
+
+void cache_balance(size_t working_set_size, dim_t C_blks,
+ dim_t &C_blks_per_iter, int64_t &iters) {
+ int nthrs = mkldnn_get_max_threads();
+ int l3_size = get_cache_size(3, true) * nthrs / 2;
+
+ C_blks_per_iter = l3_size / working_set_size;
+
+ if (C_blks_per_iter == 0)
+ C_blks_per_iter = 1;
+ if (C_blks_per_iter > C_blks)
+ C_blks_per_iter = C_blks;
+
+ iters = (C_blks + C_blks_per_iter - 1) / C_blks_per_iter;
+}
+
+bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr,
+ int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr,
+ dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s,
+ dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e) {
+ if (nthr <= C_blks || !mkldnn_thr_syncable()) {
+ C_ithr = ithr; C_nthr = nthr;
+ N_ithr = 0; N_nthr = 1;
+ S_ithr = 0; S_nthr = 1;
+ N_s = 0; N_e = N; S_s = 0; S_e = SP;
+ balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e);
+ } else {
+ if (do_blocking) {
+ N_nthr = (int)nstl::min<dim_t>(N, nthr);
+ C_nthr = (int)nstl::min<dim_t>(C_blks, nthr / N_nthr);
+ S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr));
+ } else {
+ C_nthr = (int)math::gcd((dim_t)nthr, C_blks);
+ N_nthr = (int)nstl::min<dim_t>(N, nthr / C_nthr);
+ S_nthr = (int)nstl::min<dim_t>(SP, nthr / (C_nthr * N_nthr));
+ }
+
+ if (!spatial_thr_allowed)
+ S_nthr = 1;
+
+ if (S_nthr < 1) S_nthr = 1;
+ if (ithr < C_nthr * N_nthr * S_nthr) {
+ N_ithr = (ithr / S_nthr) % N_nthr ;
+ C_ithr = ithr / (N_nthr * S_nthr);
+ S_ithr = ithr % S_nthr;
+ balance211(C_blks, C_nthr, C_ithr, C_blk_s, C_blk_e);
+ balance211(N, N_nthr, N_ithr, N_s, N_e);
+ balance211(SP, S_nthr, S_ithr, S_s, S_e);
+ } else {
+ S_ithr = N_ithr = C_ithr = -ithr;
+ S_s = S_e = N_s = N_e = C_blk_s = C_blk_e = -1;
+ }
+ }
+
+ // spatial_thr_allowed is meant to help maintain
+ // consistent decisions about spatial threading
+ // between mutiple invocations of this routine.
+ // It is caller's responsibility to check the
+ // return value and pass it as a flag to the
+ // next call if needed.
+ if (S_nthr == 1)
+ spatial_thr_allowed = false;
+
+ return spatial_thr_allowed;
+}
+
+bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w,
+ int data_size) {
+ if (!mkldnn_thr_syncable()) return false;
+
+ dim_t nthr = mkldnn_get_max_threads();
+ dim_t SP = bdesc->W() * bdesc->D() * bdesc->H();
+ dim_t C_PADDED = memory_desc_wrapper(bdesc->src_md())
+ .padded_dims()[1];
+ assert(C_PADDED % simd_w == 0);
+
+ size_t data = bdesc->MB() * C_PADDED * SP * data_size;
+ size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
+ bool do_blocking = (data >= l3_size_ / 2 && l3_size_ > 0);
+ dim_t C_blks_per_iter{ 1 }, iters{ 1 };
+ dim_t C_blks = C_PADDED / simd_w;
+
+ if (do_blocking) {
+ int num_tensors = bdesc->is_fwd() ? 1 : 2;
+ size_t working_set_size
+ = (bdesc->MB() * SP * simd_w * data_size) * num_tensors;
+ cache_balance(working_set_size, C_blks, C_blks_per_iter, iters);
+ }
+
+ // Spatial threading decision made in this function shall be consistent
+ // with thread_balance() behavior.
+ C_blks = do_blocking ? C_blks_per_iter : C_blks;
+
+ if (nthr <= C_blks) return false;
+
+ dim_t S_nthr = 1;
+ if (do_blocking) {
+ dim_t N_nthr = nstl::min(bdesc->MB(), nthr);
+ dim_t C_nthr = nstl::min(C_blks, nthr / N_nthr);
+ S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));
+ } else {
+ dim_t C_nthr = math::gcd(nthr, C_blks);
+ dim_t N_nthr = nstl::min(bdesc->MB(), nthr / C_nthr);
+ S_nthr = nstl::min(SP, nthr / (C_nthr * N_nthr));
+ }
+
+ return S_nthr > 1;
+}
+
+}
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp
new file mode 100644
index 0000000000..0daef0716c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_batch_normalization_utils.hpp
@@ -0,0 +1,43 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_BATCH_NORMALIZATION_UTILS_HPP
+#define CPU_BATCH_NORMALIZATION_UTILS_HPP
+
+#include "batch_normalization_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+namespace bnorm_utils {
+
+void cache_balance(size_t working_set_size, dim_t C_blks,
+ dim_t &C_blks_per_iter, int64_t &iters);
+
+bool thread_balance(bool do_blocking, bool spatial_thr_allowed, int ithr,
+ int nthr, dim_t N, dim_t C_blks, dim_t SP, int &C_ithr, int &C_nthr,
+ dim_t &C_blk_s, dim_t &C_blk_e, int &N_ithr, int &N_nthr, dim_t &N_s,
+ dim_t &N_e, int &S_ithr, int &S_nthr, dim_t &S_s, dim_t &S_e);
+
+bool is_spatial_thr(const batch_normalization_pd_t *bdesc, int simd_w,
+ int data_size);
+
+}
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp
new file mode 100644
index 0000000000..b926491202
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat.cpp
@@ -0,0 +1,51 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "cpu_engine.hpp"
+
+/*
+#include "cpu/ref_concat.hpp"
+#include "cpu/simple_concat.hpp"
+*/
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using cpd_create_f = mkldnn::impl::engine_t::concat_primitive_desc_create_f;
+
+namespace {
+#define INSTANCE(...) __VA_ARGS__::pd_t::create
+static const cpd_create_f cpu_concat_impl_list[] = {
+ /*
+ INSTANCE(simple_concat_t<data_type::f32>),
+ INSTANCE(simple_concat_t<data_type::u8>),
+ INSTANCE(simple_concat_t<data_type::s8>),
+ INSTANCE(simple_concat_t<data_type::s32>),
+ INSTANCE(ref_concat_t),
+ */
+ nullptr,
+};
+#undef INSTANCE
+}
+
+const cpd_create_f *cpu_engine_t::get_concat_implementation_list() const {
+ return cpu_concat_impl_list;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp
new file mode 100644
index 0000000000..0b01bcf163
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_concat_pd.hpp
@@ -0,0 +1,41 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_CONCAT_PD_HPP
+#define CPU_CONCAT_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "concat_pd.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_concat_pd_t: public concat_pd_t {
+ using concat_pd_t::concat_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp
new file mode 100644
index 0000000000..52a38a2294
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_convolution_pd.hpp
@@ -0,0 +1,74 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_CONVOLUTION_PD_HPP
+#define CPU_CONVOLUTION_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "convolution_pd.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_convolution_fwd_pd_t: public convolution_fwd_pd_t {
+ using convolution_fwd_pd_t::convolution_fwd_pd_t;
+
+ bool has_padded_dst() const {
+ memory_desc_wrapper dst_d(&dst_md_);
+ return OC() != dst_d.padded_dims()[1];
+ }
+
+ bool wants_padded_bias() const {
+ if (!with_bias()) return false;
+ return has_padded_dst();
+ }
+
+ bool wants_zero_pad_dst(bool jit_impl = true) const {
+ if (!has_padded_dst()) return false;
+ const auto &po = attr()->post_ops_;
+ int idx;
+ if ((idx = po.find(primitive_kind::eltwise)) == -1) return false;
+ return !math::eltwise_fwd_preserves_zero(po.entry_[idx].eltwise.alg,
+ jit_impl);
+ }
+};
+
+struct cpu_convolution_bwd_data_pd_t: public convolution_bwd_data_pd_t {
+ using convolution_bwd_data_pd_t::convolution_bwd_data_pd_t;
+};
+
+struct cpu_convolution_bwd_weights_pd_t: public convolution_bwd_weights_pd_t {
+ using convolution_bwd_weights_pd_t::convolution_bwd_weights_pd_t;
+
+ bool wants_padded_bias() const {
+ if (!with_bias()) return false;
+ memory_desc_wrapper diff_dst_d(&diff_dst_md_);
+ return OC() != diff_dst_d.padded_dims()[1];
+ }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp
new file mode 100644
index 0000000000..164c8601d7
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_deconvolution_pd.hpp
@@ -0,0 +1,46 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_DECONVOLUTION_PD_HPP
+#define CPU_DECONVOLUTION_PD_HPP
+
+#include <assert.h>
+
+#include "deconvolution_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_deconvolution_fwd_pd_t: public deconvolution_fwd_pd_t {
+ using deconvolution_fwd_pd_t::deconvolution_fwd_pd_t;
+};
+
+struct cpu_deconvolution_bwd_data_pd_t: public deconvolution_bwd_data_pd_t {
+ using deconvolution_bwd_data_pd_t::deconvolution_bwd_data_pd_t;
+};
+
+struct cpu_deconvolution_bwd_weights_pd_t: public deconvolution_bwd_weights_pd_t {
+ using deconvolution_bwd_weights_pd_t::deconvolution_bwd_weights_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp
new file mode 100644
index 0000000000..c52f00026e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_eltwise_pd.hpp
@@ -0,0 +1,45 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_ELTWISE_PD_HPP
+#define CPU_ELTWISE_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "eltwise_pd.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_eltwise_fwd_pd_t: public eltwise_fwd_pd_t {
+ using eltwise_fwd_pd_t::eltwise_fwd_pd_t;
+};
+
+struct cpu_eltwise_bwd_pd_t: public eltwise_bwd_pd_t {
+ using eltwise_bwd_pd_t::eltwise_bwd_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp
new file mode 100644
index 0000000000..ce0a3667ad
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.cpp
@@ -0,0 +1,324 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "type_helpers.hpp"
+#include "verbose.hpp"
+
+#include "cpu_engine.hpp"
+#include "cpu_memory.hpp"
+
+//#include "cpu/rnn/ref_rnn.hpp"
+
+//#include "cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
+//#include "cpu/jit_avx512_common_1x1_convolution.hpp"
+#include "cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp"
+#include "cpu/jit_avx512_common_convolution_winograd.hpp"
+//#include "cpu/jit_avx512_core_x8s8s32x_convolution.hpp"
+#include "cpu/jit_avx512_common_convolution.hpp"
+//#include "cpu/jit_avx2_1x1_convolution.hpp"
+//#include "cpu/jit_sse42_1x1_convolution.hpp"
+#include "cpu/jit_avx2_convolution.hpp"
+#include "cpu/jit_sse42_convolution.hpp"
+//#include "cpu/gemm_convolution.hpp"
+//#include "cpu/gemm_x8s8s32x_convolution.hpp"
+//#include "cpu/ref_convolution.hpp"
+//#include "cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp"
+//#include "cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp"
+//#include "cpu/ref_deconvolution.hpp"
+//#include "cpu/ref_shuffle.hpp"
+//#include "cpu/jit_uni_eltwise.hpp"
+//#include "cpu/ref_eltwise.hpp"
+//#include "cpu/ref_softmax.hpp"
+#include "cpu/jit_uni_pooling.hpp"
+//#include "cpu/jit_uni_i8i8_pooling.hpp"
+//#include "cpu/ref_pooling.hpp"
+//#include "cpu/nchw_pooling.hpp"
+//#include "cpu/nhwc_pooling.hpp"
+//#include "cpu/jit_avx512_common_lrn.hpp"
+//#include "cpu/jit_uni_lrn.hpp"
+//#include "cpu/ref_lrn.hpp"
+//#include "cpu/jit_uni_batch_normalization.hpp"
+//#include "cpu/ref_batch_normalization.hpp"
+//#include "cpu/ncsp_batch_normalization.hpp"
+//#include "cpu/nspc_batch_normalization.hpp"
+//#include "cpu/ref_inner_product.hpp"
+//#include "cpu/gemm_inner_product.hpp"
+//#include "cpu/gemm_x8s8s32x_inner_product.hpp"
+//#include "cpu/jit_uni_dw_convolution.hpp"
+//#include "cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp"
+#include "cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+status_t cpu_engine_t::memory_create(memory_t **memory,
+ const memory_desc_t *md, void *handle) {
+ auto _memory = new cpu_memory_t(this, md, handle);
+ if (_memory == nullptr)
+ return status::out_of_memory;
+
+ status_t status = _memory->init();
+ if (status != status::success) {
+ delete _memory;
+ return status;
+ }
+
+ return safe_ptr_assign<memory_t>(*memory, _memory);
+}
+
+using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
+
+namespace {
+using namespace mkldnn::impl::data_type;
+
+#define INSTANCE(...) &primitive_desc_t::create<__VA_ARGS__::pd_t>
+static const pd_create_f cpu_impl_list[] = {
+ /* RNN */
+ /*
+ INSTANCE(ref_rnn_fwd_f32_t),
+ INSTANCE(ref_rnn_fwd_u8s8_t),
+ INSTANCE(ref_rnn_bwd_f32_t),
+ */
+ /* conv */
+ /*
+ INSTANCE(jit_avx512_common_dw_convolution_fwd_t),
+ INSTANCE(jit_avx512_common_dw_convolution_bwd_data_t),
+ INSTANCE(jit_avx512_common_dw_convolution_bwd_weights_t),
+ INSTANCE(jit_avx512_common_1x1_convolution_fwd_f32_t),
+ INSTANCE(jit_avx512_common_1x1_convolution_bwd_data_f32_t),
+ INSTANCE(jit_avx512_common_1x1_convolution_bwd_weights_t),
+ */
+ INSTANCE(jit_avx512_core_fp32_wino_conv_2x3_fwd_t),
+ INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_fwd_t),
+ //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t),
+ //INSTANCE(jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t),
+ INSTANCE(jit_avx512_common_convolution_winograd_fwd_t),
+ //INSTANCE(jit_avx512_common_convolution_winograd_bwd_data_t),
+ //INSTANCE(jit_avx512_common_convolution_winograd_bwd_weights_t),
+ INSTANCE(jit_avx512_common_convolution_fwd_t<f32>),
+ //INSTANCE(jit_avx512_common_convolution_bwd_data_t<f32>),
+ //INSTANCE(jit_avx512_common_convolution_bwd_weights_t<f32>),
+ /*
+ INSTANCE(jit_avx2_dw_convolution_fwd_t),
+ INSTANCE(jit_avx2_dw_convolution_bwd_data_t),
+ INSTANCE(jit_avx2_dw_convolution_bwd_weights_t),
+ INSTANCE(jit_avx2_1x1_convolution_fwd_t),
+ INSTANCE(jit_avx2_1x1_convolution_bwd_data_t),
+ INSTANCE(jit_avx2_1x1_convolution_bwd_weights_t),
+ INSTANCE(jit_sse42_dw_convolution_fwd_t),
+ INSTANCE(jit_sse42_dw_convolution_bwd_data_t),
+ INSTANCE(jit_sse42_dw_convolution_bwd_weights_t),
+ INSTANCE(jit_sse42_1x1_convolution_fwd_t),
+ */
+ INSTANCE(jit_avx2_convolution_fwd_t),
+ //INSTANCE(jit_avx2_convolution_bwd_data_t),
+ //INSTANCE(jit_avx2_convolution_bwd_weights_t),
+ INSTANCE(jit_sse42_convolution_fwd_t),
+ /*
+ INSTANCE(gemm_convolution_fwd_t),
+ INSTANCE(gemm_convolution_bwd_data_t),
+ INSTANCE(gemm_convolution_bwd_weights_t),
+ INSTANCE(ref_convolution_fwd_t<f32>),
+ INSTANCE(ref_convolution_bwd_data_t<f32, f32, f32, f32>),
+ INSTANCE(ref_convolution_bwd_weights_t<f32, f32, f32, f32>),
+ */
+ /* conv (int) */
+ /*
+ INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<f32>),
+ INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s32>),
+ INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<s8>),
+ INSTANCE(jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<u8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,f32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,u8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8,s8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,f32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,u8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8,s8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,f32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,u8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<u8,s8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,f32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,u8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_convolution_fwd_t<s8,s8>),
+ INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s32>),
+ INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, u8>),
+ INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, s8>),
+ INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<u8, f32>),
+ INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s32>),
+ INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, u8>),
+ INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, s8>),
+ INSTANCE(_gemm_x8s8s32x_convolution_fwd_t<s8, f32>),
+ INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s32>),
+ INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<u8>),
+ INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<s8>),
+ INSTANCE(_gemm_u8s8s32x_convolution_bwd_data_t<f32>),
+ INSTANCE(ref_convolution_fwd_t<u8, s8, f32, s32>),
+ INSTANCE(ref_convolution_fwd_t<u8, s8, s32, s32>),
+ INSTANCE(ref_convolution_fwd_t<u8, s8, s8, s32>),
+ INSTANCE(ref_convolution_fwd_t<u8, s8, u8, s32>),
+ INSTANCE(ref_convolution_bwd_data_t<f32, s8, u8, s32>),
+ INSTANCE(ref_convolution_bwd_data_t<s32, s8, u8, s32>),
+ INSTANCE(ref_convolution_bwd_data_t<s8, s8, u8, s32>),
+ INSTANCE(ref_convolution_bwd_data_t<u8, s8, u8, s32>),
+ */
+ /* deconv */
+ /*
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,f32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,u8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<u8,s8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,f32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s32>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,u8>),
+ INSTANCE(jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<s8,s8>),
+ INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s32>),
+ INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,u8>),
+ INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,s8>),
+ INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<u8,f32>),
+ INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s32>),
+ INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,u8>),
+ INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,s8>),
+ INSTANCE(_jit_avx512_core_x8s8s32x_deconvolution_fwd_t<s8,f32>),
+ INSTANCE(ref_deconvolution_bwd_weights_t),
+ INSTANCE(ref_deconvolution_bwd_data_t),
+ INSTANCE(ref_deconvolution_fwd_t),
+ */
+ /* shuffle */
+ /*
+ INSTANCE(ref_shuffle_t<4>), // f32 or s32
+ INSTANCE(ref_shuffle_t<1>), // s8 or u8
+ */
+ /* eltwise */
+ /*
+ INSTANCE(jit_uni_eltwise_fwd_t<avx512_common>),
+ INSTANCE(jit_uni_eltwise_bwd_t<avx512_common>),
+ INSTANCE(jit_uni_eltwise_fwd_t<avx2>),
+ INSTANCE(jit_uni_eltwise_bwd_t<avx2>),
+ INSTANCE(jit_uni_eltwise_fwd_t<sse42>),
+ INSTANCE(jit_uni_eltwise_bwd_t<sse42>),
+ INSTANCE(ref_eltwise_fwd_t<f32>),
+ INSTANCE(ref_eltwise_bwd_t<f32>),
+ */
+ /* eltwise (int) */
+ /*
+ INSTANCE(ref_eltwise_fwd_t<s32>),
+ INSTANCE(ref_eltwise_fwd_t<s8>),
+ INSTANCE(ref_eltwise_fwd_t<u8>),
+ INSTANCE(ref_eltwise_bwd_t<s32>),
+ */
+ /* softmax */
+ /*
+ INSTANCE(ref_softmax_fwd_t<f32>),
+ INSTANCE(ref_softmax_bwd_t<f32>),
+ */
+ /* pool */
+ INSTANCE(jit_uni_pooling_fwd_t<avx512_common>),
+ //INSTANCE(jit_uni_pooling_bwd_t<avx512_common>),
+ INSTANCE(jit_uni_pooling_fwd_t<avx>),
+ //INSTANCE(jit_uni_pooling_bwd_t<avx>),
+ INSTANCE(jit_uni_pooling_fwd_t<sse42>),
+ //INSTANCE(jit_uni_pooling_bwd_t<sse42>),
+ /*
+ INSTANCE(nchw_pooling_fwd_t<f32>),
+ INSTANCE(nchw_pooling_bwd_t<f32>),
+ INSTANCE(nhwc_pooling_fwd_t<f32>),
+ INSTANCE(nhwc_pooling_bwd_t<f32>),
+ INSTANCE(ref_pooling_fwd_t<f32>),
+ INSTANCE(ref_pooling_bwd_t<f32>),
+ */
+ /* pool (int) */
+ /*
+ INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx512_core>),
+ INSTANCE(jit_uni_i8i8_pooling_fwd_t<avx2>),
+ INSTANCE(ref_pooling_fwd_t<s32>),
+ INSTANCE(ref_pooling_fwd_t<s8, s32>),
+ INSTANCE(ref_pooling_fwd_t<u8, s32>),
+ INSTANCE(ref_pooling_bwd_t<s32>),
+ */
+ /* lrn */
+ /*
+ INSTANCE(jit_avx512_common_lrn_fwd_t),
+ INSTANCE(jit_avx512_common_lrn_bwd_t),
+ INSTANCE(jit_uni_lrn_fwd_t<avx2>),
+ INSTANCE(jit_uni_lrn_bwd_t<avx2>),
+ INSTANCE(jit_uni_lrn_fwd_t<sse42>),
+ INSTANCE(ref_lrn_fwd_t<f32>),
+ INSTANCE(ref_lrn_bwd_t<f32>),
+ */
+ /* batch normalization */
+ /*
+ INSTANCE(jit_uni_batch_normalization_fwd_t<avx512_common>),
+ INSTANCE(jit_uni_batch_normalization_bwd_t<avx512_common>),
+ INSTANCE(jit_uni_batch_normalization_fwd_t<avx2>),
+ INSTANCE(jit_uni_batch_normalization_bwd_t<avx2>),
+ INSTANCE(jit_uni_batch_normalization_fwd_t<sse42>),
+ INSTANCE(jit_uni_batch_normalization_bwd_t<sse42>),
+ INSTANCE(ncsp_batch_normalization_fwd_t),
+ INSTANCE(ncsp_batch_normalization_bwd_t),
+ INSTANCE(nspc_batch_normalization_fwd_t),
+ INSTANCE(nspc_batch_normalization_bwd_t),
+ INSTANCE(ref_batch_normalization_fwd_t<f32>),
+ INSTANCE(ref_batch_normalization_bwd_t<f32>),
+ INSTANCE(ref_batch_normalization_fwd_t<s8>),
+ */
+ /* inner product */
+ /*
+ INSTANCE(gemm_inner_product_fwd_t<f32>),
+ INSTANCE(gemm_inner_product_bwd_data_t<f32>),
+ INSTANCE(gemm_inner_product_bwd_weights_t<f32>),
+ INSTANCE(ref_inner_product_fwd_t<f32>),
+ INSTANCE(ref_inner_product_bwd_data_t<f32, f32, f32, f32>),
+ INSTANCE(ref_inner_product_bwd_weights_t<f32>),
+ */
+ /* inner product (int) */
+ /*
+ INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, u8>),
+ INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s8>),
+ INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, s32>),
+ INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<u8, f32>),
+ INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, u8>),
+ INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s8>),
+ INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, s32>),
+ INSTANCE(gemm_x8s8s32x_inner_product_fwd_t<s8, f32>),
+ INSTANCE(ref_inner_product_fwd_t<u8, s8, u8, s32>),
+ INSTANCE(ref_inner_product_fwd_t<u8, s8, s8, s32>),
+ INSTANCE(ref_inner_product_fwd_t<u8, s8, s32, s32>),
+ INSTANCE(ref_inner_product_fwd_t<u8, s8, f32, s32>),
+ */
+ /* eol */
+ nullptr,
+};
+#undef INSTANCE
+}
+
+const pd_create_f* cpu_engine_t::get_implementation_list() const {
+ return cpu_impl_list;
+}
+
+cpu_engine_factory_t engine_factory;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp
new file mode 100644
index 0000000000..e4c877ee05
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_engine.hpp
@@ -0,0 +1,70 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_ENGINE_HPP
+#define CPU_ENGINE_HPP
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "../common/engine.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+class cpu_engine_t: public engine_t {
+public:
+ cpu_engine_t(): engine_t(engine_kind::cpu) {}
+
+ /* implementation part */
+
+ virtual status_t memory_create(memory_t **memory,
+ const memory_desc_t *md, void *handle) override;
+
+ virtual const concat_primitive_desc_create_f*
+ get_concat_implementation_list() const override;
+ virtual const reorder_primitive_desc_create_f*
+ get_reorder_implementation_list() const override;
+ virtual const sum_primitive_desc_create_f*
+ get_sum_implementation_list() const override;
+ virtual const primitive_desc_create_f*
+ get_implementation_list() const override;
+};
+
+class cpu_engine_factory_t: public engine_factory_t {
+public:
+ virtual size_t count() const override { return 1; }
+ virtual engine_kind_t kind() const override { return engine_kind::cpu; }
+ virtual status_t engine_create(engine_t **engine,
+ size_t index) const override {
+ assert(index == 0);
+ *engine = new cpu_engine_t();
+ return status::success;
+ };
+};
+
+extern cpu_engine_factory_t engine_factory;
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp
new file mode 100644
index 0000000000..5880d3450c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_inner_product_pd.hpp
@@ -0,0 +1,84 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_INNER_PRODUCT_PD_HPP
+#define CPU_INNER_PRODUCT_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "inner_product_pd.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace {
+inline bool dense_gemm_consitency_check(const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &wei_d, const memory_desc_wrapper &dst_d) {
+ using namespace utils;
+
+ auto strides_compatible = [&]() {
+ bool ok = true;
+ auto w_str = wei_d.blocking_desc().strides;
+ auto d_str = src_d.blocking_desc().strides;
+ for (int i = 1; i < src_d.ndims() - 1; i++) {
+ ok = ok && w_str[i] / d_str[i] == w_str[i + 1] / d_str[i + 1];
+ }
+ return ok && one_of(w_str[1] / d_str[1], 1, wei_d.padded_dims()[0]);
+ };
+ return true && src_d.is_blocking_desc() && wei_d.is_blocking_desc()
+ && src_d.ndims() == wei_d.ndims()
+ && src_d.blocking_desc().inner_nblks
+ == wei_d.blocking_desc().inner_nblks
+ && utils::one_of(src_d.blocking_desc().inner_nblks, 0, 1)
+ && array_cmp(src_d.blocking_desc().inner_blks,
+ wei_d.blocking_desc().inner_blks,
+ wei_d.blocking_desc().inner_nblks)
+ && array_cmp(src_d.blocking_desc().inner_idxs,
+ wei_d.blocking_desc().inner_idxs,
+ wei_d.blocking_desc().inner_nblks)
+ && strides_compatible()
+ && dst_d.matches_tag(format_tag::nc)
+ && src_d.only_padded_dim(1)
+ && wei_d.only_padded_dim(1)
+ && src_d.padded_dims()[1] == wei_d.padded_dims()[1]
+ && src_d.is_dense(true)
+ && dst_d.is_dense()
+ && wei_d.is_dense(true);
+}
+}
+
+struct cpu_inner_product_fwd_pd_t: public inner_product_fwd_pd_t {
+ using inner_product_fwd_pd_t::inner_product_fwd_pd_t;
+};
+
+struct cpu_inner_product_bwd_data_pd_t: public inner_product_bwd_data_pd_t {
+ using inner_product_bwd_data_pd_t::inner_product_bwd_data_pd_t;
+};
+
+struct cpu_inner_product_bwd_weights_pd_t: public inner_product_bwd_weights_pd_t {
+ using inner_product_bwd_weights_pd_t::inner_product_bwd_weights_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp
new file mode 100644
index 0000000000..da6e9dac8e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_isa_traits.hpp
@@ -0,0 +1,151 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_ISA_TRAITS_HPP
+#define CPU_ISA_TRAITS_HPP
+
+#include <type_traits>
+
+#define XBYAK64
+#define XBYAK_NO_OP_NAMES
+/* in order to make selinux happy memory that would be marked with X-bit should
+ * be obtained with mmap */
+#define XBYAK_USE_MMAP_ALLOCATOR
+#if defined(_MSC_VER) && !defined(__INTEL_COMPILER)
+/* turn off `size_t to other-type implicit casting` warning
+ * currently we have a lot of jit-generated instructions that
+ * take uint32_t, but we pass size_t (e.g. due to using sizeof).
+ * FIXME: replace size_t parameters with the appropriate ones */
+#pragma warning (disable: 4267)
+#endif
+#include "xbyak/xbyak.h"
+#include "xbyak/xbyak_util.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+typedef enum {
+ isa_any,
+ sse41,
+ sse42,
+ avx,
+ avx2,
+ avx512_common,
+ avx512_core,
+ avx512_core_vnni,
+ avx512_mic,
+ avx512_mic_4ops,
+} cpu_isa_t;
+
+template <cpu_isa_t> struct cpu_isa_traits {}; /* ::vlen -> 32 (for avx2) */
+
+template <> struct cpu_isa_traits<sse42> {
+ typedef Xbyak::Xmm Vmm;
+ static constexpr int vlen_shift = 4;
+ static constexpr int vlen = 16;
+ static constexpr int n_vregs = 16;
+};
+template <> struct cpu_isa_traits<avx> {
+ typedef Xbyak::Ymm Vmm;
+ static constexpr int vlen_shift = 5;
+ static constexpr int vlen = 32;
+ static constexpr int n_vregs = 16;
+};
+template <> struct cpu_isa_traits<avx2>:
+ public cpu_isa_traits<avx> {};
+
+template <> struct cpu_isa_traits<avx512_common> {
+ typedef Xbyak::Zmm Vmm;
+ static constexpr int vlen_shift = 6;
+ static constexpr int vlen = 64;
+ static constexpr int n_vregs = 32;
+};
+template <> struct cpu_isa_traits<avx512_core>:
+ public cpu_isa_traits<avx512_common> {};
+
+template <> struct cpu_isa_traits<avx512_mic>:
+ public cpu_isa_traits<avx512_common> {};
+
+template <> struct cpu_isa_traits<avx512_mic_4ops>:
+ public cpu_isa_traits<avx512_common> {};
+
+namespace {
+
+static Xbyak::util::Cpu cpu;
+static inline bool mayiuse(const cpu_isa_t cpu_isa) {
+ using namespace Xbyak::util;
+
+ switch (cpu_isa) {
+ case sse41:
+ case sse42:
+ // FIXME: SSE4.2 is actually NOT required
+ //return cpu.has(Cpu::tSSE42);
+ return cpu.has(Cpu::tSSE41);
+ case avx:
+ return cpu.has(Cpu::tAVX);
+ case avx2:
+ return cpu.has(Cpu::tAVX2);
+ case avx512_common:
+ return cpu.has(Cpu::tAVX512F);
+ case avx512_core:
+ return true
+ && cpu.has(Cpu::tAVX512F)
+ && cpu.has(Cpu::tAVX512BW)
+ && cpu.has(Cpu::tAVX512VL)
+ && cpu.has(Cpu::tAVX512DQ);
+ case avx512_core_vnni:
+ return true
+ && cpu.has(Cpu::tAVX512F)
+ && cpu.has(Cpu::tAVX512BW)
+ && cpu.has(Cpu::tAVX512VL)
+ && cpu.has(Cpu::tAVX512DQ)
+ && cpu.has(Cpu::tAVX512_VNNI);
+ case avx512_mic:
+ return true
+ && cpu.has(Cpu::tAVX512F)
+ && cpu.has(Cpu::tAVX512CD)
+ && cpu.has(Cpu::tAVX512ER)
+ && cpu.has(Cpu::tAVX512PF);
+ case avx512_mic_4ops:
+ return true
+ && mayiuse(avx512_mic)
+ && cpu.has(Cpu::tAVX512_4FMAPS)
+ && cpu.has(Cpu::tAVX512_4VNNIW);
+ case isa_any:
+ return true;
+ }
+ return false;
+}
+}
+
+/* whatever is required to generate string literals... */
+#include "z_magic.hpp"
+#define JIT_IMPL_NAME_HELPER(prefix, isa, suffix_if_any) \
+ (isa == sse42 ? prefix STRINGIFY(sse42) : \
+ (isa == avx ? prefix STRINGIFY(avx) : \
+ (isa == avx2 ? prefix STRINGIFY(avx2) : \
+ (isa == avx512_common ? prefix STRINGIFY(avx512_common) : \
+ (isa == avx512_core ? prefix STRINGIFY(avx512_core) : \
+ (isa == avx512_mic ? prefix STRINGIFY(avx512_mic) : \
+ (isa == avx512_mic_4ops ? prefix STRINGIFY(avx512_mic_4ops) : \
+ prefix suffix_if_any)))))))
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp
new file mode 100644
index 0000000000..49988f4c2d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_lrn_pd.hpp
@@ -0,0 +1,42 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_LRN_PD_HPP
+#define CPU_LRN_PD_HPP
+
+#include <assert.h>
+
+#include "lrn_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_lrn_fwd_pd_t: public lrn_fwd_pd_t {
+ using lrn_fwd_pd_t::lrn_fwd_pd_t;
+};
+
+struct cpu_lrn_bwd_pd_t: public lrn_bwd_pd_t {
+ using lrn_bwd_pd_t::lrn_bwd_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp
new file mode 100644
index 0000000000..3c0624cf46
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.cpp
@@ -0,0 +1,277 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn_traits.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_memory.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::data_type;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::format_tag;
+
+enum blk_kind_t { a, b, c, ab, ba, bc, cb };
+
+template <data_type_t dt, blk_kind_t blk_kind, int blksize>
+void typed_zero_pad_blk(
+ const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) {
+ using data_t = typename prec_traits<dt>::type;
+ const auto &dims = m_d.dims();
+ const auto &pdims = m_d.padded_dims();
+ const auto &blk = m_d.blocking_desc();
+ auto dim_is_blocked = [&](int dim) {
+ for (int i = 0; i < blk.inner_nblks; i++)
+ if (blk.inner_idxs[i] == dim)
+ return true;
+ return false;
+ };
+ bool A_blocked = dim_is_blocked(0), B_blocked = dim_is_blocked(1),
+ C_blocked = dim_is_blocked(2);
+
+ assert(blk.inner_nblks < 4);
+ assert((A_blocked || B_blocked || C_blocked) || (A_blocked && B_blocked)
+ || (C_blocked && B_blocked));
+
+ const int a_tail_s = A_blocked ? dims[0] % blksize : 0;
+ const int b_tail_s = B_blocked ? dims[1] % blksize : 0;
+ const int c_tail_s = C_blocked ? dims[2] % blksize : 0;
+ assert(a_tail_s || b_tail_s || c_tail_s);
+
+ const int A = A_blocked ? pdims[0] / blksize : dims[0];
+ const int B = B_blocked ? pdims[1] / blksize : dims[1];
+ const int C = C_blocked ? pdims[2] / blksize : dims[2];
+ const int D = m_d.ndims() > 3 ? dims[3] : 1;
+ const int E = m_d.ndims() > 4 ? dims[4] : 1;
+ const int F = m_d.ndims() > 5 ? dims[5] : 1;
+ const int inner_blk = blk.inner_nblks == 3 ? blk.inner_blks[2] : 1;
+
+ auto zeroize_tail = [&](data_t *d, const int tail_s) {
+ for (int b = tail_s; b < blksize; ++b)
+ d[b] = 0;
+ };
+ auto zeroize_tail_inner = [&](data_t *d, const int tail_s) {
+ for (int b1 = 0; b1 < blksize; ++b1)
+ for (int b2 = tail_s; b2 < blksize; ++b2)
+ d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2
+ + b1 % inner_blk]
+ = 0;
+ };
+ auto zeroize_tail_outer = [&](data_t *d, const int tail_s) {
+ for (int b1 = tail_s; b1 < blksize; ++b1)
+ for (int b2 = 0; b2 < blksize; ++b2)
+ d[(b1 / inner_blk) * blksize * inner_blk + inner_blk * b2
+ + b1 % inner_blk]
+ = 0;
+ };
+
+ if (c_tail_s) {
+ parallel_nd(A, B, D, E, F, [&](int a, int b, int d, int e, int f) {
+ auto x = &data[m_d.blk_off(a, b, C - 1, d, e, f)];
+ if (blk_kind == c)
+ zeroize_tail(x, c_tail_s);
+ else if (blk_kind == bc)
+ zeroize_tail_inner(x, c_tail_s);
+ else if (blk_kind == cb)
+ zeroize_tail_outer(x, c_tail_s);
+ });
+ }
+
+ if (b_tail_s) {
+ parallel_nd(A, C, D, E, F, [&](int a, int c, int d, int e, int f) {
+ auto x = &data[m_d.blk_off(a, B - 1, c, d, e, f)];
+ if (blk_kind == b)
+ zeroize_tail(x, b_tail_s);
+ else if (blk_kind == ab || blk_kind == cb)
+ zeroize_tail_inner(x, b_tail_s);
+ else if (blk_kind == ba || blk_kind == bc)
+ zeroize_tail_outer(x, b_tail_s);
+ });
+ }
+
+ if (a_tail_s) {
+ parallel_nd(B, C, D, E, F, [&](int b, int c, int d, int e, int f) {
+ auto x = &data[m_d.blk_off(A - 1, b, c, d, e, f)];
+ if (blk_kind == a)
+ zeroize_tail(x, a_tail_s);
+ else if (blk_kind == ba)
+ zeroize_tail_inner(x, a_tail_s);
+ else if (blk_kind == ab)
+ zeroize_tail_outer(x, a_tail_s);
+ });
+ }
+}
+
+/*
+ * all
+ */
+template <data_type_t dt>
+void typed_zero_pad_generic_blocked(
+ const memory_desc_wrapper &m_d, typename prec_traits<dt>::type *data) {
+ const int ndims = m_d.ndims();
+ const auto &dims = m_d.dims();
+ const auto &pdims = m_d.padded_dims();
+
+ const ptrdiff_t nelems = (ptrdiff_t)m_d.nelems(true);
+
+ /* [D_0] .. [D_k][D_k+1] .. [D_ndim - 1]
+ * | \ /
+ * | ---------------------
+ * has contiguous
+ * padding
+ *
+ * step <-- D_k+1 * ... * D_ndims-1
+ * step_dim <-- k
+ */
+
+ ptrdiff_t step = 1;
+ int step_dim = ndims - 1;
+ for (; step_dim >= 0; --step_dim) {
+ if (dims[step_dim] != pdims[step_dim])
+ break;
+ step *= dims[step_dim];
+ }
+
+ assert(step_dim >= 0 && "no zero padding is required");
+ if (step_dim < 0)
+ return;
+
+ parallel_nd(nelems / step, [&](ptrdiff_t e1) {
+ bool need_zero = false;
+
+ ptrdiff_t idx = e1;
+ for (int d = step_dim; d >= 0; --d) {
+ if (idx % pdims[d] >= dims[d]) {
+ need_zero = true;
+ break;
+ }
+ idx /= pdims[d];
+ }
+
+ if (need_zero) {
+ for (ptrdiff_t e0 = 0; e0 < step; ++e0)
+ data[m_d.off_l(e1 * step + e0, true)] = 0;
+ }
+ });
+}
+
+template <data_type_t dt>
+status_t cpu_memory_t::typed_zero_pad() const {
+ const memory_desc_wrapper mdw(md());
+
+ if (mdw.format_kind() != format_kind::blocked)
+ return unimplemented;
+
+ if (mdw.nelems(false) == mdw.nelems(true))
+ return success;
+
+ auto *data = (typename prec_traits<dt>::type *)data_;
+ auto blk = mdw.blocking_desc();
+
+ auto get_blksize = [&](int ind) {
+ int blksize = 1;
+ for (int i = 0; i < blk.inner_nblks; i++) {
+ if (blk.inner_idxs[i] == ind)
+ blksize *= blk.inner_blks[i];
+ }
+ return blksize;
+ };
+ const int blksize = get_blksize(blk.inner_idxs[0]);
+
+# define CASE(blksize_, blk_kind) \
+ do { \
+ if (blksize == blksize_) { \
+ typed_zero_pad_blk<dt, blk_kind, blksize_>(mdw, data); \
+ return success; \
+ } \
+ } while(0)
+
+ switch (blk.inner_nblks) {
+ case 1:
+ if (blk.inner_idxs[0] == 0) {
+ CASE(4, a);
+ CASE(8, a);
+ CASE(16, a);
+ } else if (blk.inner_idxs[0] == 1) {
+ CASE(4, b);
+ CASE(8, b);
+ CASE(16, b);
+ }
+ break;
+ case 2:
+ case 3:
+ if (!IMPLICATION(blk.inner_nblks == 3,
+ blk.inner_idxs[0] == blk.inner_idxs[2]))
+ break;
+
+ if (blk.inner_idxs[0] == 0 && blk.inner_idxs[1] == 1) {
+ CASE(4, ab);
+ CASE(8, ab);
+ CASE(16, ab);
+ } else if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 0) {
+ CASE(4, ba);
+ CASE(8, ba);
+ CASE(16, ba);
+ }
+ if (blk.inner_idxs[0] == 1 && blk.inner_idxs[1] == 2) {
+ CASE(4, bc);
+ CASE(8, bc);
+ CASE(16, bc);
+ } else if (blk.inner_idxs[0] == 2 && blk.inner_idxs[1] == 1) {
+ CASE(4, cb);
+ CASE(8, cb);
+ CASE(16, cb);
+ }
+ break;
+ default: break;
+ }
+
+# undef CASE
+
+ // the last line of defence
+ typed_zero_pad_generic_blocked<dt>(mdw, data);
+ return success;
+}
+
+status_t cpu_memory_t::zero_pad() const {
+ memory_desc_wrapper mdw(md());
+ const bool skip_zeroing = false
+ || data_ == nullptr
+ || mdw.is_zero()
+ || !mdw.is_blocking_desc();
+ if (skip_zeroing) return success;
+
+ switch (mdw.data_type()) {
+ case f32: return typed_zero_pad<f32>();
+ case s32: return typed_zero_pad<s32>();
+ case s8: return typed_zero_pad<s8>();
+ case u8: return typed_zero_pad<u8>();
+ default: assert(!"memory is undefined"); return unimplemented;
+ }
+ return unimplemented;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp
new file mode 100644
index 0000000000..2c01bcc6af
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_memory.hpp
@@ -0,0 +1,89 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_MEMORY_HPP
+#define CPU_MEMORY_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory.hpp"
+#include "memory_desc_wrapper.hpp"
+
+#include "cpu_engine.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_memory_t: public memory_t {
+ cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md, void *handle)
+ : memory_t(engine, md)
+ , own_data_(handle == MKLDNN_NATIVE_HANDLE_ALLOCATE)
+ , data_((char *)handle) {}
+
+ cpu_memory_t(cpu_engine_t *engine, const memory_desc_t *md)
+ : cpu_memory_t(engine, md, nullptr) {}
+
+ ~cpu_memory_t() { if (own_data_) free(data_); }
+
+ virtual status_t init() override {
+ if (own_data_) {
+ data_ = nullptr;
+ const size_t size = memory_desc_wrapper(this->md()).size();
+ if (size) {
+ data_ = (char *)malloc(size, 64);
+ if (data_ == nullptr)
+ return status::out_of_memory;
+ }
+ }
+ return zero_pad();
+ }
+
+ cpu_engine_t *engine() const { return (cpu_engine_t *)memory_t::engine(); }
+
+ virtual status_t get_data_handle(void **handle) const override {
+ *handle = static_cast<void *>(data_);
+ return status::success;
+ }
+
+ virtual mkldnn::impl::status_t set_data_handle(void *handle) override {
+ if (own_data_) { free(data_); own_data_ = false; }
+ data_ = static_cast<char *>(handle);
+ return zero_pad();
+ }
+
+ virtual mkldnn::impl::status_t zero_pad() const override;
+
+private:
+ bool own_data_;
+ char *data_;
+
+ template <mkldnn::impl::data_type_t>
+ mkldnn::impl::status_t typed_zero_pad() const;
+
+ cpu_memory_t(const cpu_memory_t &) = delete;
+ cpu_memory_t &operator=(const cpu_memory_t &) = delete;
+ cpu_memory_t &operator=(cpu_memory_t &&) = delete;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp
new file mode 100644
index 0000000000..ac2daa415e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_pooling_pd.hpp
@@ -0,0 +1,40 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_POOLING_PD_HPP
+#define CPU_POOLING_PD_HPP
+
+#include "pooling_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_pooling_fwd_pd_t: public pooling_fwd_pd_t {
+ using pooling_fwd_pd_t::pooling_fwd_pd_t;
+};
+
+struct cpu_pooling_bwd_pd_t: public pooling_bwd_pd_t {
+ using pooling_bwd_pd_t::pooling_bwd_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp
new file mode 100644
index 0000000000..56127f36c2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_primitive.hpp
@@ -0,0 +1,83 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_PRIMITIVE_HPP
+#define CPU_PRIMITIVE_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "primitive.hpp"
+#include "scratchpad.hpp"
+
+#define CTX_IN_MEM(type, arg) static_cast<type>(ctx.input(arg))
+#define CTX_OUT_MEM(type, arg) static_cast<type>(ctx.output(arg))
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_memory_t;
+
+struct cpu_primitive_t: public primitive_t {
+ cpu_primitive_t(const primitive_desc_t *pd,
+ bool use_global_scratchpad = false)
+ : primitive_t(pd)
+ , scratchpad_buffer_(nullptr)
+ , global_scratchpad_(nullptr)
+ {
+ const size_t scratchpad_size =
+ this->pd()->scratchpad_size(scratchpad_mode::library);
+
+ if (scratchpad_size) {
+ if (use_global_scratchpad)
+ global_scratchpad_ = create_scratchpad(scratchpad_size);
+ else
+ scratchpad_buffer_ = malloc(scratchpad_size, 64);
+ }
+ }
+
+ virtual ~cpu_primitive_t() {
+ delete global_scratchpad_;
+ free(scratchpad_buffer_);
+ }
+
+protected:
+ memory_tracking::grantor_t scratchpad(const exec_ctx_t &ctx) const {
+ void *ptr = nullptr;
+ if (pd()->attr()->scratchpad_mode_ == scratchpad_mode::user) {
+ ptr = CTX_OUT_MEM(void *, MKLDNN_ARG_SCRATCHPAD);
+ } else {
+ ptr = global_scratchpad_
+ ? global_scratchpad_->get() : scratchpad_buffer_;
+ }
+
+ return pd()->scratchpad_registry().grantor(ptr);
+ }
+
+private:
+ void *scratchpad_buffer_;
+ scratchpad_t *global_scratchpad_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp
new file mode 100644
index 0000000000..1d41ac5cea
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.cpp
@@ -0,0 +1,544 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn_thread.hpp"
+#include "mkldnn_types.h"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "cpu_reducer.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace memory_tracking::names;
+
+void reduce_balancer_t::balance() {
+ using namespace nstl;
+ using namespace utils;
+
+ assert(nthr_ > 0 && job_size_ > 0 && njobs_ > 0 && reduction_size_ > 0);
+
+ const int job_complexity = 1;
+
+ const int min_njobs_per_group = max(1, njobs_ / nthr_);
+ const int max_njobs_per_group = max(1,
+ static_cast<int>(max_buffer_size_ / (nthr_ * job_size_)));
+
+ /* initial guess */
+ int ngroups = min(njobs_ / min_njobs_per_group, nthr_);
+ int nthr_per_group = syncable_ ? min(nthr_ / ngroups, reduction_size_) : 1;
+ int njobs_per_group_ub = div_up(njobs_, ngroups);
+
+ /* rough upper-bound estimation, will be fixed during brute force */
+ size_t thread_complexity_ub = njobs_ * job_size_ * reduction_size_;
+
+ /* brute force parameters for the best balance... */
+ for (int c_njobs_per_group = min_njobs_per_group;
+ c_njobs_per_group < njobs_; ++c_njobs_per_group) {
+ /* current assumption */
+ int c_ngroups = min(njobs_ / c_njobs_per_group, nthr_);
+ int c_nthr_per_group = syncable_
+ ? min(nthr_ / c_ngroups, reduction_size_) : 1;
+ int c_njobs_per_group_ub = div_up(njobs_, c_ngroups);
+
+ if (c_nthr_per_group > 1 && c_njobs_per_group_ub > max_njobs_per_group)
+ continue;
+
+ int c_thread_reduction_ub = div_up(reduction_size_, c_nthr_per_group);
+ size_t c_group_size_ub = job_size_ * c_njobs_per_group_ub;
+ size_t c_thread_complexity_ub = c_group_size_ub * (
+ job_complexity * c_thread_reduction_ub
+ + (c_nthr_per_group != 1));
+
+ if (c_thread_complexity_ub < thread_complexity_ub) {
+ ngroups = c_ngroups;
+ nthr_per_group = c_nthr_per_group;
+ njobs_per_group_ub = c_njobs_per_group_ub;
+ thread_complexity_ub = c_thread_complexity_ub;
+ }
+ }
+
+ assert(njobs_per_group_ub <= max_njobs_per_group || nthr_per_group == 1);
+ assert(ngroups * nthr_per_group <= nthr_);
+ assert((size_t)njobs_per_group_ub * job_size_ * nthr_ <= max_buffer_size_
+ || nthr_per_group == 1); /* no reduction buffer overflow */
+ assert(IMPLICATION(!syncable_, nthr_per_group == 1));
+
+ ngroups_ = ngroups;
+ nthr_per_group_ = nthr_per_group;
+ njobs_per_group_ub_ = njobs_per_group_ub;
+}
+
+/* reducer jit-ted driver */
+
+using namespace Xbyak;
+
+template <impl::data_type_t data_type>
+struct reducer_2d_driver_t: public c_compatible {
+ typedef typename prec_traits<data_type>::type data_t;
+
+ reducer_2d_driver_t(int n_src, size_t src_ld,
+ size_t src_step, size_t dst_step, bool nullify_dst)
+ : n_src_(n_src), src_ld_(src_ld), src_step_(src_step)
+ , dst_step_(dst_step), nullify_dst_(nullify_dst), ker_(nullptr) {}
+ virtual ~reducer_2d_driver_t() {}
+ void operator()(data_t *dst, const data_t *srcs, size_t ny, size_t nx)
+ { assert(ker_); ker_(dst, srcs, ny, nx); }
+
+protected:
+ int n_src_;
+ size_t src_ld_, src_step_, dst_step_;
+ bool nullify_dst_;
+ void (*ker_)(data_t *dst, const data_t *srcs, size_t ny, size_t nx);
+};
+
+template <impl::data_type_t data_type, cpu_isa_t isa>
+struct reducer_2d_driver_f_s_32_t: public reducer_2d_driver_t<data_type>,
+ public jit_generator
+{
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(reducer_2d_driver_f_s_32_t)
+
+ /* cpu specific part */
+ using Vmm = typename utils::conditional<isa == avx2, Ymm, Zmm>::type;
+ const AddressFrame &vmmword = (isa == avx2) ? yword : zword;
+ void uni_vadd(const Xmm& x1, const Xmm& x2, const Operand& op)
+ { if (data_type == data_type::f32) vaddps(x1, x2, op);
+ else vpaddd(x1, x2, op); }
+ void uni_add(const Xmm& x1, const Operand& op)
+ { if (data_type == data_type::f32) addss(x1, op); else paddd(x1, op); }
+
+ const int vlen = cpu_isa_traits<isa>::vlen;
+ const int typesize
+ = sizeof(typename mkldnn::impl::prec_traits<data_type>::type);
+ Xbyak::Reg64 reg_dst = abi_param1;
+ Xbyak::Reg64 reg_src = abi_param2;
+ Xbyak::Reg64 reg_ny = abi_param3;
+ Xbyak::Reg64 reg_nx = abi_param4;
+
+ Xbyak::Reg64 reg_x = rax;
+ Xbyak::Reg64 reg_src_id = r10;
+
+ reducer_2d_driver_f_s_32_t(int n_src, size_t src_ld, size_t src_step,
+ size_t dst_step, bool nullify_dst)
+ : reducer_2d_driver_t<data_type>(n_src, src_ld, src_step,
+ dst_step, nullify_dst)
+ { generate(); }
+
+ void nullify_dst(int nloads, int load_len) {
+ UNUSED(load_len);
+ for (int i = 0; i < nloads; ++i)
+ uni_vpxor(Vmm(i), Vmm(i), Vmm(i));
+ /* prefetches[dst] ? */
+ }
+
+ void load_dst(int nloads, int load_len) {
+ for (int i = 0; i < nloads; ++i) {
+ if (load_len == typesize)
+ movd(Xmm(i), ptr[reg_dst + i * load_len]);
+ else if (load_len == vlen)
+ vmovups(Vmm(i), ptr[reg_dst + i * load_len]);
+ else
+ assert(!"unsupported");
+ }
+ }
+
+ void store_dst(int nloads, int load_len) {
+ for (int i = 0; i < nloads; ++i) {
+ if (load_len == typesize)
+ movd(ptr[reg_dst + i * load_len], Xmm(i));
+ else if (load_len == vlen)
+ vmovups(ptr[reg_dst + i * load_len], Vmm(i));
+ else
+ assert(!"unsupported");
+ }
+ }
+
+ void accumulate(int nloads, int load_len, size_t base_off) {
+ for (int i = 0; i < nloads; ++i) {
+ size_t off = base_off + i * load_len;
+
+ if (load_len == typesize)
+ uni_add(Xmm(i), ptr[reg_src + off]);
+ else if (load_len == vlen)
+ uni_vadd(Vmm(i), Vmm(i), vmmword[reg_src + off]);
+ else
+ assert(!"unsupported");
+ }
+ }
+
+ void loop_x() {
+ const int nloads[] = {cpu_isa_traits<isa>::n_vregs, 1, 1};
+ const int nbranches = sizeof(nloads) / sizeof(nloads[0]);
+
+ const int load_len[nbranches] = {vlen, vlen, typesize};
+ Label loop_x_label[nbranches + 1];
+
+ mov(reg_x, reg_nx);
+
+ for (int id = 0; id < nbranches; ++id) {
+ L(loop_x_label[id]);
+
+ cmp(reg_x, nloads[id] * load_len[id]);
+ jl(loop_x_label[id + 1], T_NEAR);
+
+ if (this->nullify_dst_)
+ nullify_dst(nloads[id], load_len[id]);
+ else
+ load_dst(nloads[id], load_len[id]);
+
+ if (nloads[id] > 1) {
+ Label loop_srcs;
+ mov(reg_src_id, this->n_src_);
+ L(loop_srcs);
+
+ accumulate(nloads[id], load_len[id], 0);
+ add(reg_src, this->src_ld_ * typesize);
+
+ dec(reg_src_id);
+ jnz(loop_srcs, T_NEAR);
+
+ sub(reg_src, this->n_src_ * this->src_ld_ * typesize);
+ } else {
+ for (int src_id = 0; src_id < this->n_src_; ++src_id) {
+ const size_t base_off = src_id * this->src_ld_ * typesize;
+ accumulate(nloads[id], load_len[id], base_off);
+ }
+ }
+
+ store_dst(nloads[id], load_len[id]);
+
+ add(reg_src, nloads[id] * load_len[id]);
+ add(reg_dst, nloads[id] * load_len[id]);
+
+ sub(reg_x, nloads[id] * load_len[id]);
+
+ jmp(loop_x_label[id], T_NEAR);
+ }
+
+ L(loop_x_label[nbranches]);
+
+ /* restore address registers */
+ sub(reg_src, reg_nx);
+ sub(reg_dst, reg_nx);
+ }
+
+ void generate() {
+ assert(isa == avx2 || isa == avx512_common || isa == avx512_mic);
+
+ preamble();
+
+ shl(reg_nx, 2);
+
+ Label ny_loop;
+ L(ny_loop);
+
+ loop_x();
+
+ add(reg_dst, this->dst_step_ * typesize);
+ add(reg_src, this->src_step_ * typesize);
+
+ dec(reg_ny);
+ jnz(ny_loop, T_NEAR);
+
+ postamble();
+ this->ker_ = reinterpret_cast<decltype(this->ker_)>(
+ const_cast<uint8_t*>(this->getCode()));
+ }
+};
+
+template <impl::data_type_t data_type>
+inline reducer_2d_driver_t<data_type> *create_reduce_2d_drv(int n_src,
+ size_t src_ld, size_t src_step, size_t dst_step, bool nullify_dst) {
+ if (mayiuse(avx512_common))
+ return new reducer_2d_driver_f_s_32_t<data_type, avx512_common>(n_src,
+ src_ld, src_step, dst_step, nullify_dst);
+ else if (mayiuse(avx2))
+ return new reducer_2d_driver_f_s_32_t<data_type, avx2>(n_src, src_ld,
+ src_step, dst_step, nullify_dst);
+ assert(!"unimplemented");
+ return nullptr;
+}
+
+/* cpu_reducer_t */
+
+template <impl::data_type_t data_type>
+void cpu_reducer_t<data_type>::conf_t::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad) const {
+ if (balancer_.nthr_per_group_ == 1) return;
+
+ const size_t space_size = balancer_.ngroups_
+ * (balancer_.nthr_per_group_ - 1)
+ * cpu_reducer_t<data_type>::space_per_thread(balancer_);
+ scratchpad.book(key_reducer_space, sizeof(data_t) * space_size, PAGE_4K);
+ scratchpad.book(key_reducer_space_bctx,
+ sizeof(simple_barrier::ctx_t) * balancer_.ngroups_);
+}
+
+template <impl::data_type_t data_type>
+cpu_reducer_t<data_type>::cpu_reducer_t(const conf_t &conf)
+ : conf_(conf), drv_(nullptr)
+{
+ if (balancer().nthr_per_group_ == 1) return;
+
+ drv_ = create_reduce_2d_drv<data_type>(balancer().nthr_per_group_ - 1,
+ space_per_thread(balancer()), 0, 0, false);
+}
+
+template <impl::data_type_t data_type>
+cpu_reducer_t<data_type>::~cpu_reducer_t() { delete drv_; }
+
+template <impl::data_type_t data_type>
+typename cpu_reducer_t<data_type>::data_t *
+cpu_reducer_t<data_type>::get_local_ptr(int ithr, data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const int id_in_grp = balancer().id_in_group(ithr);
+
+ /* threads 0 from each group writes directly to the destination */
+ if (id_in_grp == 0)
+ return dst + balancer().ithr_job_off(ithr) * balancer().job_size_;
+
+ const int grp_id = balancer().group_id(ithr);
+ const int offset_factor = grp_id * (balancer().nthr_per_group_ - 1)
+ + (id_in_grp - 1);
+
+ auto space = scratchpad.template get<data_t>(key_reducer_space);
+ return space + offset_factor * space_per_thread(balancer());
+}
+
+template <impl::data_type_t data_type>
+void cpu_reducer_t<data_type>::reduce_nolock(int ithr, data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ bool redundant_reduction = balancer().nthr_per_group_ == 1
+ || balancer().idle(ithr);
+ if (redundant_reduction) return;
+
+#ifdef SIMPLE_IMPL
+ if (balancer().id_in_group(ithr) != 0)
+ return; /* only threads 0 do the reduction */
+
+ const int njobs_in_grp = balancer().ithr_njobs(ithr);
+ data_t *d = get_local_ptr(ithr, dst, scratchpad);
+ for (int id_in_grp = 1; id_in_grp < balancer_.nthr_per_group_; ++id_in_grp)
+ {
+ const data_t *space = get_local_ptr(ithr + id_in_grp, dst, scratchpad);
+ for (size_t i = 0; i < (size_t)njobs_in_grp * balancer().job_size_; ++i)
+ d[i] += space[i];
+ }
+#else
+ using namespace utils;
+
+ const int id_in_grp = balancer().id_in_group(ithr);
+ const int njobs_in_grp = balancer().ithr_njobs(ithr);
+ const size_t cl = 64 / sizeof(data_t);
+
+ const size_t reduction_size = njobs_in_grp * balancer().job_size_;
+ size_t start{0}, end{0};
+ balance211(div_up(reduction_size, cl), balancer().nthr_per_group_,
+ id_in_grp, start, end);
+
+ if (start == end) return;
+
+ data_t *d = get_local_ptr(ithr - id_in_grp, dst, scratchpad) + start * cl;
+ const data_t *space = get_local_ptr(ithr - id_in_grp + 1, dst, scratchpad)
+ + start * cl;
+ const size_t len = nstl::min(end * cl, reduction_size) - start * cl;
+
+ (*drv_)(d, space, 1, len);
+#endif
+}
+
+template struct cpu_reducer_t<data_type::f32>;
+template struct cpu_reducer_t<data_type::s32>;
+
+/* cpu_reducer_2d_t */
+
+template <impl::data_type_t data_type>
+void cpu_reducer_2d_t<data_type>::conf_t::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad) const {
+ if (balancer_.nthr_per_group_ == 1) return;
+
+ const size_t space_size = balancer_.ngroups_ * balancer_.nthr_per_group_
+ * cpu_reducer_2d_t<data_type>::space_per_thread(balancer_);
+ scratchpad.book(key_reducer_space, sizeof(data_t) * space_size);
+ scratchpad.book(key_reducer_space_bctx,
+ sizeof(simple_barrier::ctx_t) * balancer_.ngroups_);
+}
+
+template <impl::data_type_t data_type>
+cpu_reducer_2d_t<data_type>::cpu_reducer_2d_t(const conf_t &conf)
+ : conf_(conf), drv_(nullptr)
+{
+ if (balancer().nthr_per_group_ == 1) return;
+
+ drv_ = create_reduce_2d_drv<data_type>(balancer().nthr_per_group_,
+ space_per_thread(balancer()), conf_.job_size_x_, conf_.dst_x_,
+ true);
+}
+
+template <impl::data_type_t data_type>
+cpu_reducer_2d_t<data_type>::~cpu_reducer_2d_t() { delete drv_; }
+
+template <impl::data_type_t data_type>
+typename cpu_reducer_2d_t<data_type>::data_t *cpu_reducer_2d_t<data_type>::
+get_local_ptr(int ithr, const memory_tracking::grantor_t &scratchpad) const {
+ const int id_in_grp = balancer().id_in_group(ithr);
+ const int grp_id = balancer().group_id(ithr);
+ const int offset_factor = grp_id * balancer().nthr_per_group_ + id_in_grp;
+ auto space = scratchpad.template get<data_t>(key_reducer_space);
+ return space + offset_factor * space_per_thread(balancer());
+}
+
+template <impl::data_type_t data_type>
+int cpu_reducer_2d_t<data_type>::choose_x_blocking(int nx, int ny,
+ int nthr_per_grp) const {
+ // find x_blocking for better balance reducing work between threads
+ assert(conf_.x_block_ > 0 && nx > conf_.x_block_
+ && nx % conf_.x_block_ == 0);
+ int x_blocking = nx / conf_.x_block_;
+ int min_x_blocking =
+ utils::div_up(x_blocking, nstl::max(1, nthr_per_grp / ny));
+ while (true) {
+ if (x_blocking % 2 == 0 && x_blocking >= min_x_blocking * 2)
+ x_blocking /= 2;
+ else if (x_blocking % 3 == 0 && x_blocking >= min_x_blocking * 3)
+ x_blocking /= 3;
+ else
+ break;
+ }
+ if (x_blocking >= min_x_blocking * 4) x_blocking = 1;
+ x_blocking *= conf_.x_block_;
+ return x_blocking;
+}
+
+template <impl::data_type_t data_type>
+void cpu_reducer_2d_t<data_type>::reduce_block(const data_t* space_base,
+ data_t *dst, int job, int start_y, int start_x,
+ int ny_start, int nx_start, int ny_step, int nx_step) const {
+ data_t *d = dst + (start_y + ny_start) * conf_.dst_x_
+ + start_x + nx_start;
+ const data_t *space = space_base + job * balancer().job_size_
+ + ny_start * conf_.job_size_x_ + nx_start;
+#ifdef SIMPLE_IMPL
+ for (int idg = 0; idg < balancer().nthr_per_group_; ++idg) {
+ const data_t *w = &space[idg * space_per_thread(balancer())];
+ for (int y = 0; y < ny_step; ++y)
+ for (int x = 0; x < nx_step; ++x) {
+ d[y * conf_.dst_x_ + x]
+ = (idg == 0 ? 0 : d[y * conf_.dst_x_ + x])
+ + w[y * conf_.job_size_x_ + x];
+ }
+ }
+#else
+ (*drv_)(d, space, ny_step, nx_step);
+#endif
+}
+
+template <impl::data_type_t data_type>
+void cpu_reducer_2d_t<data_type>::reduce_nolock(int ithr, data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ bool redundant_reduction = balancer().nthr_per_group_ == 1
+ || balancer().idle(ithr);
+ if (redundant_reduction) return;
+
+ const int id_in_grp = balancer().id_in_group(ithr);
+ const int njobs_in_grp = balancer().ithr_njobs(ithr);
+ const int njobs_x = utils::div_up(conf_.dst_x_, conf_.job_size_x_);
+ const int global_job_start = balancer().ithr_job_off(ithr);
+
+ const data_t *space_base = get_local_ptr(ithr - id_in_grp, scratchpad);
+
+ const int pr_grps = nstl::min(njobs_in_grp, balancer().nthr_per_group_);
+ const int pr_nthr_per_grp = balancer().nthr_per_group_ / pr_grps;
+
+ if (id_in_grp >= pr_grps * pr_nthr_per_grp)
+ return; /* idle */
+
+ const int pr_my_grp = id_in_grp / pr_nthr_per_grp;
+ const int pr_my_id = id_in_grp % pr_nthr_per_grp;
+
+ int pr_job_start{0}, pr_job_end{0};
+ balance211(njobs_in_grp, pr_grps, pr_my_grp, pr_job_start, pr_job_end);
+
+ for (int j = pr_job_start; j < pr_job_end; ++j) {
+ const int global_job = global_job_start + j;
+ const int j_y = global_job / njobs_x;
+ const int j_x = global_job % njobs_x;
+ const int start_y = j_y * conf_.job_size_y_;
+ const int start_x = j_x * conf_.job_size_x_;
+ const int ny = nstl::min(conf_.dst_y_ - start_y, conf_.job_size_y_);
+ const int nx = nstl::min(conf_.dst_x_ - start_x, conf_.job_size_x_);
+ int x_blocking = choose_x_blocking(nx, ny, pr_nthr_per_grp);
+
+ int nxy_start{0}, nxy_end{0};
+ balance211(ny * nx / x_blocking, pr_nthr_per_grp, pr_my_id,
+ nxy_start, nxy_end);
+ if (nxy_start == nxy_end) continue;
+ nxy_start *= x_blocking;
+ nxy_end *= x_blocking;
+
+ int nxy = nxy_start;
+ if (nxy % nx != 0) {
+ int nx_step = nstl::min(nx - nxy % nx, nxy_end - nxy);
+ reduce_block(space_base, dst, j, start_y, start_x,
+ nxy / nx, nxy % nx, 1, nx_step);
+ nxy += nx_step;
+ }
+ if ((nxy_end - nxy) > nx) {
+ int ny_step = (nxy_end - nxy) / nx;
+ reduce_block(space_base, dst, j, start_y, start_x,
+ nxy / nx, nxy % nx, ny_step, nx);
+ nxy += nx * ny_step;
+ }
+ if ((nxy_end - nxy) > 0) {
+ reduce_block(space_base, dst, j, start_y, start_x,
+ nxy / nx, nxy % nx, 1, nxy_end - nxy);
+ }
+ }
+}
+
+template struct cpu_reducer_2d_t<data_type::f32>;
+template struct cpu_reducer_2d_t<data_type::s32>;
+
+/* accumulator section */
+
+template <impl::data_type_t data_type>
+cpu_accumulator_1d_t<data_type>::cpu_accumulator_1d_t(): drv_(nullptr) {
+ drv_ = create_reduce_2d_drv<data_type>(1, 0, 0, 0, false);
+}
+
+template <impl::data_type_t data_type>
+cpu_accumulator_1d_t<data_type>::~cpu_accumulator_1d_t() {
+ delete drv_;
+}
+
+template <impl::data_type_t data_type>
+void cpu_accumulator_1d_t<data_type>::accumulate(data_t *dst,
+ const data_t *src, size_t size) {
+ (*drv_)(dst, src, 1, size);
+}
+
+template struct cpu_accumulator_1d_t<data_type::f32>;
+template struct cpu_accumulator_1d_t<data_type::s32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp
new file mode 100644
index 0000000000..27f5939cd2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reducer.hpp
@@ -0,0 +1,334 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REDUCER_HPP
+#define CPU_REDUCER_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "mkldnn_types.h"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+
+#include "cpu_barrier.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+/** class to perform balancing over 3D array
+ *
+ * Conceptually the reduction happens according to the picture below:
+ *
+ * <--job_size->
+ * +-----------+ +-----------+ +-----------+ ^
+ * | | | | | | |
+ * | | | | | | |
+ * | 1 | | 2 | . . . | njobs | | reduction_size
+ * | | | | | | |
+ * | | | | | | |
+ * +-----------+ +-----------+ +-----------+ v
+ *
+ * | | | | | | | | |
+ * v v v v v v v v v
+ * ===================================================== vertical reduction
+ *
+ * +-----------+ +-----------+ . . . +-----------+ result
+ *
+ * In a simple case the result must be contiguous in memory.
+ * @class cpu_reducer_t is an implementation.
+ *
+ * Threads are divided into groups. The groups are independent of each other.
+ * Each group may work on several jobs (the distribution is not uniform, since
+ * njobs might be not a multiple of groups). Threads within a group work on
+ * different parts of the reduction dimension. Thread 0 in each group is called
+ * master (@sa reduce_balancer_t::master()).
+ *
+ * If threading driver does not allow sync between sub-group of threads (e.g.
+ * Intel(R) TBB) the # of thread per group is enforced to be 1.
+ */
+struct reduce_balancer_t {
+ reduce_balancer_t() { init(1, 1, 1, 1, 0); } /* trivial balance */
+ reduce_balancer_t(int nthr, int job_size, int njobs, int reduction_size,
+ size_t max_buffer_size)
+ { init(nthr, job_size, njobs, reduction_size, max_buffer_size); }
+
+ reduce_balancer_t &init(int nthr, int job_size, int njobs,
+ int reduction_size, size_t max_buffer_size)
+ {
+ syncable_ = mkldnn_thr_syncable();
+ nthr_ = nthr;
+ job_size_ = job_size;
+ njobs_ = njobs;
+ reduction_size_ = reduction_size;
+ max_buffer_size_ = max_buffer_size;
+ balance();
+ return *this;
+ }
+
+ bool syncable_;
+ int nthr_;
+ int job_size_, njobs_, reduction_size_;
+
+ int ngroups_; /** number of independent work (thread) groups */
+ int nthr_per_group_; /** number of threads within a single work group */
+ int njobs_per_group_ub_; /** the max # of jobs within a work group */
+
+ bool master(int ithr) const { return id_in_group(ithr) == 0; }
+ bool idle(int ithr) const { return ithr >= nthr_per_group_ * ngroups_; }
+
+ int group_id(int ithr) const { return ithr / nthr_per_group_; }
+ int id_in_group(int ithr) const { return ithr % nthr_per_group_; }
+
+ int grp_njobs(int grp) const {
+ if (grp >= ngroups_) return 0;
+ return njobs_ / ngroups_ + (grp < njobs_ % ngroups_);
+ }
+ int grp_job_off(int grp) const {
+ if (grp >= ngroups_) return njobs_;
+ return njobs_ / ngroups_ * grp + nstl::min(grp, njobs_ % ngroups_);
+ }
+
+ int ithr_njobs(int ithr) const { return grp_njobs(group_id(ithr)); }
+ int ithr_job_off(int ithr) const { return grp_job_off(group_id(ithr)); }
+
+private:
+ size_t max_buffer_size_;
+ void balance();
+};
+
+/** forward declaration of reduce driver */
+template <impl::data_type_t data_type> struct reducer_2d_driver_t;
+
+/** class to perform a reduction over 3D array
+ *
+ * Balancing is based on @class reduce_balancer_t.
+ * Restrictions: the result of the reduction must be contiguous in memory. *
+ * The reduction happens according to the picture below (once more):
+ *
+ * <--job_size->
+ * +-----------+ +-----------+ +-----------+ ^
+ * | | | | | | |
+ * | | | | | | |
+ * | 1 | | 2 | . . . | njobs | | reduction_size
+ * | | | | | | |
+ * | | | | | | |
+ * +-----------+ +-----------+ +-----------+ v
+ *
+ * | | | | | | | | |
+ * v v v v v v v v v
+ * ===================================================== vertical reduction
+ *
+ * +-----------+ +-----------+ . . . +-----------+ (contiguous) result
+ *
+ * An example how work might be shared is shown below.
+ *
+ * In this example group 0 owns 2 (independent) jobs -- 2 big squares.
+ * The number of threads per group is also 2 (thread 0 of group 0 and thread 1
+ * of group 0). Master threads (i.e. threads with id 0 in corresponding group)
+ * from each group put the partial result directly into destination memory,
+ * while all the other threads with-in the group use workspace (on the picture
+ * the only thread 1). Once intermediate results obtained each group reduces
+ * corresponding part (own jobs) to the destination memory.
+ *
+ * <------- group 0 ------->
+ *
+ * +-----------+ +-----------+ ^
+ * | | | | | thread 0 of reduces to the dest-memory
+ * | | | | | group 0 +-----------+ +-----------+
+ * |- - - - - -| |- - - - - -| X
+ * | | | | | thread 1 of reduces to workspace[tid=1]:
+ * | | | | | group 0 +-----------+ +-----------+
+ * +-----------+ +-----------+ v
+ * | | | | | |
+ * v v v v v v
+ * ((barrier)) =============================
+ *
+ * dest-memory: +-----------+ +-----------+
+ */
+template <impl::data_type_t data_type>
+struct cpu_reducer_t {
+ typedef typename prec_traits<data_type>::type data_t;
+
+ struct conf_t {
+ conf_t() = default;
+ conf_t &init(const reduce_balancer_t &balancer)
+ { balancer_ = balancer; return *this; }
+
+ void init_scratchpad(memory_tracking::registrar_t &scratchpad) const;
+
+ reduce_balancer_t balancer_;
+ };
+
+ cpu_reducer_t(const conf_t &conf);
+ ~cpu_reducer_t();
+
+ /** initializes reducer.
+ * Must be called from a single thread prior to actual usage */
+ void init(const memory_tracking::grantor_t &scratchpad) const {
+ if (balancer().nthr_per_group_ == 1) return;
+
+ auto bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ memory_tracking::names::key_reducer_space_bctx);
+ for (int i = 0; i < balancer().ngroups_; ++i)
+ simple_barrier::ctx_init(&bctx[i]);
+ }
+
+ /** for given thread returns the pointer where to put partial results.
+ * Reduction destination @p dst must be provided as well (master threads
+ * from each group will use it for partial result to reduce memory
+ * pressure).
+ *
+ * @note: job offset is already applied by get_local_ptr(), which means all
+ * threads should start writing from the very beginning of returned
+ * address.
+ */
+ data_t *get_local_ptr(int ithr, data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const;
+
+ /** performs the reduction with built-in synchronization. */
+ void reduce(int ithr, data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ bool redundant_reduction = balancer().nthr_per_group_ == 1
+ || balancer().idle(ithr);
+ if (redundant_reduction) return;
+
+ auto bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ memory_tracking::names::key_reducer_space_bctx);
+ simple_barrier::barrier(&bctx[balancer().group_id(ithr)],
+ balancer().nthr_per_group_);
+
+ reduce_nolock(ithr, dst, scratchpad);
+ }
+
+ const reduce_balancer_t &balancer() const { return conf_.balancer_; }
+
+private:
+ static size_t space_per_thread(const reduce_balancer_t &balancer)
+ { return balancer.njobs_per_group_ub_ * balancer.job_size_; }
+
+ /* The scratchpad is organized as follows:
+ *
+ * data_t space[nthr_][njobs_per_group_ub_][jobs_size_];
+ * simple_barrier::ctx_t barriers[groups_]; */
+
+ const conf_t conf_;
+ reducer_2d_driver_t<data_type> *drv_;
+
+ void reduce_nolock(int ithr, data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const;
+};
+
+template <impl::data_type_t data_type>
+struct cpu_reducer_2d_t {
+ typedef typename prec_traits<data_type>::type data_t;
+
+ struct conf_t {
+ conf_t() = default;
+ conf_t &init(const reduce_balancer_t &balancer, int job_size_x,
+ int job_size_y, int x_block, int dst_x, int dst_y) {
+ balancer_ = balancer;
+ job_size_x_ = job_size_x;
+ job_size_y_ = job_size_y;
+ x_block_ = x_block;
+ dst_x_ = dst_x;
+ dst_y_ = dst_y;
+ return *this;
+ }
+
+ void init_scratchpad(memory_tracking::registrar_t &scratchpad) const;
+
+ reduce_balancer_t balancer_;
+ int job_size_x_, job_size_y_, x_block_, dst_x_, dst_y_;
+ };
+
+ cpu_reducer_2d_t(const conf_t &conf);
+ ~cpu_reducer_2d_t();
+
+ /** initializes reducer.
+ * Must be called from a single thread prior to actual usage */
+ void init(const memory_tracking::grantor_t &scratchpad) const {
+ if (balancer().nthr_per_group_ == 1) return;
+
+ auto bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ memory_tracking::names::key_reducer_space_bctx);
+ for (int i = 0; i < balancer().ngroups_; ++i)
+ simple_barrier::ctx_init(&bctx[i]);
+ }
+
+ /** for given thread returns the pointer where to put partial results */
+ data_t *get_local_ptr(int ithr,
+ const memory_tracking::grantor_t &scratchpad) const;
+
+ /** performs the reduction with built-in synchronization. */
+ void reduce(int ithr, data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ bool redundant_reduction = balancer().nthr_per_group_ == 1
+ || balancer().idle(ithr);
+ if (redundant_reduction) return;
+
+ auto bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ memory_tracking::names::key_reducer_space_bctx);
+ simple_barrier::barrier(&bctx[balancer().group_id(ithr)],
+ balancer().nthr_per_group_);
+
+ reduce_nolock(ithr, dst, scratchpad);
+ }
+
+ const reduce_balancer_t &balancer() const { return conf_.balancer_; }
+
+private:
+ static size_t space_per_thread(const reduce_balancer_t &balancer)
+ { return balancer.njobs_per_group_ub_ * balancer.job_size_; }
+
+ /* The scratchpad is organized as follows:
+ *
+ * data_t space[nthr_][njobs_per_group_ub_][jobs_size_];
+ * simple_barrier::ctx_t barriers[groups_]; */
+
+ const conf_t conf_;
+ reducer_2d_driver_t<data_type> *drv_;
+
+ int choose_x_blocking(int nx, int ny, int nthr_per_grp) const;
+ void reduce_block(const data_t* space_base, data_t *dst,
+ int job, int start_y, int start_x,
+ int ny_start, int nx_start, int ny_step, int nx_step) const;
+ void reduce_nolock(int ithr, data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const;
+};
+
+/** simple 1d accumulator: y[:] += x[:] */
+template <impl::data_type_t data_type>
+struct cpu_accumulator_1d_t {
+ typedef typename prec_traits<data_type>::type data_t;
+
+ cpu_accumulator_1d_t();
+ ~cpu_accumulator_1d_t();
+ void accumulate(data_t *dst, const data_t *src, size_t size);
+
+ reducer_2d_driver_t<data_type> *drv_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp
new file mode 100644
index 0000000000..82be70353d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder.cpp
@@ -0,0 +1,262 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "cpu_engine.hpp"
+#include "cpu_primitive.hpp"
+#include "cpu_reorder_pd.hpp"
+#include "cpu_memory.hpp"
+#include "type_helpers.hpp"
+
+#include "cpu/jit_uni_reorder.hpp"
+#include "cpu/simple_reorder.hpp"
+#include "cpu/wino_reorder.hpp"
+#include "cpu/rnn/rnn_reorders.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using rpd_create_f = mkldnn::impl::engine_t::reorder_primitive_desc_create_f;
+
+namespace {
+using namespace mkldnn::impl::data_type;
+using namespace mkldnn::impl::format_tag;
+
+#define REG_SR(idt, ifmt, odt, ofmt, ...) \
+ simple_reorder_t<idt, ifmt, odt, ofmt, __VA_ARGS__>::pd_t::create
+
+#define REG_SR_BIDIR(idt, ifmt, odt, ofmt) \
+ REG_SR(idt, ifmt, odt, ofmt, fmt_order::keep), \
+ REG_SR(idt, ifmt, odt, ofmt, fmt_order::reverse)
+
+#define REG_SR_DIRECT_COPY(idt, odt) \
+ REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy), \
+ REG_SR(idt, any, odt, any, fmt_order::any, spec::direct_copy_except_dim_0)
+
+static const rpd_create_f cpu_reorder_impl_list[] = {
+ /* winograd */
+ wino_reorder_t<f32, f32>::pd_t::create,
+ //wino_reorder_t<f32, s8>::pd_t::create,
+
+ /* rnn reorders */
+ rnn_data_reorder_t<f32, u8>::pd_t::create,
+ rnn_weights_reorder_t<f32, f32>::pd_t::create,
+ rnn_weights_reorder_t<f32, s8>::pd_t::create,
+
+ /* conv reorders w/ compensation */
+ REG_SR(f32, any, s8, hwio, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(f32, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, any, s8, hwio, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, any, s8, hwigo, fmt_order::keep, spec::conv_s8s8),
+
+ REG_SR(f32, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(f32, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, oiw, s8, OIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, goiw, s8, gOIw4i16o4i, fmt_order::keep, spec::conv_s8s8),
+
+ REG_SR(f32, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(f32, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, oihw, s8, OIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, goihw, s8, gOIhw4i16o4i, fmt_order::keep, spec::conv_s8s8),
+
+ REG_SR(f32, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, goihw, s8, gOIhw2i8o4i, fmt_order::keep, spec::conv_s8s8),
+
+ REG_SR(f32, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, goihw, s8, gOIhw4o4i, fmt_order::keep, spec::conv_s8s8),
+
+ REG_SR(f32, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, goiw, s8, Goiw16g, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(f32, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8),
+ REG_SR(s8, goihw, s8, Goihw16g, fmt_order::keep, spec::conv_s8s8),
+
+ /* regular reorders */
+
+#if defined(__INTEL_COMPILER) || (defined(__GNUC__) && !defined(__clang__))
+ /* Direct copy for icc which is faster than jitted code;
+ * Direct copy for gcc which might or might not be faster than jitted
+ * code, but still worth it because doesn't require jitting, i.e. much
+ * faster creation time. This is tentative solution and should be removed
+ * later (when we will cache jitted code?...). */
+ REG_SR_DIRECT_COPY(f32, f32),
+#endif
+
+#ifdef __INTEL_COMPILER
+ /* direct copy for icc, which is faster than jitted code */
+ /*
+ REG_SR_DIRECT_COPY(f32, s32),
+ REG_SR_DIRECT_COPY(f32, s8),
+ REG_SR_DIRECT_COPY(f32, u8),
+ REG_SR_DIRECT_COPY(s32, f32),
+ REG_SR_DIRECT_COPY(s32, s32),
+ REG_SR_DIRECT_COPY(s32, s8),
+ REG_SR_DIRECT_COPY(s32, u8),
+ REG_SR_DIRECT_COPY(s8, f32),
+ REG_SR_DIRECT_COPY(s8, s32),
+ REG_SR_DIRECT_COPY(s8, s8),
+ REG_SR_DIRECT_COPY(s8, u8),
+ REG_SR_DIRECT_COPY(u8, f32),
+ REG_SR_DIRECT_COPY(u8, s32),
+ REG_SR_DIRECT_COPY(u8, s8),
+ REG_SR_DIRECT_COPY(u8, u8),
+ */
+#endif
+
+ /* jit */
+ jit_uni_reorder_create,
+
+ /* fp32: flat <-> blocked with tail */
+ /*
+ REG_SR_BIDIR(f32, any, f32, nCw4c),
+ REG_SR_BIDIR(f32, any, f32, nCw8c),
+ REG_SR_BIDIR(f32, any, f32, OIw4i4o),
+ REG_SR_BIDIR(f32, any, f32, OIw8i8o),
+ REG_SR_BIDIR(f32, any, f32, OIw8o8i),
+ REG_SR_BIDIR(f32, any, f32, gOIw4i4o),
+ REG_SR_BIDIR(f32, any, f32, gOIw8i8o),
+ REG_SR_BIDIR(f32, any, f32, gOIw8o8i),
+
+ REG_SR_BIDIR(f32, any, f32, nCw16c),
+ REG_SR_BIDIR(f32, any, f32, OIw16o16i),
+ REG_SR_BIDIR(f32, any, f32, OIw16i16o),
+ REG_SR_BIDIR(f32, any, f32, IOw16o16i),
+ REG_SR_BIDIR(f32, any, f32, gOIw16o16i),
+ REG_SR_BIDIR(f32, any, f32, gOIw16i16o),
+ REG_SR_BIDIR(f32, any, f32, gIOw16o16i),
+
+ REG_SR_BIDIR(f32, any, f32, nChw4c),
+ REG_SR_BIDIR(f32, any, f32, nChw8c),
+ REG_SR_BIDIR(f32, any, f32, OIhw4i4o),
+ REG_SR_BIDIR(f32, any, f32, Ohwi8o),
+
+ REG_SR_BIDIR(f32, any, f32, OIhw8i8o),
+ REG_SR_BIDIR(f32, any, f32, OIhw8o8i),
+ REG_SR_BIDIR(f32, any, f32, gOIhw4i4o),
+ REG_SR_BIDIR(f32, any, f32, gOIhw4o4i),
+ REG_SR_BIDIR(f32, any, f32, gOhwi8o),
+ REG_SR_BIDIR(f32, any, f32, gOIhw8i8o),
+ REG_SR_BIDIR(f32, any, f32, gOIhw8o8i),
+
+ REG_SR_BIDIR(f32, any, f32, nChw16c),
+ REG_SR_BIDIR(f32, any, f32, Oihw4o),
+ REG_SR_BIDIR(f32, any, f32, Oihw16o),
+ REG_SR_BIDIR(f32, any, f32, Ohwi4o),
+ REG_SR_BIDIR(f32, any, f32, Ohwi16o),
+ REG_SR_BIDIR(f32, any, f32, OIhw16o16i),
+ REG_SR_BIDIR(f32, any, f32, OIhw16i16o),
+ REG_SR_BIDIR(f32, any, f32, IOhw16o16i),
+ REG_SR_BIDIR(f32, any, f32, gOihw4o),
+ REG_SR_BIDIR(f32, any, f32, gOihw16o),
+ REG_SR_BIDIR(f32, any, f32, gOhwi4o),
+ REG_SR_BIDIR(f32, any, f32, gOhwi16o),
+ REG_SR_BIDIR(f32, any, f32, gOIhw16o16i),
+ REG_SR_BIDIR(f32, any, f32, gOIhw16i16o),
+ REG_SR_BIDIR(f32, any, f32, gIOhw16o16i),
+
+ REG_SR_BIDIR(f32, any, f32, nCdhw4c),
+ REG_SR_BIDIR(f32, any, f32, nCdhw8c),
+ REG_SR_BIDIR(f32, any, f32, OIdhw4i4o),
+ REG_SR_BIDIR(f32, any, f32, Odhwi8o),
+ REG_SR_BIDIR(f32, any, f32, OIdhw8i8o),
+ REG_SR_BIDIR(f32, any, f32, OIdhw8o8i),
+ REG_SR_BIDIR(f32, any, f32, gOIdhw4i4o),
+ REG_SR_BIDIR(f32, any, f32, gOdhwi8o),
+ REG_SR_BIDIR(f32, any, f32, gOIdhw8i8o),
+ REG_SR_BIDIR(f32, any, f32, gOIdhw8o8i),
+
+ REG_SR_BIDIR(f32, any, f32, nCdhw16c),
+ REG_SR_BIDIR(f32, any, f32, Oidhw4o),
+ REG_SR_BIDIR(f32, any, f32, Oidhw16o),
+ REG_SR_BIDIR(f32, any, f32, Odhwi16o),
+ REG_SR_BIDIR(f32, any, f32, OIdhw16o16i),
+ REG_SR_BIDIR(f32, any, f32, OIdhw16i16o),
+ REG_SR_BIDIR(f32, any, f32, gOidhw4o),
+ REG_SR_BIDIR(f32, any, f32, gOidhw16o),
+ REG_SR_BIDIR(f32, any, f32, gOdhwi16o),
+ REG_SR_BIDIR(f32, any, f32, gOIdhw16o16i),
+ REG_SR_BIDIR(f32, any, f32, gOIdhw16i16o),
+ */
+
+ /* fp32: blocked <-> blocked with tail */
+ REG_SR_BIDIR(f32, nCw8c, f32, nCw16c),
+ REG_SR_BIDIR(f32, nChw8c, f32, nChw16c),
+ REG_SR_BIDIR(f32, nCdhw8c, f32, nCdhw16c),
+
+ /* int: flat <-> blocked with tail */
+ /*
+ REG_SR_BIDIR(f32, any, s32, nChw16c),
+ REG_SR_BIDIR(f32, any, s8, nChw16c),
+ REG_SR_BIDIR(f32, any, u8, nChw16c),
+ REG_SR_BIDIR(s32, any, f32, nChw16c),
+ REG_SR_BIDIR(s32, any, s32, nChw16c),
+ REG_SR_BIDIR(s32, any, s8, nChw16c),
+ REG_SR_BIDIR(s32, any, u8, nChw16c),
+ REG_SR_BIDIR(s8, any, f32, nChw16c),
+ REG_SR_BIDIR(s8, any, s32, nChw16c),
+ REG_SR_BIDIR(s8, any, s8, nChw16c),
+ REG_SR_BIDIR(s8, any, u8, nChw16c),
+ REG_SR_BIDIR(u8, any, f32, nChw16c),
+ REG_SR_BIDIR(u8, any, s32, nChw16c),
+ REG_SR_BIDIR(u8, any, s8, nChw16c),
+ REG_SR_BIDIR(u8, any, u8, nChw16c),
+
+ REG_SR_BIDIR(f32, any, f32, OIhw4i16o4i),
+ REG_SR_BIDIR(f32, any, s8, OIhw4i16o4i),
+ REG_SR_BIDIR(s8, any, f32, OIhw4i16o4i),
+ REG_SR_BIDIR(s8, any, s8, OIhw4i16o4i),
+ REG_SR_BIDIR(f32, any, s8, gOIhw4i16o4i),
+ REG_SR_BIDIR(s8, any, f32, gOIhw4i16o4i),
+ REG_SR_BIDIR(f32, any, f32, gOIhw4i16o4i),
+ REG_SR_BIDIR(s8, any, s8, gOIhw4i16o4i),
+ */
+
+ /* reference: the last line of defence */
+ /*
+ REG_SR(f32, any, f32, any, fmt_order::any, spec::reference),
+ REG_SR(f32, any, s32, any, fmt_order::any, spec::reference),
+ REG_SR(f32, any, s8, any, fmt_order::any, spec::reference),
+ REG_SR(f32, any, u8, any, fmt_order::any, spec::reference),
+
+ REG_SR(s32, any, f32, any, fmt_order::any, spec::reference),
+ REG_SR(s32, any, s32, any, fmt_order::any, spec::reference),
+ REG_SR(s32, any, s8, any, fmt_order::any, spec::reference),
+ REG_SR(s32, any, u8, any, fmt_order::any, spec::reference),
+
+ REG_SR(s8, any, f32, any, fmt_order::any, spec::reference),
+ REG_SR(s8, any, s32, any, fmt_order::any, spec::reference),
+ REG_SR(s8, any, s8, any, fmt_order::any, spec::reference),
+ REG_SR(s8, any, u8, any, fmt_order::any, spec::reference),
+
+ REG_SR(u8, any, f32, any, fmt_order::any, spec::reference),
+ REG_SR(u8, any, s32, any, fmt_order::any, spec::reference),
+ REG_SR(u8, any, u8, any, fmt_order::any, spec::reference),
+ REG_SR(u8, any, s8, any, fmt_order::any, spec::reference),
+ */
+
+ /* eol */
+ nullptr,
+};
+}
+
+const rpd_create_f *cpu_engine_t::get_reorder_implementation_list() const {
+ return cpu_reorder_impl_list;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp
new file mode 100644
index 0000000000..1622eb6849
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_reorder_pd.hpp
@@ -0,0 +1,48 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REORDER_PD_HPP
+#define CPU_REORDER_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "reorder_pd.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_reorder_pd_t: public reorder_pd_t {
+ using reorder_pd_t::reorder_pd_t;
+
+ status_t init() {
+ const auto &post_ops = attr()->post_ops_;
+ bool args_ok = IMPLICATION(post_ops.len_ != 0, post_ops.len_ == 1
+ && post_ops.entry_[0].kind == primitive_kind::sum);
+ scratchpad_engine_ = src_engine_;
+ return args_ok ? status::success : status::unimplemented;
+ }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp
new file mode 100644
index 0000000000..f16587b99f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_shuffle_pd.hpp
@@ -0,0 +1,41 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_SHUFFLE_PD_HPP
+#define CPU_SHUFFLE_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "shuffle_pd.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_shuffle_pd_t: public shuffle_pd_t {
+ using shuffle_pd_t::shuffle_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp
new file mode 100644
index 0000000000..3a39eab974
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_softmax_pd.hpp
@@ -0,0 +1,45 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_SOFTMAX_PD_HPP
+#define CPU_SOFTMAX_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "softmax_pd.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_softmax_fwd_pd_t: public softmax_fwd_pd_t {
+ using softmax_fwd_pd_t::softmax_fwd_pd_t;
+};
+
+struct cpu_softmax_bwd_pd_t: public softmax_bwd_pd_t {
+ using softmax_bwd_pd_t::softmax_bwd_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp
new file mode 100644
index 0000000000..1ab5d9f174
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum.cpp
@@ -0,0 +1,48 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "cpu_engine.hpp"
+
+/*
+#include "cpu/ref_sum.hpp"
+#include "cpu/simple_sum.hpp"
+*/
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using spd_create_f = mkldnn::impl::engine_t::sum_primitive_desc_create_f;
+
+namespace {
+#define INSTANCE(...) __VA_ARGS__::pd_t::create
+static const spd_create_f cpu_sum_impl_list[] = {
+ /*
+ INSTANCE(simple_sum_t<data_type::f32>),
+ INSTANCE(ref_sum_t),
+ */
+ nullptr,
+};
+#undef INSTANCE
+}
+
+const spd_create_f *cpu_engine_t::get_sum_implementation_list() const {
+ return cpu_sum_impl_list;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp
new file mode 100644
index 0000000000..0965129f9b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/cpu_sum_pd.hpp
@@ -0,0 +1,39 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_SUM_PD_HPP
+#define CPU_SUM_PD_HPP
+
+#include "c_types_map.hpp"
+#include "sum_pd.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_sum_pd_t: public sum_pd_t {
+ using sum_pd_t::sum_pd_t;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp
new file mode 100644
index 0000000000..a9810dec28
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.cpp
@@ -0,0 +1,372 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+#include <cmath>
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+#include "gemm_utils_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+namespace gemm_utils {
+#define BM_NOCOPY_AVX 64
+#define BN_NOCOPY_AVX 48
+#define BK_NOCOPY_AVX 384
+#define BN_LARGE_NOCOPY_AVX 192
+#define BM_SMALL_NOCOPY_AVX 16
+#define BN_SMALL_NOCOPY_AVX 1
+#define BK_SMALL_NOCOPY_AVX 4
+// Determine number of threads for each dimension of a 3-D partitioning
+// algorithm based on input parameters
+// m/n/k - First/second/third parameter for GEMM
+// nthrs - total available number of threads
+// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
+// BM/BN/BK - blocking values
+void calc_nthr_nocopy_avx(int m, int n, int k,
+ int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
+ int *BK)
+{
+ int nthr, nthr_m, nthr_n, nthr_k;
+ int MB, NB, KB;
+
+ nthr = nthrs;
+ nthr_m = (m + BM_NOCOPY_AVX - 1) / BM_NOCOPY_AVX;
+ nthr_n = (n + BN_NOCOPY_AVX - 1) / BN_NOCOPY_AVX;
+ nthr_k = 1;
+
+ // Partition along K dimension
+ // - if threading allows having barriers (e.g. OMP)
+ // - if there is not enough parallelism along M or N
+ if (mkldnn_thr_syncable()) {
+ int nthr_other = nthr_k = 1;
+ while ((nthr_m * nthr_n * nthr_other < nthr)
+ && (k / (nthr_other + 1) > BK_NOCOPY_AVX)) {
+ nthr_other++;
+ if ((nthr / nthr_other) * nthr_other > 0.9 * nthr)
+ nthr_k = nthr_other;
+ }
+ }
+ nthr /= nthr_k;
+
+ if (nthr_m == 1)
+ nthr_n = nthr;
+ if (nthr_n == 1)
+ nthr_m = nthr;
+
+ // Simple partition reduction
+ while (nthr_m * nthr_n > nthr)
+ if (nthr_m > nthr_n)
+ nthr_m--;
+ else
+ nthr_n--;
+ while (nthr_m * nthr_n < nthr)
+ if (nthr_m < nthr_n)
+ nthr_m++;
+ else
+ nthr_n++;
+
+ if ((nthr_m * nthr_n > nthr) && (nthr_m > 1) && (nthr_n > 1)) {
+
+ if (nthr_m <= nthr_n) {
+ nthr_m = (int)sqrt((double)nthr);
+ if (nthr_m > (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX)
+ nthr_m = (m + BM_SMALL_NOCOPY_AVX - 1) / BM_SMALL_NOCOPY_AVX;
+ nthr_n = nthr / nthr_m;
+
+ while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
+ nthr_m--;
+ nthr_n = nthr / nthr_m;
+ }
+ } else {
+ nthr_n = (int)sqrt((double)nthr);
+ if (nthr_n > (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX)
+ nthr_n = (n + BN_SMALL_NOCOPY_AVX - 1) / BN_SMALL_NOCOPY_AVX;
+ nthr_m = nthr / nthr_n;
+
+ while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
+ nthr_n--;
+ nthr_m = nthr / nthr_n;
+ }
+ }
+ }
+
+ MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX - 1;
+ MB -= MB % BM_SMALL_NOCOPY_AVX;
+ NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX - 1;
+ NB -= NB % BN_SMALL_NOCOPY_AVX;
+ KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX - 1;
+ KB -= KB % BK_SMALL_NOCOPY_AVX;
+
+ if (MB * nthr_m > m)
+ nthr_m = (m + MB - 1) / MB;
+ if (NB * nthr_n > n)
+ nthr_n = (n + NB - 1) / NB;
+ if (KB * nthr_k > k)
+ nthr_k = (k + KB - 1) / KB;
+
+ *nthrs_m = nthr_m;
+ *nthrs_n = nthr_n;
+ *nthrs_k = nthr_k;
+
+ *BM = MB;
+ *BN = NB;
+ *BK = KB;
+}
+#undef BM_NOCOPY_AVX
+#undef BN_NOCOPY_AVX
+#undef BK_NOCOPY_AVX
+#undef BN_LARGE_NOCOPY_AVX
+#undef BM_SMALL_NOCOPY_AVX
+#undef BN_SMALL_NOCOPY_AVX
+#undef BK_SMALL_NOCOPY_AVX
+
+#define BM_NOCOPY_AVX512_COMMON 32
+#define BN_NOCOPY_AVX512_COMMON 64
+#define BK_NOCOPY_AVX512_COMMON 192
+#define BN_LARGE_NOCOPY_AVX512_COMMON 192
+#define BM_SMALL_NOCOPY_AVX512_COMMON 16
+#define BN_SMALL_NOCOPY_AVX512_COMMON 1
+#define BK_SMALL_NOCOPY_AVX512_COMMON 4
+// Determine number of threads for each dimension of a 3-D partitioning
+// algorithm based on input parameters
+// m/n/k - First/second/third parameter for GEMM
+// nthrs - total available number of threads
+// nthrs_m/nthrs_n/nthrs_k - number of threads to use in each dimension
+// BM/BN/BK - blocking values
+void calc_nthr_nocopy_avx512_common(int m,
+ int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
+ int *BM, int *BN, int *BK)
+{
+ int nthr, nthr_m, nthr_n, nthr_k = 1;
+ int MB, NB, KB;
+ nthr = nthrs;
+
+ int counter = 0;
+ float ratio_float = 1.;
+ int ratio = 1;
+ nthr = nthrs;
+ int nthr_m_gt_n;
+
+ // Partition along K dimension
+ // - if threading allows having barriers (e.g. OMP)
+ // - if there is not enough parallelism along M or N
+ if (mkldnn_thr_syncable()) {
+ if (n <= 2 * BN_NOCOPY_AVX512_COMMON &&
+ m <= 2 * BM_NOCOPY_AVX512_COMMON * nthr) {
+ nthr_k = k / BK_NOCOPY_AVX512_COMMON;
+ if (nthr_k > nthr / 4)
+ nthr_k = nthr / 4;
+ if (nthr_k < 1)
+ nthr_k = 1;
+
+ while ((nthr_k > 1) && (nthr % nthr_k)) {
+ nthr_k--;
+ }
+ nthr /= nthr_k;
+ } else {
+ nthr_k = 1;
+ }
+ }
+ nthr_m = (m + BM_NOCOPY_AVX512_COMMON - 1) / BM_NOCOPY_AVX512_COMMON;
+ nthr_n = (n + BN_NOCOPY_AVX512_COMMON - 1) / BN_NOCOPY_AVX512_COMMON;
+
+ if (nthr_m < 1)
+ nthr_m = 1;
+ if (nthr_n < 1)
+ nthr_n = 1;
+
+ nthr_m_gt_n = nthr_m > nthr_n ? 1 : 0;
+ ratio_float = (float)nthr_m / nthr_n;
+
+ if (nthr_m_gt_n)
+ ratio = (int)ratio_float;
+ else
+ ratio = (int)(1. / ratio_float);
+
+ // scale down nthr_m and nthr_n if they are too large
+ while (nthr_m * nthr_n > 4 * nthr) {
+ nthr_m /= 2;
+ nthr_n /= 2;
+ }
+
+ if (nthr_m < 1)
+ nthr_m = 1;
+ if (nthr_n < 1)
+ nthr_n = 1;
+
+ // Simple partition reduction
+ counter = 0;
+ while (nthr_m * nthr_n > nthr) {
+ if (nthr_m > nthr_n) {
+ if (counter < ratio)
+ nthr_m--;
+ else {
+ nthr_n--;
+ counter = -1;
+ }
+ } else {
+ if (counter < ratio)
+ nthr_n--;
+ else {
+ nthr_m--;
+ counter = -1;
+ }
+ }
+ counter++;
+ }
+
+ // Simple partition increment
+ counter = 0;
+ while (nthr_m * nthr_n < 0.95 * nthr) {
+ if (nthr_m > nthr_n) {
+ if (counter < ratio)
+ nthr_m++;
+ else {
+ nthr_n++;
+ counter = -1;
+ }
+ } else {
+ if (counter < ratio)
+ nthr_n++;
+ else {
+ nthr_m++;
+ counter = -1;
+ }
+ }
+ counter++;
+ }
+
+ // if nothing works out, then this should work
+ if ((nthr_m * nthr_n > nthr)) {
+
+ if (nthr_m <= nthr_n) {
+ nthr_m = (int)sqrt((double)nthr);
+ if (nthr_m > (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
+ / BM_SMALL_NOCOPY_AVX512_COMMON)
+ nthr_m = (m + BM_SMALL_NOCOPY_AVX512_COMMON - 1)
+ / BM_SMALL_NOCOPY_AVX512_COMMON;
+ nthr_n = nthr / nthr_m;
+
+ while ((nthr_m > 1) && (nthr_m * nthr_n != nthr)) {
+ nthr_m--;
+ nthr_n = nthr / nthr_m;
+ }
+ } else {
+ nthr_n = (int)sqrt((double)nthr);
+ if (nthr_n > (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
+ / BN_SMALL_NOCOPY_AVX512_COMMON)
+ nthr_n = (n + BN_SMALL_NOCOPY_AVX512_COMMON - 1)
+ / BN_SMALL_NOCOPY_AVX512_COMMON;
+ nthr_m = nthr / nthr_n;
+
+ while ((nthr_n > 1) && (nthr_m * nthr_n != nthr)) {
+ nthr_n--;
+ nthr_m = nthr / nthr_n;
+ }
+ }
+ }
+
+ MB = (m + nthr_m - 1) / nthr_m + BM_SMALL_NOCOPY_AVX512_COMMON - 1;
+ MB -= MB % BM_SMALL_NOCOPY_AVX512_COMMON;
+ NB = (n + nthr_n - 1) / nthr_n + BN_SMALL_NOCOPY_AVX512_COMMON - 1;
+ NB -= NB % BN_SMALL_NOCOPY_AVX512_COMMON;
+ KB = (k + nthr_k - 1) / nthr_k + BK_SMALL_NOCOPY_AVX512_COMMON - 1;
+ KB -= KB % BK_SMALL_NOCOPY_AVX512_COMMON;
+
+ if (MB * nthr_m > m)
+ nthr_m = (m + MB - 1) / MB;
+ if (NB * nthr_n > n)
+ nthr_n = (n + NB - 1) / NB;
+ if (KB * nthr_k > k)
+ nthr_k = (k + KB - 1) / KB;
+
+ *nthrs_m = nthr_m;
+ *nthrs_n = nthr_n;
+ *nthrs_k = nthr_k;
+
+ *BM = MB;
+ *BN = NB;
+ *BK = KB;
+}
+#undef BM_NOCOPY_AVX512_COMMON
+#undef BN_NOCOPY_AVX512_COMMON
+#undef BK_NOCOPY_AVX512_COMMON
+#undef BN_LARGE_NOCOPY_AVX512_COMMON
+#undef BM_SMALL_NOCOPY_AVX512_COMMON
+#undef BN_SMALL_NOCOPY_AVX512_COMMON
+#undef BK_SMALL_NOCOPY_AVX512_COMMON
+
+// Partition n values as equally as possible among nthr threads
+// and set the offset (t_offset) and number of values (t_block) for ithr
+// Assumption: 0 <= ithr < nthr
+void partition_unit_diff(
+ int ithr, int nthr, int n, int *t_offset, int *t_block)
+{
+ int band = n / nthr;
+ if (band == 0)
+ band = 1;
+ int tail = n - band * nthr;
+ if (tail < 0)
+ tail = 0;
+
+ if (ithr < tail) {
+ band++;
+ *t_offset = band * ithr;
+ *t_block = band;
+ } else {
+ *t_offset = band * ithr + tail;
+ *t_block = band;
+ }
+
+ if (*t_offset >= n) {
+ *t_offset = 0;
+ *t_block = 0;
+ }
+
+ if (*t_offset + *t_block > n) {
+ *t_block = n - *t_offset;
+ }
+}
+
+// Sum the m*n values from p_src into p_dst, assuming the two-dimensional
+// arrays have leading dimensions ld_src and ld_dst, respectively
+template<typename data_t>
+void sum_two_matrices(int m, int n,
+ data_t * __restrict p_src, dim_t ld_src,
+ data_t * __restrict p_dst, dim_t ld_dst)
+{
+ int i, j;
+ for (j = 0; j < n; j++) {
+ for (i = 0; i < m; i++) {
+ p_dst[i + j * ld_dst] += p_src[i + j * ld_src];
+ }
+ }
+}
+
+template
+void sum_two_matrices<float>(int m, int n,
+ float * __restrict p_src, dim_t ld_src,
+ float * __restrict p_dst, dim_t ld_dst);
+
+template
+void sum_two_matrices<double>(int m, int n,
+ double * __restrict p_src, dim_t ld_src,
+ double * __restrict p_dst, dim_t ld_dst);
+}
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp
new file mode 100644
index 0000000000..3352298b4a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/gemm_utils_f32.hpp
@@ -0,0 +1,72 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef GEMM_UTILS_HPP
+#define GEMM_UTILS_HPP
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace gemm_utils {
+// Alias for any dimension related variable.
+typedef ptrdiff_t dim_t;
+
+template <typename T, bool isTransA, bool isTransB>
+struct gemm_traits {};
+
+template <bool isTransA, bool isTransB>
+struct gemm_traits<double, isTransA, isTransB> {
+ static constexpr int m = 8;
+ static constexpr int n = 6;
+ static constexpr int BM = 4032;
+ static constexpr int BN = isTransA ? 96 : 192;
+ static constexpr int BK = isTransB ? 96 : 512;
+};
+
+template <bool isTransA, bool isTransB>
+struct gemm_traits<float, isTransA, isTransB> {
+ static constexpr int m = 16;
+ static constexpr int n = 6;
+ static constexpr int BM = 4032;
+ static constexpr int BN = isTransA ? 96 : 48;
+ static constexpr int BK = isTransB ? 96 : 256;
+};
+
+template <typename T>
+using unroll_factor = gemm_traits<T, false, false>;
+
+template <typename data_t>
+void sum_two_matrices(int m, int n,
+ data_t * __restrict p_src, dim_t ld_src,
+ data_t * __restrict p_dst, dim_t ld_dst);
+
+void calc_nthr_nocopy_avx512_common(int m,
+ int n, int k, int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k,
+ int *BM, int *BN, int *BK);
+
+void calc_nthr_nocopy_avx(int m, int n, int k,
+ int nthrs, int *nthrs_m, int *nthrs_n, int *nthrs_k, int *BM, int *BN,
+ int *BK);
+
+void partition_unit_diff(
+ int ithr, int nthr, int n, int *t_offset, int *t_block);
+};
+
+}
+}
+}
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp
new file mode 100644
index 0000000000..d7be43e392
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.cpp
@@ -0,0 +1,2131 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <cmath>
+#include <mutex>
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "ref_gemm_f32.hpp"
+#include "gemm_utils_f32.hpp"
+#include "jit_avx512_common_gemm_f32.hpp"
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+#define CACHE_LINE_SIZE 64
+
+#define STACKSIZE get_size_of_abi_save_regs()
+#ifdef _WIN32
+#define STACK_K_CAPACITY 32
+#else
+#define STACK_K_CAPACITY 2048
+#endif
+#define SIZE 4
+#define OFFSET 128
+#define BASE_SHIFT 2
+#define SECOND_FETCH unroll_n
+#define UNROLL_M 48
+#define UNROLL_N 8
+
+namespace avx512_common_gemm_f32 {
+using namespace gemm_utils;
+
+struct xbyak_gemm : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_gemm_f32_xbyak_gemm)
+
+ xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
+ void *code_ptr = nullptr,
+ size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
+ : jit_generator(code_ptr, code_size)
+ {
+ using namespace Xbyak;
+
+ enum { ver_avx512_core, ver_avx512_mic } ver =
+ mayiuse(avx512_core) ? ver_avx512_core : ver_avx512_mic;
+
+ bool isBeta0 = (beta == 0.0);
+ bool isBetaN = (!isBeta0 && beta != 1.0);
+
+ // various definitions for convenience
+ auto ARG_M = abi_param1;
+ auto ARG_N = abi_param2;
+ auto K = abi_param3;
+ auto ARG_ALPHA = abi_param4;
+#ifdef _WIN32
+ auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
+ auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
+ sizeof(float *) + STACKSIZE];
+ const auto stackOffset = OFFSET_SHADOWSPACE +
+ sizeof(float *) + STACKSIZE;
+ auto A = rsi;
+ auto LDA = rdi;
+#else
+ auto ARG_A = r8;
+ auto ARG_LDA = r9;
+ const auto stackOffset = STACKSIZE;
+ auto A = ARG_A;
+ auto LDA = ARG_LDA;
+#endif
+ auto ARG_B = ptr[rsp + 8 + stackOffset];
+ auto ARG_LDB = ptr[rsp + 16 + stackOffset];
+ auto ARG_BETA = ptr[rsp + 24 + stackOffset];
+ auto ARG_C = ptr[rsp + 32 + stackOffset];
+ auto ARG_LDC = ptr[rsp + 40 + stackOffset];
+ auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
+ auto ARG_WS = ptr[rsp + 56 + stackOffset];
+
+ auto B = r11;
+ auto LDB = rbx;
+ auto LDC = r13;
+ auto LL = rax;
+ auto AO1 = abi_param2;
+ auto BO1 = abi_param4;
+ auto BO2 = rbp;
+ auto CO1 = r14;
+ auto CO2 = r15;
+ auto LDB3 = r10;
+ auto LDA4 = abi_param1;
+ auto AA = r12;
+ auto BIAS1 = abi_param1;
+
+ auto M = qword[rsp + 0];
+ auto N = qword[rsp + 8];
+ auto FLAG = qword[rsp + 16];
+ auto I = qword[rsp + 24];
+ auto C = qword[rsp + 32];
+ auto BIAS = qword[rsp + 40];
+ auto ALPHA = qword[rsp + 48];
+ auto BETA = qword[rsp + 64];
+ auto ORIG_A = qword[rsp + 80];
+ auto ORIG_SP = qword[rsp + 120];
+
+ auto ZSTRIDE = zmm4;
+ auto VALPHA = zmm6;
+ auto VBETA = zmm7;
+ auto VBIAS1 = zmm1;
+ auto VBIAS2 = zmm2;
+ auto VBIAS3 = zmm3;
+
+ auto PREFETCHSIZEA = ver == ver_avx512_core ? 48 : 80;
+ auto PREFETCHSIZEB = 16;
+
+ Zmm regs[] = { zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15,
+ zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23, zmm24,
+ zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31 };
+
+ // Function for packing if needed
+ auto do_pack = [&](int unroll_m) {
+ Label pack2, pack3, pack4, pack10;
+
+ mov(BO1, A);
+ lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
+ mov(LL, K);
+ sar(LL, 2);
+ jle(pack3, T_NEAR);
+ align(16);
+
+ L(pack2);
+ if (!isTransA) {
+ for (int i = 0; i < 4; i++) {
+ vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
+ if (unroll_m > 16)
+ vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
+ if (unroll_m > 32)
+ vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
+ add(BO1, LDA);
+
+ vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE]
+ | k1,
+ zmm0);
+ if (unroll_m > 16)
+ vmovups(ptr[AO1
+ + (unroll_m * i + 1 * 16 - OFFSET)
+ * SIZE]
+ | k2,
+ zmm1);
+ if (unroll_m > 32)
+ vmovups(ptr[AO1
+ + (unroll_m * i + 2 * 16 - OFFSET)
+ * SIZE]
+ | k3,
+ zmm2);
+ }
+ } else {
+ for (int i = 0; i < 4; i++) {
+ kmovw(k4, k1);
+ vgatherqps(ymm5 | k4,
+ ptr[BO1 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO1 + LDA * 8]);
+ kshiftrw(k4, k1, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ vshuff64x2(zmm0, zmm5, zmm6, 0x44);
+
+ if (unroll_m > 16) {
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kmovw(k4, k2);
+ vgatherqps(ymm5 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kshiftrw(k4, k2, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ vshuff64x2(zmm1, zmm5, zmm6, 0x44);
+ }
+
+ if (unroll_m > 32) {
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kmovw(k4, k3);
+ vgatherqps(ymm5 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kshiftrw(k4, k3, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (i - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ vshuff64x2(zmm2, zmm5, zmm6, 0x44);
+ }
+
+ vmovups(ptr[AO1 + (unroll_m * i + 0 * 16 - OFFSET) * SIZE],
+ zmm0 | k1);
+ if (unroll_m > 16)
+ vmovups(ptr[AO1
+ + (unroll_m * i + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1 | k2);
+ if (unroll_m > 32)
+ vmovups(ptr[AO1
+ + (unroll_m * i + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2 | k3);
+ }
+ add(BO1, 4 * SIZE);
+ }
+ add(AO1, unroll_m * 4 * SIZE);
+
+ sub(LL, 1);
+ jg(pack2, T_NEAR);
+ align(16);
+
+ L(pack3);
+ mov(LL, K);
+ and_(LL, 3);
+ jle(pack10, T_NEAR);
+ align(16);
+
+ L(pack4);
+ if (!isTransA) {
+ vmovups(zmm0 | k1, ptr[BO1 + (0 * 16 - OFFSET) * SIZE]);
+ if (unroll_m > 16)
+ vmovups(zmm1 | k2, ptr[BO1 + (1 * 16 - OFFSET) * SIZE]);
+ if (unroll_m > 32)
+ vmovups(zmm2 | k3, ptr[BO1 + (2 * 16 - OFFSET) * SIZE]);
+ add(BO1, LDA);
+ } else {
+ kmovw(k4, k1);
+ vgatherqps(ymm5 | k4, ptr[BO1 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO1 + LDA * 8]);
+ kshiftrw(k4, k1, 8);
+ vgatherqps(ymm6 | k4, ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ vshuff64x2(zmm0, zmm5, zmm6, 0x44);
+
+ if (unroll_m > 16) {
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kmovw(k4, k2);
+ vgatherqps(ymm5 | k4,
+ ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kshiftrw(k4, k2, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ vshuff64x2(zmm1, zmm5, zmm6, 0x44);
+ }
+
+ if (unroll_m > 32) {
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kmovw(k4, k3);
+ vgatherqps(ymm5 | k4,
+ ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ kshiftrw(k4, k3, 8);
+ vgatherqps(ymm6 | k4,
+ ptr[BO2 + ZSTRIDE + (0 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 8]);
+ vshuff64x2(zmm2, zmm5, zmm6, 0x44);
+ }
+ add(BO1, SIZE);
+ }
+
+ vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
+ zmm0 | k1);
+ if (unroll_m > 16)
+ vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 16 - OFFSET) * SIZE],
+ zmm1 | k2);
+ if (unroll_m > 32)
+ vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 16 - OFFSET) * SIZE],
+ zmm2 | k3);
+
+ add(AO1, unroll_m * SIZE);
+ sub(LL, 1);
+ jg(pack4, T_NEAR);
+ align(16);
+
+ L(pack10);
+ };
+
+ // Function to update C, covering masking and other considerations
+ auto update = [&](Zmm reg, bool useCO1, int offset, int mask,
+ bool useScale = false) {
+ vmulps(reg, reg, VALPHA);
+ if (!isBeta0) {
+ if (!useScale) {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(zmm0, ptr[CO1 + offset * SIZE]);
+ else
+ vmovups(zmm0, ptr[CO2 + offset * SIZE]);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(zmm0 | k1 | T_z, ptr[CO1 + offset * SIZE]);
+ else
+ vmovups(zmm0 | k1 | T_z, ptr[CO2 + offset * SIZE]);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(zmm0 | k2 | T_z, ptr[CO1 + offset * SIZE]);
+ else
+ vmovups(zmm0 | k2 | T_z, ptr[CO2 + offset * SIZE]);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(zmm0 | k3 | T_z, ptr[CO1 + offset * SIZE]);
+ else
+ vmovups(zmm0 | k3 | T_z, ptr[CO2 + offset * SIZE]);
+ break;
+ }
+ } else {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(zmm0, ptr[CO1 + LDC + offset * SIZE]);
+ else
+ vmovups(zmm0, ptr[CO2 + LDC + offset * SIZE]);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(zmm0 | k1 | T_z,
+ ptr[CO1 + LDC + offset * SIZE]);
+ else
+ vmovups(zmm0 | k1 | T_z,
+ ptr[CO2 + LDC + offset * SIZE]);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(zmm0 | k2 | T_z,
+ ptr[CO1 + LDC + offset * SIZE]);
+ else
+ vmovups(zmm0 | k2 | T_z,
+ ptr[CO2 + LDC + offset * SIZE]);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(zmm0 | k3 | T_z,
+ ptr[CO1 + LDC + offset * SIZE]);
+ else
+ vmovups(zmm0 | k3 | T_z,
+ ptr[CO2 + LDC + offset * SIZE]);
+ break;
+ }
+ }
+ if (!isBetaN) {
+ vaddps(zmm0, reg, zmm0);
+ } else {
+ vfmadd132ps(zmm0, reg, VBETA);
+ }
+ if (!useScale) {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], zmm0);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], zmm0);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], zmm0 | k1);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], zmm0 | k1);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], zmm0 | k2);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], zmm0 | k2);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], zmm0 | k3);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], zmm0 | k3);
+ break;
+ }
+ } else {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k1);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k1);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k2);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k2);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], zmm0 | k3);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], zmm0 | k3);
+ break;
+ }
+ }
+ } else {
+ if (!useScale) {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], reg);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], reg);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], reg | k1);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], reg | k1);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], reg | k2);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], reg | k2);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(ptr[CO1 + offset * SIZE], reg | k3);
+ else
+ vmovups(ptr[CO2 + offset * SIZE], reg | k3);
+ break;
+ }
+ } else {
+ switch (mask) {
+ case 0:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], reg);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], reg);
+ break;
+ case 1:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k1);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k1);
+ break;
+ case 2:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k2);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k2);
+ break;
+ case 3:
+ if (useCO1)
+ vmovups(ptr[CO1 + LDC + offset * SIZE], reg | k3);
+ else
+ vmovups(ptr[CO2 + LDC + offset * SIZE], reg | k3);
+ break;
+ }
+ }
+ }
+ vpxorq(reg, reg, reg);
+ };
+
+ // Loop with unroll_n - 2 FMAs; called by innerkernel
+ auto fmaloop = [&](int unroll_m, int unroll_n, int iteration) {
+ for (int i = 2; i < unroll_n; i++) {
+ if (ver == ver_avx512_core) {
+ if (!isTransB) {
+ switch (i) {
+ case 2:
+ vbroadcastss(
+ zmm3,
+ ptr[BO1 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 3:
+ vbroadcastss(
+ zmm3,
+ ptr[BO1 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 4:
+ vbroadcastss(zmm3,
+ ptr[BO2 + (iteration - OFFSET) * SIZE]);
+ break;
+ case 5:
+ vbroadcastss(
+ zmm3,
+ ptr[BO2 + LDB * 1
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 6:
+ vbroadcastss(
+ zmm3,
+ ptr[BO2 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 7:
+ vbroadcastss(
+ zmm3,
+ ptr[BO2 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ }
+ } else {
+ vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
+ }
+ vfmadd231ps(regs[i], zmm3, zmm0);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm3, zmm1);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm3, zmm2);
+ } else {
+ if (!isTransB) {
+ switch (i) {
+ case 2:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO1 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO1 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO1 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 3:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO1 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO1 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO1 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 4:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO2 + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO2 + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO2 + (iteration - OFFSET) * SIZE]);
+ break;
+ case 5:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO2 + LDB * 1
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO2 + LDB * 1
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO2 + LDB * 1
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 6:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO2 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO2 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO2 + LDB * 2
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ case 7:
+ vfmadd231ps(regs[i], zmm0,
+ zword_b[BO2 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO2 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO2 + LDB3
+ + (iteration - OFFSET) * SIZE]);
+ break;
+ }
+ } else {
+ vfmadd231ps(
+ regs[i], zmm0, zword_b[BO1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[i + 8], zmm1,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[i + 16], zmm2,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ }
+ }
+ }
+ };
+
+ // Innerkernel; called by kernel
+ auto innerkernel = [&](int unroll_m, int unroll_n, bool isDirect,
+ bool isCopy, bool doCPrefetch, bool isUnmasked = true) {
+ for (int i = 0; i < 8; i++) {
+ if (!isDirect) {
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + i * unroll_m + 0 * 16 - OFFSET)
+ * SIZE]);
+ if (unroll_m >= 32)
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + i * unroll_m + 1 * 16 - OFFSET)
+ * SIZE]);
+ if (unroll_m >= 48)
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + i * unroll_m + 2 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4 + (16 * 0 * SIZE)]);
+ if (unroll_m >= 32)
+ prefetcht0(ptr[AO1 + LDA4 + (16 * 1 * SIZE)]);
+ if (unroll_m >= 48)
+ prefetcht0(ptr[AO1 + LDA4 + (16 * 2 * SIZE)]);
+ }
+
+ if (!isDirect) {
+ if (i != 0) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0,
+ ptr[AO1
+ + (unroll_m * i + 0 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1
+ + (unroll_m * i + 0 * 16 - OFFSET)
+ * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1
+ + (unroll_m * i + 1 * 16
+ - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1
+ + (unroll_m * i + 1 * 16
+ - OFFSET)
+ * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1
+ + (unroll_m * i + 2 * 16
+ - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1
+ + (unroll_m * i + 2 * 16
+ - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+ } else {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (ver == ver_avx512_core) {
+ if (!isTransB) {
+ vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ vfmadd231ps(regs[0], zmm3, zmm0);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[0 + 8], zmm3, zmm1);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[0 + 16], zmm3, zmm2);
+ } else {
+ if (!isTransB) {
+ vfmadd231ps(regs[0], zmm0,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[0 + 8], zmm1,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[0 + 16], zmm2,
+ zword_b[BO1 + (i - OFFSET) * SIZE]);
+ } else {
+ vfmadd231ps(regs[0], zmm0,
+ zword_b[BO1 + (0 - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[0 + 8], zmm1,
+ zword_b[BO1 + (0 - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[0 + 16], zmm2,
+ zword_b[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ }
+
+ if (unroll_n >= i + 1) {
+ if (!isTransB) {
+ switch (i) {
+ case 0:
+ prefetcht0(
+ ptr[BO1 + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 1:
+ prefetcht0(ptr[BO1 + LDB
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 2:
+ prefetcht0(ptr[BO1 + LDB * 2
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 3:
+ prefetcht0(ptr[BO1 + LDB3
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 4:
+ prefetcht0(
+ ptr[BO2 + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 5:
+ prefetcht0(ptr[BO2 + LDB
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 6:
+ prefetcht0(ptr[BO2 + LDB * 2
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ case 7:
+ prefetcht0(ptr[BO2 + LDB3
+ + (PREFETCHSIZEB - OFFSET) * SIZE]);
+ break;
+ }
+ }
+ }
+
+ if (unroll_n >= 2) {
+ if (ver == ver_avx512_core) {
+ if (!isTransB) {
+ vbroadcastss(zmm3,
+ ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(zmm3, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ vfmadd231ps(regs[1], zmm3, zmm0);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[1 + 8], zmm3, zmm1);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[1 + 16], zmm3, zmm2);
+ } else {
+ if (!isTransB) {
+ vfmadd231ps(regs[1], zmm0,
+ zword_b[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[1 + 8], zmm1,
+ zword_b[BO1 + LDB * 1
+ + (i - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[1 + 16], zmm2,
+ zword_b[BO1 + LDB * 1
+ + (i - OFFSET) * SIZE]);
+ } else {
+ vfmadd231ps(regs[1], zmm0,
+ zword_b[BO1 + (1 - OFFSET) * SIZE]);
+ if (unroll_m >= 32)
+ vfmadd231ps(regs[1 + 8], zmm1,
+ zword_b[BO1 + (1 - OFFSET) * SIZE]);
+ if (unroll_m >= 48)
+ vfmadd231ps(regs[1 + 16], zmm2,
+ zword_b[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ }
+ }
+
+ if (isCopy) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 0 * 16 - OFFSET)
+ * SIZE],
+ zmm0);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 0 * 16 - OFFSET)
+ * SIZE],
+ zmm0 | k1);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1 | k2);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2 | k3);
+ }
+ }
+ if (i == 7)
+ sub(LDA4, -unroll_m * 8 * SIZE);
+ }
+ fmaloop(unroll_m, unroll_n, i);
+
+ if (i == 1) {
+ if (doCPrefetch) {
+ if (ver == ver_avx512_core)
+ prefetchw(ptr[CO2 + 0 * 16 * SIZE]);
+ else
+ prefetcht0(ptr[CO2 + 0 * 16 * SIZE]);
+ }
+ }
+ if (i == 3) {
+ if (doCPrefetch && unroll_m >= 32) {
+ if (ver == ver_avx512_core)
+ prefetchw(ptr[CO2 + 1 * 16 * SIZE]);
+ else
+ prefetcht0(ptr[CO2 + 1 * 16 * SIZE]);
+ }
+ if (!isTransA) {
+ if (ver == ver_avx512_core)
+ prefetcht0(ptr[AA + 16 * 0 * SIZE]);
+ else
+ prefetcht2(ptr[AA + 16 * 0 * SIZE]);
+ }
+ }
+ if (i == 5) {
+ if (doCPrefetch) {
+ if (unroll_m >= 48) {
+ if (ver == ver_avx512_core)
+ prefetchw(ptr[CO2 + 2 * 16 * SIZE]);
+ else
+ prefetcht0(ptr[CO2 + 2 * 16 * SIZE]);
+ }
+ add(CO2, LDC);
+ }
+ if (!isTransA) {
+ if (unroll_m >= 32) {
+ if (ver == ver_avx512_core)
+ prefetcht0(ptr[AA + 16 * 1 * SIZE]);
+ else
+ prefetcht2(ptr[AA + 16 * 1 * SIZE]);
+ }
+ }
+ }
+
+ if (isTransB) {
+ prefetcht0(ptr[BO1 + BO2]);
+ add(BO1, LDB);
+ }
+ } // end of for loop
+
+ if (!isTransB) {
+ sub(BO1, -8 * SIZE);
+ if (unroll_n >= 4)
+ sub(BO2, -8 * SIZE);
+ }
+ if (!isTransA) {
+ if (unroll_m >= 48) {
+ if (ver == ver_avx512_core)
+ prefetcht0(ptr[AA + 16 * 2 * SIZE]);
+ else
+ prefetcht2(ptr[AA + 16 * 2 * SIZE]);
+ }
+ lea(AA, ptr[AA + LDA]);
+ }
+
+ if (!isDirect) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0,
+ ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (unroll_m * 8 + 0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1
+ + (unroll_m * 8 + 1 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1
+ + (unroll_m * 8 + 1 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1
+ + (unroll_m * 8 + 2 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1
+ + (unroll_m * 8 + 2 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ sub(AO1, -unroll_m * 8 * SIZE);
+ }
+
+ sub(LL, 1);
+ };
+
+ // Main kernel; does prefetching and calls innerkernel
+ // After calculating results in registers, writes back to C matrix by
+ // calling update
+ auto kernel = [&](int unroll_m, int unroll_n, bool isDirect,
+ bool isCopy, bool isUnmasked = true) {
+ if (!isDirect) {
+ lea(AO1, ptr[rsp + 128 + OFFSET * SIZE]);
+ } else {
+ mov(AO1, A);
+ }
+
+ if (isCopy) {
+ lea(LDA4, ptr[rsp + 128 + OFFSET * SIZE]);
+ } else {
+ auto step = ver == ver_avx512_core ? 2 : 4;
+ lea(LDA4, ptr[LDA * step + (16 - 1 - OFFSET) * SIZE]);
+ }
+
+ if (isTransB) {
+ lea(BO2, ptr[LDB * 4 + (16 / 2 - 1 - OFFSET) * SIZE]);
+ }
+
+ if (!isDirect) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0,
+ ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1
+ + (unroll_m * 0 + 1 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1
+ + (unroll_m * 0 + 1 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1
+ + (unroll_m * 0 + 2 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1
+ + (unroll_m * 0 + 2 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+
+ Label kernel12, kernel13, kernel14, kernel15, kernel16, kernel18;
+
+ mov(LL, K);
+ sar(LL, 3);
+ sub(LL, SECOND_FETCH);
+ jle(kernel13, T_NEAR);
+ align(16);
+
+ L(kernel12);
+ innerkernel(
+ unroll_m, unroll_n, isDirect, isCopy, false, isUnmasked);
+ jg(kernel12, T_NEAR);
+ align(16);
+
+ L(kernel13);
+ lea(CO2, ptr[CO1 + (16 - 1) * SIZE]);
+ add(LL, unroll_n);
+ jle(kernel15, T_NEAR);
+ align(16);
+
+ L(kernel14);
+ innerkernel(unroll_m, unroll_n, isDirect, isCopy, true, isUnmasked);
+ jg(kernel14, T_NEAR);
+ align(16);
+
+ L(kernel15);
+ mov(LL, K);
+ and_(LL, 7);
+ jle(kernel18, T_NEAR);
+ align(16);
+
+ L(kernel16);
+ if (isDirect) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0, ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1 + (1 * 16 - OFFSET) * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1 + (2 * 16 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ for (int i = 0; i < unroll_n; i++) {
+ if (!isTransB) {
+ switch (i) {
+ case 0:
+ vbroadcastss(zmm3, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ break;
+ case 1:
+ vbroadcastss(
+ zmm3, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ break;
+ case 2:
+ vbroadcastss(
+ zmm3, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ break;
+ case 3:
+ vbroadcastss(
+ zmm3, ptr[BO1 + LDB3 + (0 - OFFSET) * SIZE]);
+ break;
+ case 4:
+ vbroadcastss(zmm3, ptr[BO2 + (0 - OFFSET) * SIZE]);
+ break;
+ case 5:
+ vbroadcastss(
+ zmm3, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ break;
+ case 6:
+ vbroadcastss(
+ zmm3, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ break;
+ case 7:
+ vbroadcastss(
+ zmm3, ptr[BO2 + LDB3 + (0 - OFFSET) * SIZE]);
+ break;
+ }
+ } else {
+ vbroadcastss(zmm3, ptr[BO1 + (i - OFFSET) * SIZE]);
+ }
+ vfmadd231ps(regs[i], zmm3, zmm0);
+ if (unroll_m >= 32) {
+ vfmadd231ps(regs[i + 8], zmm3, zmm1);
+ }
+ if (unroll_m >= 48) {
+ vfmadd231ps(regs[i + 16], zmm3, zmm2);
+ }
+ }
+
+ if (isCopy) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
+ zmm0);
+ } else {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 16 - OFFSET) * SIZE],
+ zmm0 | k1);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 1 * 16 - OFFSET)
+ * SIZE],
+ zmm1 | k2);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2);
+ } else {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 2 * 16 - OFFSET)
+ * SIZE],
+ zmm2 | k3);
+ }
+ }
+ sub(LDA4, -unroll_m * SIZE);
+ }
+
+ if (!isDirect) {
+ if (isUnmasked || unroll_m > 16) {
+ vmovups(zmm0,
+ ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
+ } else {
+ vmovups(zmm0 | k1 | T_z,
+ ptr[AO1 + (unroll_m * 1 + 0 * 16 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32) {
+ vmovups(zmm1, ptr[AO1
+ + (unroll_m * 1 + 1 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm1 | k2 | T_z,
+ ptr[AO1
+ + (unroll_m * 1 + 1 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked) {
+ vmovups(zmm2, ptr[AO1
+ + (unroll_m * 1 + 2 * 16 - OFFSET)
+ * SIZE]);
+ } else {
+ vmovups(zmm2 | k3 | T_z,
+ ptr[AO1
+ + (unroll_m * 1 + 2 * 16 - OFFSET)
+ * SIZE]);
+ }
+ }
+ sub(AO1, -unroll_m * SIZE);
+ }
+
+ if (!isTransB) {
+ sub(BO1, -SIZE);
+ if (unroll_n >= 4) {
+ sub(BO2, -SIZE);
+ }
+ } else {
+ add(BO1, LDB);
+ }
+
+ sub(LL, 1);
+ jg(kernel16, T_NEAR);
+ align(16);
+
+ L(kernel18);
+ vbroadcastss(VALPHA, ALPHA);
+
+ if (isBetaN) {
+ vbroadcastss(VBETA, BETA);
+ }
+
+ // Write back the results; all beta cases need to be handled
+ if (hasBias) {
+ mov(BIAS1, BIAS);
+ if (isUnmasked || unroll_m > 16)
+ vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
+ else
+ vmovups(VBIAS1 | k1 | T_z, ptr[BIAS1 + 0 * SIZE]);
+ if (unroll_m >= 32) {
+ if (isUnmasked || unroll_m > 32)
+ vmovups(VBIAS2, ptr[BIAS1 + 16 * SIZE]);
+ else
+ vmovups(VBIAS2 | k2 | T_z, ptr[BIAS1 + 16 * SIZE]);
+ }
+ if (unroll_m >= 48) {
+ if (isUnmasked)
+ vmovups(VBIAS3, ptr[BIAS1 + 32 * SIZE]);
+ else
+ vmovups(VBIAS3 | k3 | T_z, ptr[BIAS1 + 32 * SIZE]);
+ }
+ }
+
+ for (int i = 0; i < unroll_n; i++) {
+ bool useScale = i % 2 != 0;
+ bool useCO1 = i < 2;
+ if (i == 2)
+ lea(CO2, ptr[CO1 + LDC * 2]);
+ if (i == 4 || i == 6)
+ lea(CO2, ptr[CO2 + LDC * 2]);
+ if (hasBias)
+ vaddps(regs[i], VBIAS1, regs[i]);
+ if (isUnmasked || unroll_m > 16) {
+ update(regs[i], useCO1, 0, 0, useScale);
+ } else {
+ update(regs[i], useCO1, 0, 1, useScale);
+ }
+ if (unroll_m >= 32) {
+ if (hasBias)
+ vaddps(regs[i + 8], VBIAS2, regs[i + 8]);
+ if (isUnmasked || unroll_m > 32) {
+ update(regs[i + 8], useCO1, 16, 0, useScale);
+ } else {
+ update(regs[i + 8], useCO1, 16, 2, useScale);
+ }
+ }
+ if (unroll_m >= 48) {
+ if (hasBias)
+ vaddps(regs[i + 16], VBIAS3, regs[i + 16]);
+ if (isUnmasked) {
+ update(regs[i + 16], useCO1, 32, 0, useScale);
+ } else {
+ update(regs[i + 16], useCO1, 32, 3, useScale);
+ }
+ }
+ }
+
+ switch (unroll_n) {
+ case 1: add(CO1, LDC); break;
+ case 2: lea(CO1, ptr[CO1 + LDC * 2]); break;
+ case 3: lea(CO1, ptr[CO2 + LDC * 1]); break;
+ case 4: lea(CO1, ptr[CO2 + LDC * 2]); break;
+ case 5: lea(CO1, ptr[CO2 + LDC * 1]); break;
+ case 6: lea(CO1, ptr[CO2 + LDC * 2]); break;
+ case 7: lea(CO1, ptr[CO2 + LDC * 1]); break;
+ case 8: lea(CO1, ptr[CO2 + LDC * 2]); break;
+ }
+
+ // Compute next address of B
+ if (!isTransB) {
+ lea(rax, ptr[K * SIZE]);
+ switch (unroll_n) {
+ case 1:
+ add(BO1, LDB);
+ add(BO2, LDB);
+ break;
+ case 2:
+ lea(BO1, ptr[BO1 + LDB * 2]);
+ lea(BO2, ptr[BO2 + LDB * 2]);
+ break;
+ case 3:
+ lea(BO1, ptr[BO1 + LDB3]);
+ lea(BO2, ptr[BO2 + LDB3]);
+ break;
+ case 4:
+ lea(BO1, ptr[BO1 + LDB * 4]);
+ lea(BO2, ptr[BO2 + LDB * 4]);
+ break;
+ case 5:
+ lea(BO1, ptr[BO1 + LDB * 4]);
+ add(BO1, LDB);
+ lea(BO2, ptr[BO2 + LDB * 4]);
+ add(BO2, LDB);
+ break;
+ case 6:
+ lea(BO1, ptr[BO1 + LDB3 * 2]);
+ lea(BO2, ptr[BO2 + LDB3 * 2]);
+ break;
+ case 7:
+ lea(BO1, ptr[BO1 + LDB * 8]);
+ sub(BO1, LDB);
+ lea(BO2, ptr[BO2 + LDB * 8]);
+ sub(BO2, LDB);
+ break;
+ case 8:
+ lea(BO1, ptr[BO1 + LDB * 8]);
+ lea(BO2, ptr[BO2 + LDB * 8]);
+ break;
+ }
+ sub(BO1, rax);
+ sub(BO2, rax);
+ } else {
+ mov(rax, LDB);
+ imul(rax, K);
+ sub(BO1, rax);
+ add(BO1, unroll_n * SIZE);
+ }
+ };
+
+ // High-level subroutine; does packing if needed, then splits C matrix.
+ // Operates on chunks of 48 rows, 8 columns at a time (handling tail
+ // cases appropriately by doing 32 or 16 rows, and/or with masking,
+ // and/or fewer columns).
+ auto subloop = [&](int unroll_m) {
+ Label l_subloop_20x[8], l_subloop_mask_20x[8];
+ Label l_subloop_30x[8], l_subloop_mask_30x[8];
+
+ Label subloop11, subloop11mask;
+ Label subloop30, subloop30mask;
+ Label subloop31, subloop31mask;
+ Label subloop96;
+ Label subloop98, subloop98mask;
+ Label subloop99;
+
+ // Create mask
+ mov(BO1, rcx);
+ mov(rcx, M);
+ sub(rcx, unroll_m - 16);
+ mov(CO1, 16);
+ cmp(rcx, 16);
+
+ cmovg(rcx, CO1);
+ mov(rax, 1);
+ sal(rax, cl);
+ sub(rax, 1);
+ mov(rcx, 0xffff);
+
+ if (unroll_m == 16) {
+ kmovw(k1, eax);
+ } else if (unroll_m == 32) {
+ kmovw(k1, ecx);
+ kmovw(k2, eax);
+ } else {
+ kmovw(k1, ecx);
+ kmovw(k2, ecx);
+ kmovw(k3, eax);
+ }
+ mov(rcx, BO1);
+
+ and_(rax, 0xffff);
+ cmp(rax, 0xffff);
+ jne(subloop96, T_NEAR);
+
+ if (isTransA) {
+ do_pack(unroll_m);
+ }
+
+ mov(CO1, C);
+ add(C, unroll_m * SIZE);
+
+ mov(BO1, B);
+ if (!isTransB) {
+ lea(BO2, ptr[B + LDB * 4]);
+ }
+
+ if (!isTransA) {
+ lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
+ cmp(M, UNROLL_M);
+ jg(subloop98, T_NEAR);
+
+ mov(AA, ORIG_A);
+ lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
+ L(subloop98);
+ }
+
+ mov(LL, N);
+ mov(I, LL);
+ if (!isTransA) {
+ // If N is too small, skip copy operation
+ cmp(LL, UNROLL_N * 3);
+ jle(subloop30, T_NEAR);
+
+ // If A is not aligned to cache line
+ cmp(FLAG, 0);
+ je(subloop30, T_NEAR);
+ } else {
+ cmp(LL, UNROLL_N);
+ jl(l_subloop_20x[1], T_NEAR);
+ }
+ align(16);
+
+ if (!isTransA) {
+ kernel(unroll_m, UNROLL_N, true, true);
+ } else {
+ kernel(unroll_m, UNROLL_N, false, false);
+ }
+
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jl(l_subloop_20x[1], T_NEAR);
+ align(16);
+
+ L(subloop11);
+ kernel(unroll_m, UNROLL_N, false, false);
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop11, T_NEAR);
+ align(16);
+
+ for (int i = 1; i <= 7; i++) {
+ L(l_subloop_20x[i]);
+ cmp(I, i);
+ if (i < 7) {
+ jne(l_subloop_20x[i + 1], T_NEAR);
+ } else {
+ jne(subloop99, T_NEAR);
+ }
+ kernel(unroll_m, i, false, false);
+ jmp(subloop99, T_NEAR);
+ align(16);
+ }
+
+ if (!isTransA) {
+ L(subloop30);
+ cmp(I, UNROLL_N);
+ jl(l_subloop_30x[1], T_NEAR);
+ align(16);
+
+ L(subloop31);
+ kernel(unroll_m, UNROLL_N, true, false);
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop31, T_NEAR);
+ align(16);
+
+ for (int i = 1; i <= 7; i++) {
+ L(l_subloop_30x[i]);
+ cmp(I, i);
+ if (i < 7) {
+ jne(l_subloop_30x[i + 1], T_NEAR);
+ } else {
+ jne(subloop99, T_NEAR);
+ }
+ kernel(unroll_m, i, true, false);
+ if (i < 7)
+ jmp(subloop99, T_NEAR);
+ align(16);
+ }
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop96);
+ if (isTransA) {
+ do_pack(unroll_m);
+ }
+
+ mov(CO1, C);
+ add(C, unroll_m * SIZE);
+ mov(BO1, B);
+ if (!isTransB) {
+ lea(BO2, ptr[B + LDB * 4]);
+ }
+
+ if (!isTransA) {
+ lea(AA, ptr[A + (unroll_m + 16 - 1 - OFFSET) * SIZE]);
+ cmp(M, UNROLL_M);
+ jg(subloop98mask, T_NEAR);
+ mov(AA, ORIG_A);
+ lea(AA, ptr[AA + (16 - 1 - OFFSET) * SIZE]);
+ L(subloop98mask);
+ }
+
+ mov(LL, N);
+ mov(I, LL);
+ if (!isTransA) {
+ // If N is too small, skip copy operation
+ cmp(LL, UNROLL_N * 3);
+ jle(subloop30mask, T_NEAR);
+
+ // If A is not aligned to cache line
+ cmp(FLAG, 0);
+ je(subloop30mask, T_NEAR);
+ } else {
+ cmp(LL, UNROLL_N);
+ jl(l_subloop_mask_20x[1], T_NEAR);
+ }
+ align(16);
+
+ if (!isTransA) {
+ kernel(unroll_m, UNROLL_N, true, true, false);
+ } else {
+ kernel(unroll_m, UNROLL_N, false, false, false);
+ }
+
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jl(l_subloop_mask_20x[1], T_NEAR);
+ align(16);
+
+ L(subloop11mask);
+ kernel(unroll_m, UNROLL_N, false, false, false);
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop11mask, T_NEAR);
+ align(16);
+
+ for (int i = 1; i <= 7; i++) {
+ L(l_subloop_mask_20x[i]);
+ cmp(I, i);
+ if (i < 7) {
+ jne(l_subloop_mask_20x[i + 1], T_NEAR);
+ } else {
+ jne(subloop99, T_NEAR);
+ }
+ kernel(unroll_m, i, false, false, false);
+ jmp(subloop99, T_NEAR);
+ align(16);
+ }
+
+ if (!isTransA) {
+ L(subloop30mask);
+ cmp(I, UNROLL_N);
+ jl(l_subloop_mask_30x[1], T_NEAR);
+ align(16);
+
+ L(subloop31mask);
+ kernel(unroll_m, UNROLL_N, true, false, false);
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop31mask, T_NEAR);
+ align(16);
+
+ for (int i = 1; i <= 7; i++) {
+ L(l_subloop_mask_30x[i]);
+ cmp(I, i);
+ if (i < 7) {
+ jne(l_subloop_mask_30x[i + 1], T_NEAR);
+ } else {
+ jne(subloop99, T_NEAR);
+ }
+ kernel(unroll_m, i, true, false, false);
+ if (i < 7)
+ jmp(subloop99, T_NEAR);
+ align(16);
+ }
+ }
+
+ L(subloop99);
+ // Compute address for A
+ if (!isTransA) {
+ add(A, unroll_m * SIZE);
+ } else {
+ mov(rax, LDA);
+ imul(rax, rax, unroll_m);
+ add(A, rax);
+ }
+
+ // Compute next address of BIAS
+ if (hasBias) {
+ add(BIAS, unroll_m * SIZE);
+ }
+ };
+
+ preamble();
+
+ Label buffer_in_ws, buffer_allocated;
+
+ // Get the registers
+ mov(B, ARG_B);
+ mov(LDB, ARG_LDB);
+ mov(r15, ARG_BETA);
+ mov(r12, ARG_C);
+ if (hasBias)
+ mov(r10, ARG_BIAS);
+ mov(LDC, ARG_LDC);
+ mov(rbp, rsp);
+
+ vmovss(xmm0, ptr[ARG_ALPHA]);
+ vmovss(xmm1, ptr[r15]);
+
+#if _WIN32
+ mov(A, ARG_A);
+ mov(LDA, ARG_LDA);
+#endif
+
+ cmp(K, STACK_K_CAPACITY);
+ jg(buffer_in_ws, T_NEAR);
+
+ // Create buffer and align to 4kB page
+ lea(rax, ptr[K * SIZE]);
+ imul(rax, rax, 0x30);
+ add(rax, 256);
+ sub(rsp, rax);
+ and_(rsp, -PAGE_4K);
+ jmp(buffer_allocated, T_NEAR);
+
+ L(buffer_in_ws);
+ mov(rsp, ARG_WS);
+
+ L(buffer_allocated);
+
+ mov(ORIG_SP, rbp);
+ mov(M, ARG_M);
+ mov(N, ARG_N);
+ mov(C, r12);
+ if (hasBias)
+ mov(BIAS, r10);
+ vmovss(ALPHA, xmm0);
+ vmovss(BETA, xmm1);
+ sub(A, -OFFSET * SIZE);
+ sub(B, -OFFSET * SIZE);
+ mov(ORIG_A, A);
+ sal(LDA, BASE_SHIFT);
+ sal(LDB, BASE_SHIFT);
+ sal(LDC, BASE_SHIFT);
+ lea(LDB3, ptr[LDB + LDB * 2]);
+
+ if (isTransA) {
+ vpbroadcastq(zmm2, LDA);
+ vpxorq(ZSTRIDE, ZSTRIDE, ZSTRIDE);
+ mov(rax, -2);
+ kmovw(k4, eax);
+
+ for (int i = 0; i < 6; i++) {
+ vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
+ kshiftlw(k4, k4, 1);
+ }
+ vpaddq(ZSTRIDE | k4, ZSTRIDE, zmm2);
+ }
+
+ // Check A alignment and leading dimension; take copy-based path as
+ // needed
+ mov(rax, LDA);
+ or_(rax, A);
+ and_(rax, ver == ver_avx512_core ? 0x07 : 0x3f);
+ mov(FLAG, rax);
+
+ for (int i = 8; i < 16; i++) {
+ for (int j = 0; j < 3; j++) {
+ vpxorq(Zmm(i + 8 * j), Zmm(i + 8 * j), Zmm(i + 8 * j));
+ }
+ }
+
+ Label main0, main1, main2, main999;
+
+ cmp(M, 32);
+ jle(main0, T_NEAR);
+ align(16);
+
+ L(main1);
+ subloop(48);
+ sub(M, UNROLL_M);
+ cmp(M, 32);
+ jg(main1, T_NEAR);
+ align(16);
+
+ L(main0);
+ cmp(M, 16);
+ jle(main2, T_NEAR);
+
+ subloop(32);
+ jmp(main999, T_NEAR);
+ align(16);
+
+ L(main2);
+ cmp(M, 0);
+ jle(main999, T_NEAR);
+ subloop(16);
+ align(16);
+
+ L(main999);
+ // Restore original stack
+ mov(rsp, ORIG_SP);
+
+ vzeroupper();
+ postamble();
+
+ ker_ = this->getCode<ker_t>();
+ }
+
+ typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws);
+
+ void operator()(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws) const
+ {
+ ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
+ }
+
+private:
+ ker_t ker_;
+};
+
+const xbyak_gemm *get_xbyak_gemm(
+ bool isTransA, bool isTransB, float beta, bool hasBias) {
+ auto beta_idx = [](float beta) {
+ return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
+ };
+
+ // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
+ static xbyak_gemm *kernel_table[2][2][2][3];
+ static std::once_flag initialized;
+ std::call_once(initialized, [=]{
+ for (bool isTransA: {false, true})
+ for (bool isTransB: {false, true})
+ for (bool hasBias: {false, true})
+ for (float beta: {0.0f, 1.0f, 2.0f}) {
+ // nocopy sgemm with bias for beta != 0.0 is not supported
+ if (hasBias && beta != 0.0)
+ continue;
+ kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
+ new xbyak_gemm(isTransA, isTransB, beta, hasBias);
+ }
+ });
+
+ return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
+}
+
+void sgemm_nocopy_driver(const char *transa,
+ const char *transb, int m, int n, int k, const float *alpha,
+ const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
+ float *c, dim_t ldc, const float *bias, float *ws)
+{
+ bool isTransA = (*transa == 'T' || *transa == 't');
+ bool isTransB = (*transb == 'T' || *transb == 't');
+
+ int Bm, sizeM, Bn, sizeN, Bk, sizeK;
+
+ int i, j;
+
+ if ((m <= 0) || (n <= 0))
+ return;
+
+ if ((k <= 0) || (alpha[0] == 0.)) {
+
+ if (beta[0] == 0.) {
+ for (j = 0; j < n; j++)
+ for (i = 0; i < m; i++)
+ c[i + j * ldc] = 0.0;
+ } else if (beta[0] != 1.) {
+ for (j = 0; j < n; j++)
+ for (i = 0; i < m; i++)
+ c[i + j * ldc] *= beta[0];
+ }
+
+ return;
+ }
+
+ assert(IMPLICATION(bias != nullptr, *beta == 0.0));
+
+ // XXX: this happens on every thread...
+ bool hasBias = (bias != nullptr);
+ auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
+ auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
+ auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
+ assert(ker_bn && ker_b1 && ker_b0);
+
+ int BM = 4032, BN, BK;
+ if (mayiuse(avx512_core)) {
+ BN = isTransA ? 384 : 64;
+ BK = 384;
+ } else {
+ BN = isTransA ? 96 : 64;
+ BK = isTransB ? 96 : 192;
+ if (!isTransA && !isTransB)
+ BK = 128;
+ }
+ const float *curA, *curB, *curBias = nullptr;
+ float *curC;
+
+ for (Bk = 0; Bk < k; Bk += sizeK) {
+ sizeK = k - Bk;
+ if (sizeK >= BK * 2)
+ sizeK = BK;
+ else {
+ if (sizeK > BK)
+ sizeK = (sizeK + 1) / 2;
+ }
+
+ for (Bm = 0; Bm < m; Bm += sizeM) {
+ sizeM = m - Bm;
+ if (sizeM >= BM * 2)
+ sizeM = BM;
+ else {
+ if (sizeM > BM + BM / 2)
+ sizeM = (sizeM + 1) / 2;
+ }
+
+ for (Bn = 0; Bn < n; Bn += sizeN) {
+ sizeN = n - Bn;
+ if (sizeN >= BN * 2)
+ sizeN = BN;
+ else {
+ if (sizeN > BN + BN / 2)
+ sizeN = (sizeN + 1) / 2;
+ }
+
+ if (!isTransA) {
+ curA = a + Bm + Bk * lda;
+ } else {
+ curA = a + Bk + Bm * lda;
+ }
+ if (!isTransB) {
+ curB = b + Bk + Bn * ldb;
+ } else {
+ curB = b + Bn + Bk * ldb;
+ }
+ curC = c + Bm + (size_t)Bn * ldc;
+ if (bias != nullptr) {
+ if (Bk == 0) {
+ curBias = bias + Bm;
+ } else {
+ curBias = nullptr;
+ }
+ }
+ if (Bk == 0) {
+ if (*beta == 0.0 && bias == nullptr)
+ (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ else
+ (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ } else {
+ (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ }
+ }
+ }
+ }
+}
+
+}
+
+mkldnn_status_t jit_avx512_common_gemm_f32(
+ const char *transa, const char *transb,
+ const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
+ const float *A, const int *p_lda, const float *B, const int *p_ldb,
+ const float *p_beta, float *C, const int *p_ldc, const float *bias)
+{
+ using namespace mkldnn::impl::utils;
+ using namespace avx512_common_gemm_f32;
+ using namespace gemm_utils;
+
+ if (*p_beta != 0 && bias)
+ return ref_gemm(transa, transb, p_m, p_n, p_k,
+ p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
+
+ int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
+
+ int m = *p_m;
+ int n = *p_n;
+ int k = *p_k;
+ dim_t lda = *p_lda;
+ dim_t ldb = *p_ldb;
+ dim_t ldc = *p_ldc;
+ float beta = *p_beta;
+ int MB, NB, KB;
+
+ int nthr_m, nthr_n, nthr_k, nthr_mn;
+
+ // Determine threading partitioning
+ calc_nthr_nocopy_avx512_common(
+ m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
+
+ // May not happen, but just in case
+ if (nthr < nthr_m * nthr_n * nthr_k)
+ nthr = nthr_m * nthr_n * nthr_k;
+
+ nthr_mn = nthr_m * nthr_n;
+
+ unsigned char * ompstatus_ = nullptr;
+ unsigned char volatile *ompstatus = nullptr;
+
+ float *c_buffers = nullptr;
+ float *ws_buffers = nullptr;
+
+ if (nthr_k > 1) {
+ ompstatus_ = (unsigned char *) malloc(
+ nthr * CACHE_LINE_SIZE,
+ CACHE_LINE_SIZE);
+ ompstatus = (unsigned char volatile *) ompstatus_;
+ assert(ompstatus);
+
+ for (int i = 0; i < nthr; i++)
+ ompstatus[i * CACHE_LINE_SIZE] = 0;
+
+ c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
+ * sizeof(float), PAGE_4K);
+ }
+
+ const size_t ws_elems_per_thr = (size_t)k * 48 + 64;
+ const size_t ws_size_per_thr
+ = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
+ if (k > STACK_K_CAPACITY) {
+ ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
+ }
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_m, ithr_n, ithr_k, ithr_mn;
+ int m_from, m_to, myM;
+ int n_from, n_to, myN;
+ int k_from, k_to, myK;
+ int cbase, ibase;
+ const float *myA, *myB, *myBias = nullptr;
+ float *myC = C, myBeta;
+ float *ws = ws_buffers ?
+ ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
+ dim_t ld = ldc;
+
+ int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
+
+ if (ithr < nthr_m * nthr_n * nthr_k) {
+
+ ithr_mn = ithr % nthr_mn;
+ ithr_m = ithr_mn % nthr_m;
+ ithr_n = ithr_mn / nthr_m;
+ ithr_k = ithr / nthr_mn;
+
+ /* swap ithr_k for performance improvement */
+ if (ithr_k == 0)
+ ithr_k = nthr_k - 1;
+ else if (ithr_k == nthr_k - 1)
+ ithr_k = 0;
+
+ m_from = MB * (ithr_m);
+ m_to = MB * (ithr_m + 1);
+ if (m_to > m)
+ m_to = m;
+ myM = m_to - m_from;
+
+ n_from = NB * (ithr_n);
+ n_to = NB * (ithr_n + 1);
+ if (n_to > n)
+ n_to = n;
+ myN = n_to - n_from;
+
+ k_from = KB * (ithr_k);
+ k_to = KB * (ithr_k + 1);
+ if (k_to > k)
+ k_to = k;
+ myK = k_to - k_from;
+
+ cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+ ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
+
+ if ((myM > 0) && (myN > 0)) {
+
+ if (*transa == 'N' || *transa == 'n') {
+ myA = &(A[m_from + k_from * lda]);
+ } else {
+ myA = &(A[k_from + m_from * lda]);
+ }
+ if (*transb == 'N' || *transb == 'n') {
+ myB = &(B[k_from + n_from * ldb]);
+ } else {
+ myB = &(B[n_from + k_from * ldb]);
+ }
+ if (ithr_k == 0) {
+ myC = &(C[m_from + n_from * ldc]);
+ myBeta = beta;
+ ld = ldc;
+ if (bias)
+ myBias = &(bias[m_from]);
+ } else {
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
+ myBeta = 0.0;
+ ld = MB;
+ myBias = nullptr;
+ }
+
+ sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
+ lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
+
+ if (nthr_k > 1 && !sum_later)
+ ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
+ }
+
+ if (nthr_k > 1 && !sum_later) {
+
+ // sum matrices partitioned along K dimension
+ int n1, n2;
+
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+ if (ithr_k > 0) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
+ /* need to wait until main thread finishes */
+ while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
+ };
+
+ /* my cache is hot */
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ if (ik != ithr_k) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
+
+ while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
+ };
+
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+ }
+ }
+ }
+ });
+
+
+ // handle C summation later
+ if (nthr_k > 1 && ompstatus[0] == 0) {
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_m, ithr_n, ithr_k, ithr_mn;
+ int m_from, m_to, myM;
+ int n_from, n_to, myN;
+ int cbase;
+ float *myC = C;
+
+ if (ithr < nthr_m * nthr_n * nthr_k) {
+
+ ithr_mn = ithr % nthr_mn;
+ ithr_m = ithr_mn % nthr_m;
+ ithr_n = ithr_mn / nthr_m;
+ ithr_k = ithr / nthr_mn;
+
+ /* swap ithr_k for performance improvement */
+ if (ithr_k == 0)
+ ithr_k = nthr_k - 1;
+ else if (ithr_k == nthr_k - 1)
+ ithr_k = 0;
+
+ m_from = MB * (ithr_m);
+ m_to = MB * (ithr_m + 1);
+ if (m_to > m)
+ m_to = m;
+ myM = m_to - m_from;
+
+ n_from = NB * (ithr_n);
+ n_to = NB * (ithr_n + 1);
+ if (n_to > n)
+ n_to = n;
+ myN = n_to - n_from;
+
+ cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ if (nthr_k > 1) {
+ // sum matrices partitioned along K dimension
+ int n1, n2;
+
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+ if (ithr_k > 0) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
+
+ /* my cache is hot */
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ if (ik != ithr_k) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
+
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+ }
+ }
+ }
+ });
+ }
+
+ free(c_buffers);
+ free(ompstatus_);
+ free(ws_buffers);
+
+ return mkldnn_success;
+}
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp
new file mode 100644
index 0000000000..d581b7fd71
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx512_common_gemm_f32.hpp
@@ -0,0 +1,36 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX512_COMMON_GEMM_F32_HPP
+#define JIT_AVX512_COMMON_GEMM_F32_HPP
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t jit_avx512_common_gemm_f32(
+ const char *transa, const char *transb, const int *M,
+ const int *N, const int *K, const float *alpha, const float *A,
+ const int *lda, const float *B, const int *ldb, const float *beta,
+ float *C, const int *ldc, const float *bias = nullptr);
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp
new file mode 100644
index 0000000000..60d4220837
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.cpp
@@ -0,0 +1,2705 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <cmath>
+#include <mutex>
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "ref_gemm_f32.hpp"
+#include "gemm_utils_f32.hpp"
+#include "jit_avx_gemm_f32.hpp"
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+#define CACHE_LINE_SIZE 64
+
+#define STACKSIZE get_size_of_abi_save_regs()
+#if _WIN32
+#define STACK_K_CAPACITY 128
+#else
+#define STACK_K_CAPACITY 8192
+#endif
+#define SIZE 4
+#define OFFSET 32
+#define BASE_SHIFT 2
+#define SECOND_FETCH 14
+
+namespace avx_gemm_f32 {
+using namespace gemm_utils;
+
+struct xbyak_gemm : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx_gemm_f32_xbyak_gemm)
+
+ xbyak_gemm(char isTransA, char isTransB, float beta, bool hasBias = false,
+ void *code_ptr = nullptr,
+ size_t code_size = 80 * Xbyak::DEFAULT_MAX_CODE_SIZE)
+ : jit_generator(code_ptr, code_size)
+ {
+ using namespace Xbyak;
+
+ const bool is_avx2 = mayiuse(avx2);
+ assert(IMPLICATION(!is_avx2, mayiuse(avx)));
+
+ const int UNROLL_M = is_avx2 ? 16 : 8;
+ const int UNROLL_N = 6;
+
+ bool isBeta0 = (beta == 0.0);
+ bool isBetaN = (!isBeta0 && beta != 1.0);
+
+ // various definitions for convenience
+ auto ARG_M = abi_param1;
+ auto ARG_N = abi_param2;
+ auto K = abi_param3;
+ auto ARG_ALPHA = abi_param4;
+#ifdef _WIN32
+ auto ARG_A = ptr[rsp + OFFSET_SHADOWSPACE + STACKSIZE];
+ auto ARG_LDA = qword[rsp + OFFSET_SHADOWSPACE +
+ sizeof(float *) + STACKSIZE];
+ const auto stackOffset = OFFSET_SHADOWSPACE +
+ sizeof(float *) + STACKSIZE;
+ auto A = rsi;
+ auto LDA = rdi;
+#else
+ auto ARG_A = r8;
+ auto ARG_LDA = r9;
+ const auto stackOffset = STACKSIZE;
+ auto A = ARG_A;
+ auto LDA = ARG_LDA;
+#endif
+ auto ARG_B = ptr[rsp + 8 + stackOffset];
+ auto ARG_LDB = ptr[rsp + 16 + stackOffset];
+ auto ARG_BETA = ptr[rsp + 24 + stackOffset];
+ auto ARG_C = ptr[rsp + 32 + stackOffset];
+ auto ARG_LDC = ptr[rsp + 40 + stackOffset];
+ auto ARG_BIAS = ptr[rsp + 48 + stackOffset];
+ auto ARG_WS = ptr[rsp + 56 + stackOffset];
+
+ auto B = r11;
+ auto LDB = rbx;
+ auto LDC = r13;
+ auto LL = rax;
+ auto AO1 = abi_param2;
+ auto BO1 = abi_param4;
+ auto BO2 = rbp;
+ auto CO1 = r14;
+ auto CO2 = r15;
+ auto LDB3 = r10;
+ auto LDA4 = abi_param1;
+ auto AA = r12;
+ auto BIAS1 = abi_param1;
+
+ auto M = qword[rsp + 0];
+ auto N = qword[rsp + 8];
+ auto FLAG = qword[rsp + 16];
+ auto I = qword[rsp + 24];
+ auto C = qword[rsp + 32];
+ auto BIAS = qword[rsp + 40];
+ auto ALPHA = qword[rsp + 48];
+ auto BETA = qword[rsp + 64];
+ auto ORIG_A = qword[rsp + 80];
+ auto MASK = dword[rsp + 88];
+ auto STRIDE = qword[rsp + 120];
+ auto ORIG_SP = qword[rsp + 152];
+
+ auto VALPHA = ymm1;
+ auto VBETA = ymm2;
+ auto VMASK = ymm3;
+ auto VBIAS1 = ymm2;
+ auto VBIAS2 = ymm4;
+
+ auto PREFETCHSIZEA = 128;
+ auto PREFETCHSIZEB = (!isTransB) ? -16 : 0;
+
+ // Function for packing if needed
+ auto do_pack = [&](
+ int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
+ Label pack2, pack3, pack4, pack10;
+
+ int regIdx;
+ Reg64 reg;
+
+ mov(BO1, A);
+ lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
+
+ if (isTransA) {
+ lea(BO2, ptr[BO1 + LDA * 4]);
+ lea(CO1, ptr[LDA + LDA * 2]);
+ vmovupd(ymm7, STRIDE);
+ }
+
+ mov(LL, K);
+ sar(LL, 2);
+ jle(pack3, T_NEAR);
+ align(16);
+
+ L(pack2);
+ if (!isTransA) {
+ for (int i = 0; i < 4; i++) {
+ regIdx = (i % 2 == 0) ? 4 : 6;
+ if (isLoad1Unmasked) {
+ vmovups(Ymm(regIdx),
+ ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(Ymm(regIdx), VMASK,
+ ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m > 8) {
+ if (isLoad2Unmasked) {
+ vmovups(Ymm(regIdx + 1),
+ ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(Ymm(regIdx + 1), VMASK,
+ ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(BO1, LDA);
+
+ vmovups(ptr[AO1 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
+ Ymm(regIdx));
+ if (unroll_m > 8) {
+ vmovups(ptr[AO1
+ + (unroll_m * i + 1 * 8 - OFFSET)
+ * SIZE],
+ Ymm(regIdx + 1));
+ }
+ }
+
+ } else {
+ if (isLoad1Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ reg = (i % 2 == 0) ? BO1 : BO2;
+ vmovups(xmm0, ptr[reg + (0 * 8 - OFFSET) * SIZE]);
+ vmovups(xmm1,
+ ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[reg + LDA * 2]);
+ vunpcklps(xmm4, xmm0, xmm1);
+ vunpckhps(xmm5, xmm0, xmm1);
+ vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovups(xmm1,
+ ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(xmm6, xmm0, xmm1);
+ vunpckhps(xmm2, xmm0, xmm1);
+
+ vunpcklpd(xmm0, xmm4, xmm6);
+ vunpckhpd(xmm1, xmm4, xmm6);
+ vmovups(ptr[AO1
+ + (unroll_m * 0 + i * 4 - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * 1 + i * 4 - OFFSET)
+ * SIZE],
+ xmm1);
+ vunpcklpd(xmm0, xmm5, xmm2);
+ vunpckhpd(xmm1, xmm5, xmm2);
+ vmovups(ptr[AO1
+ + (unroll_m * 2 + i * 4 - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * 3 + i * 4 - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+ } else if (is_avx2) {
+ for (int i = 0; i < 2; i++) {
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm0,
+ ptr[BO1 + ymm7 + ((2 * i) - OFFSET) * SIZE],
+ xmm4);
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm1,
+ ptr[BO1 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
+ xmm4);
+
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i) + 0 * 4 - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i + 1) + 0 * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+
+ lea(BO2, ptr[BO1 + LDA * 4]);
+
+ for (int i = 0; i < 2; i++) {
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm0,
+ ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
+ xmm4);
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7 + ((2 * i + 1) - OFFSET) * SIZE],
+ xmm4);
+
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i) + 1 * 4 - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i + 1) + 1 * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ } else {
+ vxorps(xmm4, xmm4, xmm4);
+ lea(BO2, ptr[BO1 + LDA * 4]);
+
+ auto el_cp = [&](int section, int ld_step) {
+ RegExp src_addr = section == 0 ? BO1 : BO2;
+ if (ld_step == 1 || ld_step == 2)
+ src_addr = src_addr + LDA * ld_step;
+ else if (ld_step == 3)
+ src_addr = src_addr + CO1;
+ src_addr = src_addr - OFFSET * SIZE;
+
+ vmovups(Xmm(ld_step % 2), ptr[src_addr]);
+ RegExp dst_addr = AO1
+ + (ld_step + section * 4 - OFFSET) * SIZE;
+ for (int off = 0; off < 4; ++off)
+ pextrd(ptr[dst_addr + unroll_m * off * SIZE],
+ Xmm(ld_step % 2), off);
+ };
+
+ Label l_end;
+ el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
+ el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
+ el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
+ el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
+ el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
+ el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
+ el_cp(1, 2);
+ L(l_end);
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ }
+
+ if (unroll_m >= 16) {
+ assert(is_avx2);
+ if (isLoad2Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovups(xmm1, ptr[BO2 + LDA * 1
+ + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(xmm4, xmm0, xmm1);
+ vunpckhps(xmm5, xmm0, xmm1);
+ vmovups(xmm0, ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovups(xmm1, ptr[BO2 + LDA * 1
+ + (0 * 8 - OFFSET) * SIZE]);
+ if (i == 0)
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(xmm6, xmm0, xmm1);
+ vunpckhps(xmm2, xmm0, xmm1);
+
+ vunpcklpd(xmm0, xmm4, xmm6);
+ vunpckhpd(xmm1, xmm4, xmm6);
+ vmovups(ptr[AO1
+ + (unroll_m * 0 + (i + 2) * 4
+ - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * 1 + (i + 2) * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ vunpcklpd(xmm0, xmm5, xmm2);
+ vunpckhpd(xmm1, xmm5, xmm2);
+ vmovups(ptr[AO1
+ + (unroll_m * 2 + (i + 2) * 4
+ - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * 3 + (i + 2) * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+ } else {
+ for (int i = 0; i < 2; i++) {
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm0,
+ ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
+ xmm4);
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7
+ + ((2 * i + 1) - OFFSET) * SIZE],
+ xmm4);
+
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i) + 2 * 4
+ - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i + 1) + 2 * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+
+ for (int i = 0; i < 2; i++) {
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm0,
+ ptr[BO2 + ymm7 + ((2 * i) - OFFSET) * SIZE],
+ xmm4);
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7
+ + ((2 * i + 1) - OFFSET) * SIZE],
+ xmm4);
+
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i) + 3 * 4
+ - OFFSET)
+ * SIZE],
+ xmm0);
+ vmovups(ptr[AO1
+ + (unroll_m * (2 * i + 1) + 3 * 4
+ - OFFSET)
+ * SIZE],
+ xmm1);
+ }
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ }
+ }
+ add(BO1, (4 * SIZE));
+ }
+
+ add(AO1, unroll_m * 4 * SIZE);
+ sub(LL, 1);
+ jg(pack2, T_NEAR);
+ align(16);
+
+ L(pack3);
+ mov(LL, K);
+ and_(LL, 3);
+ jle(pack10, T_NEAR);
+ align(16);
+
+ L(pack4);
+ if (!isTransA) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm4, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm4, VMASK, ptr[BO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m > 8) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm5, ptr[BO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm5, VMASK,
+ ptr[BO1 + (1 + 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(BO1, LDA);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
+ ymm4);
+ if (unroll_m > 8) {
+ vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
+ ymm5);
+ }
+ } else {
+ if (isLoad1Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ reg = (i % 2 == 0) ? BO1 : BO2;
+ vmovss(Xmm(i + 1), ptr[reg + (0 * 8 - OFFSET) * SIZE]);
+ vmovss(xmm0,
+ ptr[reg + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[reg + LDA * 2]);
+ vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
+ }
+ vunpcklpd(xmm1, xmm1, xmm2);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
+ xmm1);
+
+ for (int i = 0; i < 2; i++) {
+ vmovss(Xmm(i + 1), ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovss(xmm0,
+ ptr[BO2 + LDA * 1 + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
+ }
+ vunpcklpd(xmm1, xmm1, xmm2);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
+ xmm1);
+ } else if (is_avx2) {
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm1, ptr[BO1 + ymm7 + (0 * 8 - OFFSET) * SIZE],
+ xmm4);
+ lea(BO2, ptr[BO1 + LDA * 4]);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 0 * 4 - OFFSET) * SIZE],
+ xmm1);
+
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm1, ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
+ xmm4);
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ vmovups(ptr[AO1 + (unroll_m * 0 + 1 * 4 - OFFSET) * SIZE],
+ xmm1);
+ } else {
+ vxorps(xmm4, xmm4, xmm4);
+ lea(BO2, ptr[BO1 + LDA * 4]);
+
+ auto el_cp = [&](int section, int ld_step) {
+ RegExp src_addr = section == 0 ? BO1 : BO2;
+ if (ld_step == 1 || ld_step == 2)
+ src_addr = src_addr + LDA * ld_step;
+ else if (ld_step == 3)
+ src_addr = src_addr + CO1;
+ src_addr = src_addr - OFFSET * SIZE;
+
+ vmovss(xmm1, ptr[src_addr]);
+ RegExp dst_addr = AO1
+ + (ld_step + section * 4 - OFFSET) * SIZE;
+ movss(ptr[dst_addr], xmm1);
+ };
+
+ Label l_end;
+ el_cp(0, 0); cmp(M, 4 * 0 + 0 + 1); je(l_end, T_NEAR);
+ el_cp(0, 1); cmp(M, 4 * 0 + 1 + 1); je(l_end, T_NEAR);
+ el_cp(0, 2); cmp(M, 4 * 0 + 2 + 1); je(l_end, T_NEAR);
+ el_cp(0, 3); cmp(M, 4 * 0 + 3 + 1); je(l_end, T_NEAR);
+ el_cp(1, 0); cmp(M, 4 * 1 + 0 + 1); je(l_end, T_NEAR);
+ el_cp(1, 1); cmp(M, 4 * 1 + 1 + 1); je(l_end, T_NEAR);
+ el_cp(1, 2);
+ L(l_end);
+
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ }
+
+ if (unroll_m >= 16) {
+ assert(is_avx2);
+ if (isLoad2Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ vmovss(Xmm(i + 1),
+ ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovss(xmm0, ptr[BO2 + LDA * 1
+ + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
+ }
+ vunpcklpd(xmm1, xmm1, xmm2);
+ } else {
+ vmovaps(xmm4, xmm3);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
+ xmm4);
+ lea(BO2, ptr[BO2 + LDA * 4]);
+ }
+ vmovups(ptr[AO1 + (unroll_m * 0 + 2 * 4 - OFFSET) * SIZE],
+ xmm1);
+
+ if (isLoad2Unmasked) {
+ for (int i = 0; i < 2; i++) {
+ vmovss(Xmm(i + 1),
+ ptr[BO2 + (0 * 8 - OFFSET) * SIZE]);
+ vmovss(xmm0, ptr[BO2 + LDA * 1
+ + (0 * 8 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDA * 2]);
+ vunpcklps(Xmm(i + 1), Xmm(i + 1), Xmm(0));
+ }
+ vunpcklpd(xmm1, xmm1, xmm2);
+ } else {
+ vextractf128(xmm4, ymm3, 1);
+ vgatherqps(xmm1,
+ ptr[BO2 + ymm7 + (0 * 8 - OFFSET) * SIZE],
+ xmm4);
+ }
+ vmovups(ptr[AO1 + (unroll_m * 0 + 3 * 4 - OFFSET) * SIZE],
+ xmm1);
+ }
+ add(BO1, SIZE);
+ }
+
+ add(AO1, unroll_m * SIZE);
+ sub(LL, 1);
+ jg(pack4, T_NEAR);
+ align(16);
+
+ L(pack10);
+ };
+
+ // Fused multiply add; may become one or two instructions
+ auto fma = [&](bool useFma, Ymm reg0, Ymm reg1, Ymm reg2,
+ bool overWrite = false) {
+ if (useFma) {
+ if (is_avx2) {
+ vfmadd231ps(reg2, reg1, reg0);
+ } else {
+ assert(UNROLL_M == 8);
+ auto tent_vreg = overWrite ? reg1 : ymm1;
+ vmulps(tent_vreg, reg1, reg0);
+ vaddps(reg2, reg2, tent_vreg);
+ }
+ } else {
+ if (!overWrite) {
+ vmulps(ymm15, reg1, reg0);
+ vaddps(reg2, reg2, ymm15);
+ } else {
+ vmulps(reg1, reg1, reg0);
+ vaddps(reg2, reg2, reg1);
+ }
+ }
+ };
+
+ // Inner kernel with k=8
+ auto innerkernel8 = [&](int unroll_m, int unroll_n,
+ bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
+ bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
+ Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
+ Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
+ Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
+ Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
+ Ymm reg23) {
+
+ Ymm fmareg;
+
+ if (!isDirect) {
+ prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+
+ for (int i = 0; i < 8; i++) {
+ if (isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg00 : reg12;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg06 : reg18;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ if (i == 0) {
+ if (!isTransB) {
+ prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
+ }
+ }
+ if (unroll_n >= 2) {
+ if (!isTransB) {
+ if (i == 1) {
+ prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg01 : reg13;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg07 : reg19;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (isCopy) {
+ vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
+ ymm0);
+ if (unroll_m >= 16) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 1 * 8 - OFFSET)
+ * SIZE],
+ ymm1);
+ }
+ if (i == 7) {
+ sub(LDA4, -unroll_m * 8 * SIZE);
+ }
+ }
+
+ if (unroll_n >= 3) {
+ if (!isTransB) {
+ if (i == 2) {
+ prefetcht0(
+ ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg02 : reg14;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg08 : reg20;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (i == 7) {
+ if (!isTransB) {
+ sub(BO1, -8 * SIZE);
+ }
+ }
+
+ if (unroll_n >= 4) {
+ if (!isTransB) {
+ if (i == 3) {
+ prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg03 : reg15;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg09 : reg21;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 5) {
+ if (!isTransB) {
+ if (i == 4) {
+ prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg04 : reg16;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg10 : reg22;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 6) {
+ if (!isTransB) {
+ if (i == 5) {
+ prefetcht0(
+ ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg05 : reg17;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg11 : reg23;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+ if (isTransB) {
+ prefetcht0(ptr[BO1 + BO2]);
+ add(BO1, LDB);
+ }
+
+ if (i == 0) {
+ if (unroll_m >= 4) {
+ if (!isDirect) {
+ prefetcht0(
+ ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 1 || i == 2) {
+ if (unroll_m >= 8) {
+ if (!isDirect) {
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + (2 + 2 * i) * 8)
+ * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 3 || i == 4 || i == 5 || i == 6) {
+ if (unroll_m >= 16) {
+ if (!isDirect) {
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + (2 + 2 * i) * 8)
+ * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 7) {
+ if (!isTransB) {
+ if (unroll_n >= 4) {
+ sub(BO2, -8 * SIZE);
+ }
+ }
+ if (!isTransA) {
+ prefetcht2(ptr[AA]);
+ lea(AA, ptr[AA + LDA]);
+ }
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(
+ ymm0, VMASK,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
+ * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1
+ + (unroll_m * (i + 1) + 1 * 8
+ - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 1 * 8
+ - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+ }
+
+ if (!isDirect) {
+ sub(AO1, -unroll_m * 8 * SIZE);
+ }
+ sub(LL, 1);
+
+ };
+
+ // Inner kernel with k=4
+ auto innerkernel4 = [&](int unroll_m, int unroll_n,
+ bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
+ bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
+ Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
+ Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
+ Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
+ Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
+ Ymm reg23) {
+
+ Ymm fmareg;
+
+ if (!isDirect) {
+ prefetcht0(ptr[AO1 + (PREFETCHSIZEA + 0) * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+
+ for (int i = 0; i < 4; i++) {
+ if (isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg00 : reg12;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg06 : reg18;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ if (i == 0) {
+ if (!isTransB) {
+ prefetcht0(ptr[BO1 + PREFETCHSIZEB * SIZE]);
+ }
+ }
+ if (unroll_n >= 2) {
+ if (!isTransB) {
+ if (i == 1) {
+ prefetcht0(ptr[BO1 + LDB + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg01 : reg13;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg07 : reg19;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (isCopy) {
+ vmovups(ptr[LDA4 + (unroll_m * i + 0 * 8 - OFFSET) * SIZE],
+ ymm0);
+ if (unroll_m >= 16) {
+ vmovups(ptr[LDA4
+ + (unroll_m * i + 1 * 8 - OFFSET)
+ * SIZE],
+ ymm1);
+ }
+ if (i == 3) {
+ sub(LDA4, -unroll_m * 4 * SIZE);
+ }
+ }
+
+ if (unroll_n >= 3) {
+ if (!isTransB) {
+ if (i == 2) {
+ prefetcht0(
+ ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg02 : reg14;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg08 : reg20;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (i == 7) {
+ if (!isTransB) {
+ sub(BO1, -8 * SIZE);
+ }
+ }
+
+ if (unroll_n >= 4) {
+ if (!isTransB) {
+ if (i == 3) {
+ prefetcht0(ptr[BO2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(ymm2, ptr[BO2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg03 : reg15;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg09 : reg21;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 5) {
+ if (!isTransB) {
+ if (i == 4) {
+ prefetcht0(ptr[BO2 + LDB + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 1 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg04 : reg16;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg10 : reg22;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 6) {
+ if (!isTransB) {
+ if (i == 5) {
+ prefetcht0(
+ ptr[BO2 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 2 + (i - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg05 : reg17;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg11 : reg23;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+ if (isTransB) {
+ prefetcht0(ptr[BO1 + BO2]);
+ add(BO1, LDB);
+ }
+
+ if (i == 0) {
+ if (unroll_m >= 4) {
+ if (!isDirect) {
+ prefetcht0(
+ ptr[AO1 + (PREFETCHSIZEA + 2 * 8) * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 1 || i == 2) {
+ if (unroll_m >= 8) {
+ if (!isDirect) {
+ prefetcht0(ptr[AO1
+ + (PREFETCHSIZEA + (2 + 2 * i) * 8)
+ * SIZE]);
+ } else {
+ prefetcht0(ptr[AO1 + LDA4]);
+ }
+ }
+ }
+ if (i == 3) {
+ if (!isTransB) {
+ sub(BO1, -4 * SIZE);
+ if (unroll_n >= 4) {
+ sub(BO2, -4 * SIZE);
+ }
+ }
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(
+ ymm0, VMASK,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 0 * 8 - OFFSET)
+ * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1
+ + (unroll_m * (i + 1) + 1 * 8
+ - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * (i + 1) + 1 * 8
+ - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+ }
+
+ if (!isDirect) {
+ sub(AO1, -unroll_m * 4 * SIZE);
+ }
+
+ };
+
+ // Inner kernel with k=2
+ auto innerkernel2 = [&](int unroll_m, int unroll_n,
+ bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
+ bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
+ Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
+ Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11, Ymm reg12,
+ Ymm reg13, Ymm reg14, Ymm reg15, Ymm reg16, Ymm reg17,
+ Ymm reg18, Ymm reg19, Ymm reg20, Ymm reg21, Ymm reg22,
+ Ymm reg23) {
+
+ Ymm fmareg;
+
+ for (int i = 0; i < 2; i++) {
+ if (isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg00 : reg12;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg06 : reg18;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ if (unroll_n >= 2) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg01 : reg13;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg07 : reg19;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 3) {
+ if (!isTransB) {
+ if (i == 2) {
+ prefetcht0(
+ ptr[BO1 + LDB * 2 + PREFETCHSIZEB * SIZE]);
+ }
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg02 : reg14;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg08 : reg20;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 4) {
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg03 : reg15;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg09 : reg21;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 5) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg04 : reg16;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg10 : reg22;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (unroll_n >= 6) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
+ }
+ fmareg = (i % 2 == 0) ? reg05 : reg17;
+ fma(useFma, ymm0, ymm2, fmareg);
+ if (unroll_m >= 16) {
+ fmareg = (i % 2 == 0) ? reg11 : reg23;
+ fma(useFma, ymm1, ymm2, fmareg);
+ }
+ }
+
+ if (isCopy) {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
+ ymm0);
+ if (unroll_m >= 16) {
+ vmovups(ptr[LDA4
+ + (unroll_m * 0 + 1 * 8 - OFFSET)
+ * SIZE],
+ ymm1);
+ }
+ sub(LDA4, -unroll_m * SIZE);
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1
+ + (unroll_m * 1 + 0 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1
+ + (unroll_m * 1 + 0 * 8 - OFFSET)
+ * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1,
+ ptr[AO1
+ + (unroll_m * 1 + 1 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * 1 + 1 * 8 - OFFSET)
+ * SIZE]);
+ }
+ }
+ sub(AO1, -unroll_m * SIZE);
+ }
+
+ if (!isTransB) {
+ sub(BO1, -SIZE);
+ if (unroll_n >= 4) {
+ sub(BO2, -SIZE);
+ }
+ } else {
+ add(BO1, LDB);
+ }
+ }
+
+ };
+
+ // Inner kernel with k=1
+ auto innerkernel1 = [&](int unroll_m, int unroll_n,
+ bool isLoad1Unmasked, bool isLoad2Unmasked, bool isDirect,
+ bool isCopy, bool useFma, Ymm reg00, Ymm reg01, Ymm reg02,
+ Ymm reg03, Ymm reg04, Ymm reg05, Ymm reg06, Ymm reg07,
+ Ymm reg08, Ymm reg09, Ymm reg10, Ymm reg11) {
+
+ if (isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK, ptr[AO1 + (0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1 + (1 * 8 - OFFSET) * SIZE]);
+ }
+ }
+ add(AO1, LDA);
+ }
+
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (0 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg00);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg06);
+ }
+
+ if (unroll_n >= 2) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (1 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg01);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg07);
+ }
+ }
+
+ if (unroll_n >= 3) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO1 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (2 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg02);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg08);
+ }
+ }
+
+ if (unroll_n >= 4) {
+ if (!isTransB) {
+ vbroadcastss(ymm2, ptr[BO2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (3 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg03);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg09);
+ }
+ }
+
+ if (unroll_n >= 5) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 1 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (4 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg04);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg10);
+ }
+ }
+
+ if (unroll_n >= 6) {
+ if (!isTransB) {
+ vbroadcastss(
+ ymm2, ptr[BO2 + LDB * 2 + (0 - OFFSET) * SIZE]);
+ } else {
+ vbroadcastss(ymm2, ptr[BO1 + (5 - OFFSET) * SIZE]);
+ }
+ fma(useFma, ymm0, ymm2, reg05);
+ if (unroll_m >= 16) {
+ fma(useFma, ymm1, ymm2, reg11);
+ }
+ }
+
+ if (isCopy) {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE],
+ ymm0);
+ if (unroll_m >= 16) {
+ vmovups(ptr[LDA4 + (unroll_m * 0 + 1 * 8 - OFFSET) * SIZE],
+ ymm1);
+ }
+ sub(LDA4, -unroll_m * SIZE);
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0,
+ ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (unroll_m * 1 + 0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1
+ + (unroll_m * 1 + 1 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * 1 + 1 * 8 - OFFSET)
+ * SIZE]);
+ }
+ }
+ sub(AO1, -unroll_m * SIZE);
+ }
+
+ if (!isTransB) {
+ sub(BO1, -SIZE);
+ if (unroll_n >= 4) {
+ sub(BO2, -SIZE);
+ }
+ } else {
+ add(BO1, LDB);
+ }
+
+ };
+
+ // Main kernel; does prefetching and calls innerkernel{1,2,4,8} as
+ // appropriate
+ // After calculating results in registers, writes back to C matrix
+ auto kernel = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy, bool useFma,
+ Ymm reg00 = Ymm(4), Ymm reg01 = Ymm(5), Ymm reg02 = Ymm(6),
+ Ymm reg03 = Ymm(7), Ymm reg04 = Ymm(8), Ymm reg05 = Ymm(9),
+ Ymm reg06 = Ymm(10), Ymm reg07 = Ymm(11), Ymm reg08 = Ymm(12),
+ Ymm reg09 = Ymm(13), Ymm reg10 = Ymm(14), Ymm reg11 = Ymm(15),
+ Ymm reg12 = Ymm(4), Ymm reg13 = Ymm(5), Ymm reg14 = Ymm(6),
+ Ymm reg15 = Ymm(7), Ymm reg16 = Ymm(8), Ymm reg17 = Ymm(9),
+ Ymm reg18 = Ymm(10), Ymm reg19 = Ymm(11), Ymm reg20 = Ymm(12),
+ Ymm reg21 = Ymm(13), Ymm reg22 = Ymm(14), Ymm reg23 = Ymm(15)) {
+ if (!isDirect) {
+ lea(AO1, ptr[rsp + 256 + OFFSET * SIZE]);
+ } else {
+ mov(AO1, A);
+ }
+
+ if (isCopy) {
+ lea(LDA4, ptr[rsp + 256 + OFFSET * SIZE]);
+ } else {
+ lea(LDA4, ptr[LDA * 8 + (8 - 1 - OFFSET) * SIZE]);
+ }
+
+ if (isTransB) {
+ lea(BO2, ptr[LDB * 4 + (8 - 1 - OFFSET) * SIZE]);
+ lea(BO2, ptr[BO2 + LDB * 2]);
+ }
+
+ if (!isDirect) {
+ if (isLoad1Unmasked) {
+ vmovups(ymm0,
+ ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
+ } else {
+ vmaskmovps(ymm0, VMASK,
+ ptr[AO1 + (unroll_m * 0 + 0 * 8 - OFFSET) * SIZE]);
+ }
+ if (unroll_m >= 16) {
+ if (isLoad2Unmasked) {
+ vmovups(ymm1, ptr[AO1
+ + (unroll_m * 0 + 1 * 8 - OFFSET)
+ * SIZE]);
+ } else {
+ vmaskmovps(ymm1, VMASK,
+ ptr[AO1
+ + (unroll_m * 0 + 1 * 8 - OFFSET)
+ * SIZE]);
+ }
+ }
+ }
+
+ for (int i = 4; i < 10; i++) {
+ vxorps(Ymm(i), Ymm(i), Ymm(i));
+ vxorps(Ymm(i + 6), Ymm(i + 6), Ymm(i + 6));
+ }
+
+ mov(LL, K);
+ sar(LL, 3);
+
+ Label kernel12, kernel13, kernel14, kernel15;
+ Label kernel16, kernel17, kernel18;
+
+ sub(LL, SECOND_FETCH);
+ jle(kernel13, T_NEAR);
+ align(16);
+
+ L(kernel12);
+ innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
+ reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
+ reg21, reg22, reg23);
+ jg(kernel12, T_NEAR);
+ align(16);
+
+ L(kernel13);
+ prefetcht0(ptr[CO1 + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 2)
+ prefetcht0(ptr[CO1 + LDC + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 3)
+ prefetcht0(ptr[CO1 + LDC * 2 + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 4)
+ prefetcht0(ptr[CO2 + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 5)
+ prefetcht0(ptr[CO2 + LDC + (unroll_m - 1) * SIZE]);
+ if (unroll_n >= 6)
+ prefetcht0(ptr[CO2 + LDC * 2 + (unroll_m - 1) * SIZE]);
+
+ add(LL, SECOND_FETCH);
+ jle(kernel15, T_NEAR);
+ align(16);
+
+ L(kernel14);
+ innerkernel8(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
+ reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
+ reg21, reg22, reg23);
+ jg(kernel14, T_NEAR);
+ align(16);
+
+ L(kernel15);
+ test(K, 4);
+ jle(kernel16, T_NEAR);
+ innerkernel4(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
+ reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
+ reg21, reg22, reg23);
+
+ L(kernel16);
+ test(K, 2);
+ jle(kernel17, T_NEAR);
+ innerkernel2(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11, reg12,
+ reg13, reg14, reg15, reg16, reg17, reg18, reg19, reg20,
+ reg21, reg22, reg23);
+ align(16);
+
+ L(kernel17);
+ if (unroll_m == 16) {
+ if (unroll_n <= 3) {
+ vaddps(reg00, reg00, reg12);
+ vaddps(reg01, reg01, reg13);
+ vaddps(reg02, reg02, reg14);
+ vaddps(reg06, reg06, reg18);
+ vaddps(reg07, reg07, reg19);
+ vaddps(reg08, reg08, reg20);
+ }
+ }
+
+ if (unroll_m <= 8) {
+ vaddps(reg00, reg00, reg12);
+ vaddps(reg01, reg01, reg13);
+ vaddps(reg02, reg02, reg14);
+ vaddps(reg03, reg03, reg15);
+ vaddps(reg04, reg04, reg16);
+ vaddps(reg05, reg05, reg17);
+ }
+
+ test(K, 1);
+ jle(kernel18, T_NEAR);
+ innerkernel1(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, reg00, reg01, reg02, reg03, reg04,
+ reg05, reg06, reg07, reg08, reg09, reg10, reg11);
+ align(16);
+
+ L(kernel18);
+ vbroadcastss(VALPHA, ALPHA);
+
+ if (isBetaN) {
+ vbroadcastss(VBETA, BETA);
+ }
+
+ // Write back the results; all beta and bias cases need to be
+ // handled
+ switch (unroll_n) {
+ case 1: mov(rax, LDC); break;
+ case 2: lea(rax, ptr[LDC * 2]); break;
+ case 3: lea(rax, ptr[LDC + LDC * 2]); break;
+ case 4: lea(rax, ptr[LDC + LDC * 4]); break;
+ case 5:
+ lea(rax, ptr[LDC * 4]);
+ add(rax, LDC);
+ break;
+ case 6:
+ lea(rax, ptr[LDC + LDC * 2]);
+ add(rax, rax);
+ break;
+ }
+
+ if (hasBias) {
+ mov(BIAS1, BIAS);
+ if (isLoad1Unmasked) {
+ vmovups(VBIAS1, ptr[BIAS1 + 0 * SIZE]);
+ } else {
+ vmaskmovps(VBIAS1, VMASK, ptr[BIAS1 + 0 * SIZE]);
+ }
+ }
+
+ for (int i = 0; i < unroll_n; i++) {
+ vmulps(Ymm(i + 4), Ymm(i + 4), VALPHA);
+ if (!isBeta0) {
+ if (isLoad1Unmasked) {
+ switch (i) {
+ case 0: vmovups(ymm0, ptr[CO1 + 0 * SIZE]); break;
+ case 1: vmovups(ymm0, ptr[CO1 + LDC + 0 * SIZE]); break;
+ case 2:
+ vmovups(ymm0, ptr[CO1 + LDC * 2 + 0 * SIZE]);
+ break;
+ case 3: vmovups(ymm0, ptr[CO2 + 0 * SIZE]); break;
+ case 4: vmovups(ymm0, ptr[CO2 + LDC + 0 * SIZE]); break;
+ case 5:
+ vmovups(ymm0, ptr[CO2 + LDC * 2 + 0 * SIZE]);
+ break;
+ }
+ } else {
+ switch (i) {
+ case 0:
+ vmaskmovps(ymm0, VMASK, ptr[CO1 + 0 * SIZE]);
+ break;
+ case 1:
+ vmaskmovps(ymm0, VMASK, ptr[CO1 + LDC + 0 * SIZE]);
+ break;
+ case 2:
+ vmaskmovps(
+ ymm0, VMASK, ptr[CO1 + LDC * 2 + 0 * SIZE]);
+ break;
+ case 3:
+ vmaskmovps(ymm0, VMASK, ptr[CO2 + 0 * SIZE]);
+ break;
+ case 4:
+ vmaskmovps(ymm0, VMASK, ptr[CO2 + LDC + 0 * SIZE]);
+ break;
+ case 5:
+ vmaskmovps(
+ ymm0, VMASK, ptr[CO2 + LDC * 2 + 0 * SIZE]);
+ break;
+ }
+ }
+
+ if (!isBetaN) {
+ vaddps(Ymm(i + 4), ymm0, Ymm(i + 4));
+ } else {
+ fma(useFma, VBETA, ymm0, Ymm(i + 4), true);
+ }
+ }
+ if (hasBias) {
+ vaddps(Ymm(i + 4), VBIAS1, Ymm(i + 4));
+ }
+ if (isLoad1Unmasked) {
+ switch (i) {
+ case 0: vmovups(ptr[CO1 + 0 * SIZE], Ymm(i + 4)); break;
+ case 1:
+ vmovups(ptr[CO1 + LDC + 0 * SIZE], Ymm(i + 4));
+ break;
+ case 2:
+ vmovups(ptr[CO1 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
+ break;
+ case 3: vmovups(ptr[CO2 + 0 * SIZE], Ymm(i + 4)); break;
+ case 4:
+ vmovups(ptr[CO2 + LDC + 0 * SIZE], Ymm(i + 4));
+ break;
+ case 5:
+ vmovups(ptr[CO2 + LDC * 2 + 0 * SIZE], Ymm(i + 4));
+ break;
+ }
+ } else {
+ switch (i) {
+ case 0:
+ vmaskmovps(ptr[CO1 + 0 * SIZE], VMASK, Ymm(i + 4));
+ break;
+ case 1:
+ vmaskmovps(
+ ptr[CO1 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
+ break;
+ case 2:
+ vmaskmovps(ptr[CO1 + LDC * 2 + 0 * SIZE], VMASK,
+ Ymm(i + 4));
+ break;
+ case 3:
+ vmaskmovps(ptr[CO2 + 0 * SIZE], VMASK, Ymm(i + 4));
+ break;
+ case 4:
+ vmaskmovps(
+ ptr[CO2 + LDC + 0 * SIZE], VMASK, Ymm(i + 4));
+ break;
+ case 5:
+ vmaskmovps(ptr[CO2 + LDC * 2 + 0 * SIZE], VMASK,
+ Ymm(i + 4));
+ break;
+ }
+ }
+
+ if (unroll_m >= 16) {
+ // Re-use ymm4 (VBIAS2)
+ if (i == 0) {
+ if (hasBias) {
+ if (isLoad1Unmasked) {
+ vmovups(VBIAS2, ptr[BIAS1 + 8 * SIZE]);
+ } else {
+ vmaskmovps(
+ VBIAS2, VMASK, ptr[BIAS1 + 8 * SIZE]);
+ }
+ }
+ }
+ vmulps(Ymm(i + 10), Ymm(i + 10), VALPHA);
+ if (!isBeta0) {
+ if (isLoad2Unmasked) {
+ switch (i) {
+ case 0: vmovups(ymm0, ptr[CO1 + 8 * SIZE]); break;
+ case 1:
+ vmovups(ymm0, ptr[CO1 + LDC + 8 * SIZE]);
+ break;
+ case 2:
+ vmovups(ymm0, ptr[CO1 + LDC * 2 + 8 * SIZE]);
+ break;
+ case 3: vmovups(ymm0, ptr[CO2 + 8 * SIZE]); break;
+ case 4:
+ vmovups(ymm0, ptr[CO2 + LDC + 8 * SIZE]);
+ break;
+ case 5:
+ vmovups(ymm0, ptr[CO2 + LDC * 2 + 8 * SIZE]);
+ break;
+ }
+ } else {
+ switch (i) {
+ case 0:
+ vmaskmovps(ymm0, VMASK, ptr[CO1 + 8 * SIZE]);
+ break;
+ case 1:
+ vmaskmovps(
+ ymm0, VMASK, ptr[CO1 + LDC + 8 * SIZE]);
+ break;
+ case 2:
+ vmaskmovps(ymm0, VMASK,
+ ptr[CO1 + LDC * 2 + 8 * SIZE]);
+ break;
+ case 3:
+ vmaskmovps(ymm0, VMASK, ptr[CO2 + 8 * SIZE]);
+ break;
+ case 4:
+ vmaskmovps(
+ ymm0, VMASK, ptr[CO2 + LDC + 8 * SIZE]);
+ break;
+ case 5:
+ vmaskmovps(ymm0, VMASK,
+ ptr[CO2 + LDC * 2 + 8 * SIZE]);
+ break;
+ }
+ }
+ if (!isBetaN) {
+ vaddps(Ymm(i + 10), ymm0, Ymm(i + 10));
+ } else {
+ fma(useFma, VBETA, ymm0, Ymm(i + 10), true);
+ }
+ }
+ if (hasBias) {
+ vaddps(Ymm(i + 10), VBIAS2, Ymm(i + 10));
+ }
+ if (isLoad2Unmasked) {
+ switch (i) {
+ case 0:
+ vmovups(ptr[CO1 + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 1:
+ vmovups(ptr[CO1 + LDC + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 2:
+ vmovups(ptr[CO1 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 3:
+ vmovups(ptr[CO2 + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 4:
+ vmovups(ptr[CO2 + LDC + 8 * SIZE], Ymm(i + 10));
+ break;
+ case 5:
+ vmovups(ptr[CO2 + LDC * 2 + 8 * SIZE], Ymm(i + 10));
+ break;
+ }
+ } else {
+ switch (i) {
+ case 0:
+ vmaskmovps(ptr[CO1 + 8 * SIZE], VMASK, Ymm(i + 10));
+ break;
+ case 1:
+ vmaskmovps(ptr[CO1 + LDC + 8 * SIZE], VMASK,
+ Ymm(i + 10));
+ break;
+ case 2:
+ vmaskmovps(ptr[CO1 + LDC * 2 + 8 * SIZE], VMASK,
+ Ymm(i + 10));
+ break;
+ case 3:
+ vmaskmovps(ptr[CO2 + 8 * SIZE], VMASK, Ymm(i + 10));
+ break;
+ case 4:
+ vmaskmovps(ptr[CO2 + LDC + 8 * SIZE], VMASK,
+ Ymm(i + 10));
+ break;
+ case 5:
+ vmaskmovps(ptr[CO2 + LDC * 2 + 8 * SIZE], VMASK,
+ Ymm(i + 10));
+ break;
+ }
+ }
+ }
+ if (i == 2)
+ add(CO1, rax);
+ }
+ if (unroll_n >= 4) {
+ add(CO2, rax);
+ }
+
+ // Compute next address of B
+ if (!isTransB) {
+ lea(rax, ptr[K * SIZE]);
+ switch (unroll_n) {
+ case 1:
+ add(BO1, LDB);
+ add(BO2, LDB);
+ break;
+ case 2:
+ lea(BO1, ptr[BO1 + LDB * 2]);
+ lea(BO2, ptr[BO2 + LDB * 2]);
+ break;
+ case 3:
+ lea(BO1, ptr[BO1 + LDB3]);
+ lea(BO2, ptr[BO2 + LDB3]);
+ break;
+ case 4:
+ lea(BO1, ptr[BO1 + LDB * 4]);
+ lea(BO2, ptr[BO2 + LDB * 4]);
+ break;
+ case 5:
+ lea(BO1, ptr[BO1 + LDB * 4]);
+ add(BO1, LDB);
+ lea(BO2, ptr[BO2 + LDB * 4]);
+ add(BO2, LDB);
+ break;
+ case 6:
+ lea(BO1, ptr[BO1 + LDB3 * 2]);
+ lea(BO2, ptr[BO2 + LDB3 * 2]);
+ break;
+ }
+ sub(BO1, rax);
+ sub(BO2, rax);
+ } else {
+ mov(rax, LDB);
+ imul(rax, K);
+ sub(BO1, rax);
+ add(BO1, unroll_n * SIZE);
+ }
+ };
+
+ auto kernel_16x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, true);
+ };
+
+ auto kernel_16x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, true);
+ };
+
+ auto kernel_16x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, true);
+ };
+
+ auto kernel_16x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy,
+ bool useFma = true) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
+ Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
+ Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
+ Ymm(13), Ymm(14), Ymm(15));
+ };
+
+ auto kernel_16x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, false);
+ };
+
+ auto kernel_16x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_16x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, false);
+ };
+
+ auto kernel_8x6 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy,
+ bool useFma = true) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
+ Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
+ Ymm(15), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
+ Ymm(15));
+ };
+
+ auto kernel_8x5 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy);
+ };
+
+ auto kernel_8x4 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_8x6(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy);
+ };
+
+ auto kernel_8x3 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy,
+ bool useFma = true) {
+ kernel(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, useFma, Ymm(4), Ymm(5), Ymm(6), Ymm(7),
+ Ymm(8), Ymm(9), Ymm(10), Ymm(11), Ymm(12), Ymm(13), Ymm(14),
+ Ymm(15), Ymm(7), Ymm(8), Ymm(9), Ymm(7), Ymm(8), Ymm(9),
+ Ymm(13), Ymm(14), Ymm(15));
+ };
+
+ auto kernel_8x2 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, false);
+ };
+
+ auto kernel_8x1 = [&](int unroll_m, int unroll_n, bool isLoad1Unmasked,
+ bool isLoad2Unmasked, bool isDirect, bool isCopy) {
+ kernel_8x3(unroll_m, unroll_n, isLoad1Unmasked, isLoad2Unmasked,
+ isDirect, isCopy, false);
+ };
+
+ // High-level subroutine; does packing if needed, then splits C matrix.
+ // Operates on chunks of 16 rows, 6 columns at a time (handling tail
+ // cases appropriately).
+ // Masking is used for tail cases where M is not divisible by 8.
+ auto subloop = [&](
+ int unroll_m, bool isLoad1Unmasked, bool isLoad2Unmasked) {
+ if (isTransA) {
+ do_pack(unroll_m, isLoad1Unmasked, isLoad2Unmasked);
+ }
+
+ Label subloop11, subloop11mask;
+ Label subloop20, subloop21, subloop22, subloop23;
+ Label subloop24, subloop25;
+ Label subloop30, subloop31, subloop32, subloop33;
+ Label subloop34, subloop35;
+ Label subloop98, subloop98mask;
+ Label subloop99, subloop99mask;
+
+ mov(CO1, C);
+ lea(CO2, ptr[CO1 + LDC * 2]);
+ add(CO2, LDC);
+ add(C, unroll_m * SIZE);
+ mov(BO1, B);
+ if (!isTransB) {
+ lea(BO2, qword[B + LDB3]);
+ }
+
+ if (!isTransA) {
+ lea(AA, ptr[A + (unroll_m * 2 - 1 - OFFSET) * SIZE]);
+ cmp(M, UNROLL_M);
+ jg(subloop98, T_NEAR);
+
+ mov(AA, ORIG_A);
+ lea(AA, ptr[AA + (unroll_m - 1 - OFFSET) * SIZE]);
+ L(subloop98);
+ }
+
+ mov(LL, N);
+ mov(I, LL);
+ if (!isTransA) {
+ // If N is too small, skip copy operation
+ cmp(LL, UNROLL_N * 3);
+ jle(subloop30, T_NEAR);
+
+ // If A is not aligned to cache line
+ cmp(FLAG, 0);
+ je(subloop30, T_NEAR);
+ } else {
+ cmp(LL, UNROLL_N);
+ jl(subloop20, T_NEAR);
+ }
+ align(16);
+
+ if (!isTransA) {
+ if (unroll_m == 16) {
+ kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, true, true);
+ } else {
+ kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, true, true);
+ }
+ } else {
+ if (unroll_m == 16) {
+ kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, false, false);
+ } else {
+ kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, false, false);
+ }
+ }
+
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jl(subloop20, T_NEAR);
+ align(16);
+
+ L(subloop11);
+ if (unroll_m == 16) {
+ kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, false, false);
+ } else {
+ kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ }
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop11, T_NEAR);
+ align(16);
+
+ L(subloop20);
+ cmp(I, 1);
+ jne(subloop21, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop21);
+ cmp(I, 2);
+ jne(subloop22, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop22);
+ cmp(I, 3);
+ jne(subloop23, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop23);
+ cmp(I, 4);
+ jne(subloop24, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop24);
+ cmp(I, 5);
+ jne(subloop99, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
+ false, false);
+ } else {
+ kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked, false,
+ false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ if (!isTransA) {
+ L(subloop30);
+ cmp(I, UNROLL_N);
+ jl(subloop25, T_NEAR);
+ align(16);
+
+ L(subloop31);
+ if (unroll_m == 16) {
+ kernel_16x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, true, false);
+ } else {
+ kernel_8x6(unroll_m, UNROLL_N, isLoad1Unmasked,
+ isLoad2Unmasked, true, false);
+ }
+ sub(I, UNROLL_N);
+ cmp(I, UNROLL_N);
+ jge(subloop31, T_NEAR);
+ align(16);
+
+ L(subloop25);
+ cmp(I, 1);
+ jne(subloop32, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x1(unroll_m, 1, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop32);
+ cmp(I, 2);
+ jne(subloop33, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x2(unroll_m, 2, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop33);
+ cmp(I, 3);
+ jne(subloop34, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x3(unroll_m, 3, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop34);
+ cmp(I, 4);
+ jne(subloop35, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x4(unroll_m, 4, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ jmp(subloop99, T_NEAR);
+ align(16);
+
+ L(subloop35);
+ cmp(I, 5);
+ jne(subloop99, T_NEAR);
+ if (unroll_m == 16) {
+ kernel_16x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ } else {
+ kernel_8x5(unroll_m, 5, isLoad1Unmasked, isLoad2Unmasked,
+ true, false);
+ }
+ align(16);
+ }
+
+ L(subloop99);
+ // Compute address for A
+ if (!isTransA) {
+ add(A, unroll_m * SIZE);
+ } else {
+ mov(rax, LDA);
+ imul(rax, rax, unroll_m);
+ add(A, rax);
+ }
+
+ // Compute next address of BIAS
+ if (hasBias) {
+ add(BIAS, unroll_m * SIZE);
+ }
+ };
+
+ preamble();
+
+ Label buffer_in_ws, buffer_allocated;
+
+ // Get the registers
+ mov(B, ARG_B);
+ mov(LDB, ARG_LDB);
+ mov(r15, ARG_BETA);
+ mov(r12, ARG_C);
+ if (hasBias)
+ mov(r10, ARG_BIAS);
+ mov(LDC, ARG_LDC);
+ mov(rbp, rsp);
+
+ vmovss(xmm0, ptr[ARG_ALPHA]);
+ vmovss(xmm1, ptr[r15]);
+
+#if _WIN32
+ mov(A, ARG_A);
+ mov(LDA, ARG_LDA);
+#endif
+
+ cmp(K, STACK_K_CAPACITY);
+ jg(buffer_in_ws, T_NEAR);
+
+ // Create buffer and align to 4kB page
+ lea(rax, ptr[K * SIZE]);
+ sal(rax, 4);
+ add(rax, 256);
+ sub(rsp, rax);
+ and_(rsp, -PAGE_4K);
+ jmp(buffer_allocated, T_NEAR);
+
+ L(buffer_in_ws);
+ mov(rsp, ARG_WS);
+
+ L(buffer_allocated);
+
+ mov(ORIG_SP, rbp);
+ mov(M, ARG_M);
+ mov(N, ARG_N);
+ mov(C, r12);
+ if (hasBias)
+ mov(BIAS, r10);
+ vmovss(ALPHA, xmm0);
+ vmovss(BETA, xmm1);
+ sub(A, -OFFSET * SIZE);
+ sub(B, -OFFSET * SIZE);
+ mov(ORIG_A, A);
+ sal(LDA, BASE_SHIFT);
+ sal(LDB, BASE_SHIFT);
+ sal(LDC, BASE_SHIFT);
+ lea(LDB3, ptr[LDB + LDB * 2]);
+
+ for (int i = 0; i < 8; i++) {
+ mov(dword[rsp + 88 + i * 4], i);
+ }
+
+ if (isTransA && is_avx2) {
+ movq(xmm0, LDA);
+ vpbroadcastq(ymm1, xmm0);
+ vinsertf128(ymm0, ymm0, xmm0, 1);
+ vpermilpd(ymm0, ymm0, 5);
+ vpaddq(ymm1, ymm1, ymm1);
+ vperm2f128(ymm1, ymm1, ymm1, 8);
+ vpaddq(ymm0, ymm0, ymm1);
+ vmovups(STRIDE, ymm0);
+ }
+
+ // Check A alignment and leading dimension; take copy-based path as
+ // needed
+ mov(rax, LDA);
+ or_(rax, A);
+ and_(rax, 0x1f);
+ mov(FLAG, rax);
+
+ Label main0, main1, main2, main3, main999;
+
+ cmp(M, UNROLL_M);
+ jl(main0, T_NEAR);
+ align(16);
+
+ L(main1);
+ subloop(UNROLL_M, true, true);
+ sub(M, UNROLL_M);
+ cmp(M, UNROLL_M);
+ jge(main1, T_NEAR);
+ align(16);
+
+ L(main0);
+ cmp(M, 0);
+ jle(main999, T_NEAR);
+
+ if (UNROLL_M > 8) {
+ cmp(M, 8);
+ jle(main2, T_NEAR);
+
+ sub(M, 8);
+ vbroadcastss(VMASK, M);
+ vpcmpgtd(VMASK, VMASK, MASK);
+
+ subloop(16, true, false);
+ jmp(main999, T_NEAR);
+ align(16);
+
+ L(main2);
+ cmp(M, 8);
+ jne(main3, T_NEAR);
+ subloop(8, true, true);
+ jmp(main999, T_NEAR);
+ }
+
+ align(16);
+
+ L(main3);
+ vbroadcastss(VMASK, M);
+ if (is_avx2) {
+ vpcmpgtd(VMASK, VMASK, MASK);
+ } else {
+ auto xmask = Xmm(VMASK.getIdx());
+ auto xmm_tmp = xmm4;
+
+ vextractf128(xmm_tmp, VMASK, 1);
+ vpcmpgtd(xmask, xmask, MASK);
+ vpcmpgtd(xmm_tmp, xmm_tmp, dword[rsp + 88 + 4 * 4]); // MASK + 4
+ vinsertf128(VMASK, VMASK, xmm_tmp, 1);
+ }
+ subloop(8, false, false);
+ align(16);
+
+ L(main999);
+ // Restore original stack
+ mov(rsp, ORIG_SP);
+
+ vzeroupper();
+ postamble();
+
+ ker_ = this->getCode<ker_t>();
+ }
+
+ typedef void (*ker_t)(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws);
+
+ void operator()(dim_t m, dim_t n, dim_t k,
+ const float *alpha, const float *a, dim_t lda,
+ const float *b, dim_t ldb, const float *beta, float *c,
+ dim_t ldc, const float *bias, float *ws) const
+ {
+ ker_(m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, bias, ws);
+ }
+
+private:
+ ker_t ker_;
+};
+
+const xbyak_gemm *get_xbyak_gemm(
+ bool isTransA, bool isTransB, float beta, bool hasBias) {
+ auto beta_idx = [](float beta) {
+ return (beta == 0.0) ? 0 : (beta == 1.0 ? 1 : 2);
+ };
+
+ // Kernel table [isTransA][isTransB][hasBias][beta (0, 1, other)]
+ static xbyak_gemm *kernel_table[2][2][2][3];
+ static std::once_flag initialized;
+ std::call_once(initialized, [=]{
+ for (bool isTransA: {false, true})
+ for (bool isTransB: {false, true})
+ for (bool hasBias: {false, true})
+ for (float beta: {0.0f, 1.0f, 2.0f}) {
+ // nocopy sgemm with bias for beta != 0.0 is not supported
+ if (hasBias && beta != 0.0)
+ continue;
+ kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)] =
+ new xbyak_gemm(isTransA, isTransB, beta, hasBias);
+ }
+ });
+
+ return kernel_table[isTransA][isTransB][hasBias][beta_idx(beta)];
+}
+
+void sgemm_nocopy_driver(const char *transa,
+ const char *transb, int m, int n, int k, const float *alpha,
+ const float *a, dim_t lda, const float *b, dim_t ldb, const float *beta,
+ float *c, dim_t ldc, const float *bias, float *ws)
+{
+ bool isTransA = (*transa == 'T' || *transa == 't');
+ bool isTransB = (*transb == 'T' || *transb == 't');
+
+ int Bm, sizeM, Bn, sizeN, Bk, sizeK;
+
+ int i, j;
+
+ if ((m <= 0) || (n <= 0))
+ return;
+
+ if ((k <= 0) || (alpha[0] == 0.)) {
+
+ if (beta[0] == 0.) {
+ for (j = 0; j < n; j++)
+ for (i = 0; i < m; i++)
+ c[i + j * ldc] = 0.0;
+ } else if (beta[0] != 1.) {
+ for (j = 0; j < n; j++)
+ for (i = 0; i < m; i++)
+ c[i + j * ldc] *= beta[0];
+ }
+
+ return;
+ }
+
+ assert(IMPLICATION(bias != nullptr, *beta == 0.0));
+
+ // XXX: this happens on every thread...
+ bool hasBias = (bias != nullptr);
+ auto ker_bn = get_xbyak_gemm(isTransA, isTransB, *beta, hasBias);
+ auto ker_b1 = get_xbyak_gemm(isTransA, isTransB, 1.0, false);
+ auto ker_b0 = get_xbyak_gemm(isTransA, isTransB, 0.0, false);
+ assert(ker_bn && ker_b1 && ker_b0);
+
+ int BM = 4032;
+ int BN = isTransA ? 96 : 48;
+ int BK = isTransB ? 96 : 256;
+ const float *curA, *curB, *curBias = nullptr;
+ float *curC;
+
+ for (Bk = 0; Bk < k; Bk += sizeK) {
+ sizeK = k - Bk;
+ if (sizeK >= BK * 2)
+ sizeK = BK;
+ else {
+ if (sizeK > BK)
+ sizeK = (sizeK + 1) / 2;
+ }
+
+ for (Bm = 0; Bm < m; Bm += sizeM) {
+ sizeM = m - Bm;
+ if (sizeM >= BM * 2)
+ sizeM = BM;
+ else {
+ if (sizeM > BM + BM / 2)
+ sizeM = (sizeM + 1) / 2;
+ }
+
+ for (Bn = 0; Bn < n; Bn += sizeN) {
+ sizeN = n - Bn;
+ if (sizeN >= BN * 2)
+ sizeN = BN;
+ else {
+ if (sizeN > BN + BN / 2)
+ sizeN = (sizeN + 1) / 2;
+ }
+
+ if (!isTransA) {
+ curA = a + Bm + Bk * lda;
+ } else {
+ curA = a + Bk + Bm * lda;
+ }
+ if (!isTransB) {
+ curB = b + Bk + Bn * ldb;
+ } else {
+ curB = b + Bn + Bk * ldb;
+ }
+ curC = c + Bm + (size_t)Bn * ldc;
+ if (bias != nullptr) {
+ if (Bk == 0) {
+ curBias = bias + Bm;
+ } else {
+ curBias = nullptr;
+ }
+ }
+ if (Bk == 0) {
+ if (*beta == 0.0 && bias == nullptr)
+ (*ker_b0)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ else
+ (*ker_bn)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ } else {
+ (*ker_b1)((dim_t)sizeM, (dim_t)sizeN, (dim_t)sizeK,
+ alpha, curA, lda, curB, ldb, beta, curC, ldc,
+ curBias, ws);
+ }
+ }
+ }
+ }
+}
+
+}
+
+mkldnn_status_t jit_avx_gemm_f32(
+ const char *transa, const char *transb,
+ const int *p_m, const int *p_n, const int *p_k, const float *p_alpha,
+ const float *A, const int *p_lda, const float *B, const int *p_ldb,
+ const float *p_beta, float *C, const int *p_ldc, const float *bias)
+{
+ using namespace mkldnn::impl::utils;
+ using namespace avx_gemm_f32;
+ using namespace gemm_utils;
+
+ if (*p_beta != 0 && bias)
+ return ref_gemm(transa, transb, p_m, p_n, p_k,
+ p_alpha, A, p_lda, B, p_lda, p_beta, C, p_ldc, bias);
+
+ int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
+
+ int m = *p_m;
+ int n = *p_n;
+ int k = *p_k;
+ dim_t lda = *p_lda;
+ dim_t ldb = *p_ldb;
+ dim_t ldc = *p_ldc;
+ float beta = *p_beta;
+ int MB, NB, KB;
+
+ int nthr_m, nthr_n, nthr_k, nthr_mn;
+
+ // Determine threading partitioning
+ calc_nthr_nocopy_avx(
+ m, n, k, nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
+
+ // May not happen, but just in case
+ if (nthr < nthr_m * nthr_n * nthr_k)
+ nthr = nthr_m * nthr_n * nthr_k;
+
+ nthr_mn = nthr_m * nthr_n;
+
+ unsigned char * ompstatus_ = nullptr;
+ unsigned char volatile *ompstatus = nullptr;
+
+ float *c_buffers = nullptr;
+ float *ws_buffers = nullptr;
+
+ if (nthr_k > 1) {
+ ompstatus_ = (unsigned char *) malloc(
+ nthr * CACHE_LINE_SIZE,
+ CACHE_LINE_SIZE);
+ ompstatus = (unsigned char volatile *) ompstatus_;
+ assert(ompstatus);
+
+ for (int i = 0; i < nthr; i++)
+ ompstatus[i * CACHE_LINE_SIZE] = 0;
+
+ c_buffers = (float *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
+ * sizeof(float), PAGE_4K);
+ }
+
+ const size_t ws_elems_per_thr = (size_t)k * 16 + 64;
+ const size_t ws_size_per_thr
+ = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K);
+ if (k > STACK_K_CAPACITY) {
+ ws_buffers = (float *)malloc(nthr * ws_size_per_thr, PAGE_4K);
+ }
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_m, ithr_n, ithr_k, ithr_mn;
+ int m_from, m_to, myM;
+ int n_from, n_to, myN;
+ int k_from, k_to, myK;
+ int cbase, ibase;
+ const float *myA, *myB, *myBias = nullptr;
+ float *myC = C, myBeta;
+ float *ws = ws_buffers ?
+ ws_buffers + ithr * ws_size_per_thr / sizeof(float) : 0;
+ dim_t ld = ldc;
+
+ int sum_later = (mkldnn_get_num_threads() < nthr_m * nthr_n * nthr_k);
+
+ if (ithr < nthr_m * nthr_n * nthr_k) {
+
+ ithr_mn = ithr % nthr_mn;
+ ithr_m = ithr_mn % nthr_m;
+ ithr_n = ithr_mn / nthr_m;
+ ithr_k = ithr / nthr_mn;
+
+ /* swap ithr_k for performance improvement */
+ if (ithr_k == 0)
+ ithr_k = nthr_k - 1;
+ else if (ithr_k == nthr_k - 1)
+ ithr_k = 0;
+
+ m_from = MB * (ithr_m);
+ m_to = MB * (ithr_m + 1);
+ if (m_to > m)
+ m_to = m;
+ myM = m_to - m_from;
+
+ n_from = NB * (ithr_n);
+ n_to = NB * (ithr_n + 1);
+ if (n_to > n)
+ n_to = n;
+ myN = n_to - n_from;
+
+ k_from = KB * (ithr_k);
+ k_to = KB * (ithr_k + 1);
+ if (k_to > k)
+ k_to = k;
+ myK = k_to - k_from;
+
+ cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+ ibase = (ithr_m + nthr_m * ithr_n) * nthr_k;
+
+ if ((myM > 0) && (myN > 0)) {
+
+ if (*transa == 'N' || *transa == 'n') {
+ myA = &(A[m_from + k_from * lda]);
+ } else {
+ myA = &(A[k_from + m_from * lda]);
+ }
+ if (*transb == 'N' || *transb == 'n') {
+ myB = &(B[k_from + n_from * ldb]);
+ } else {
+ myB = &(B[n_from + k_from * ldb]);
+ }
+ if (ithr_k == 0) {
+ myC = &(C[m_from + n_from * ldc]);
+ myBeta = beta;
+ ld = ldc;
+ if (bias)
+ myBias = &(bias[m_from]);
+ } else {
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
+ myBeta = 0.0;
+ ld = MB;
+ myBias = nullptr;
+ }
+
+ sgemm_nocopy_driver(transa, transb, myM, myN, myK, p_alpha, myA,
+ lda, myB, ldb, &myBeta, myC, ld, myBias, ws);
+
+ if (nthr_k > 1 && !sum_later)
+ ompstatus[(ibase + ithr_k) * CACHE_LINE_SIZE] = 1;
+ }
+
+ if (nthr_k > 1 && !sum_later) {
+
+ // sum matrices partitioned along K dimension
+ int n1, n2;
+
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+ if (ithr_k > 0) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
+ /* need to wait until main thread finishes */
+ while (ompstatus[ibase * CACHE_LINE_SIZE] != 1) {
+ };
+
+ /* my cache is hot */
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ if (ik != ithr_k) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
+
+ while (ompstatus[(ibase + ik) * CACHE_LINE_SIZE] != 1) {
+ };
+
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+ }
+ }
+ }
+ });
+
+ // handle C summation later
+ if (nthr_k > 1 && ompstatus[0] == 0) {
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_m, ithr_n, ithr_k, ithr_mn;
+ int m_from, m_to, myM;
+ int n_from, n_to, myN;
+ int cbase;
+ float *myC = C;
+
+ if (ithr < nthr_m * nthr_n * nthr_k) {
+
+ ithr_mn = ithr % nthr_mn;
+ ithr_m = ithr_mn % nthr_m;
+ ithr_n = ithr_mn / nthr_m;
+ ithr_k = ithr / nthr_mn;
+
+ /* swap ithr_k for performance improvement */
+ if (ithr_k == 0)
+ ithr_k = nthr_k - 1;
+ else if (ithr_k == nthr_k - 1)
+ ithr_k = 0;
+
+ m_from = MB * (ithr_m);
+ m_to = MB * (ithr_m + 1);
+ if (m_to > m)
+ m_to = m;
+ myM = m_to - m_from;
+
+ n_from = NB * (ithr_n);
+ n_to = NB * (ithr_n + 1);
+ if (n_to > n)
+ n_to = n;
+ myN = n_to - n_from;
+
+ cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ if (nthr_k > 1) {
+ // sum matrices partitioned along K dimension
+ int n1, n2;
+
+ partition_unit_diff(ithr_k, nthr_k, myN, &n1, &n2);
+
+ if (ithr_k > 0) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1)
+ + (dim_t)n1 * MB;
+
+ /* my cache is hot */
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ if (ik != ithr_k) {
+
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ik - 1)
+ + (dim_t)n1 * MB;
+
+ sum_two_matrices(myM, n2, myC, MB,
+ &C[m_from + (n_from + n1) * ldc], ldc);
+ }
+ }
+ }
+ }
+ });
+ }
+
+
+ free(c_buffers);
+ free(ompstatus_);
+ free(ws_buffers);
+
+ return mkldnn_success;
+}
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp
new file mode 100644
index 0000000000..aabf520a3c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/jit_avx_gemm_f32.hpp
@@ -0,0 +1,37 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX_GEMM_F32_HPP
+#define JIT_AVX_GEMM_F32_HPP
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t jit_avx_gemm_f32(
+ const char *transa, const char *transb, const int *M,
+ const int *N, const int *K, const float *alpha, const float *A,
+ const int *lda, const float *B, const int *ldb, const float *beta,
+ float *C, const int *ldc, const float *bias = nullptr);
+
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp
new file mode 100644
index 0000000000..5147885a89
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.cpp
@@ -0,0 +1,346 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_types.h"
+
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "jit_generator.hpp"
+
+#include "gemm_utils_f32.hpp"
+#include "ref_gemm_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace gemm_utils;
+
+namespace {
+
+template <typename data_t>
+void copy_A(
+ bool isTransA, int K, const data_t *A, const dim_t lda, data_t *ws) {
+ for (int k = 0; k < K; k++) {
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < unroll_factor<data_t>::m; i++) {
+ ws[i] = isTransA ? A[i * lda + k] : A[i + k * lda];
+ }
+ ws += unroll_factor<data_t>::m;
+ }
+}
+
+template <typename data_t, bool isTransA, bool isTransB>
+void kernel_mxn(int K, const data_t *A, const dim_t lda,
+ const data_t *B, const dim_t ldb, data_t *C, const dim_t ldc,
+ const data_t alpha, const data_t beta) {
+ data_t c[unroll_factor<data_t>::m * unroll_factor<data_t>::n] =
+ { static_cast<data_t>(0.) };
+ for (int k = 0; k < K; k++) {
+ for (int j = 0; j < unroll_factor<data_t>::n; j++) {
+ data_t b = isTransB ? B[j + k * ldb] : B[k + j * ldb];
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < unroll_factor<data_t>::m; i++) {
+ data_t a = isTransA ? A[i * lda + k] : A[i + lda * k];
+ c[i + unroll_factor<data_t>::m * j] += a * b;
+ }
+ }
+ }
+ for (int j = 0; j < unroll_factor<data_t>::n; j++) {
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < unroll_factor<data_t>::m; i++) {
+ C[i + j * ldc] = (beta == static_cast<data_t>(0.))
+ ? alpha * c[i + unroll_factor<data_t>::m * j]
+ : alpha * c[i + unroll_factor<data_t>::m * j]
+ + beta * C[i + j * ldc];
+ }
+ }
+}
+
+template <typename data_t, bool isTransA, bool isTransB>
+void block_ker(const int M, const int N, const int K,
+ const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
+ data_t *C, const dim_t ldc, const data_t alpha, const data_t beta,
+ data_t *ws, bool do_copy) {
+ int Nu = rnd_dn(N, unroll_factor<data_t>::n);
+ int Mu = rnd_dn(M, unroll_factor<data_t>::m);
+ for (int i = 0; i < Mu; i += unroll_factor<data_t>::m) {
+ for (int j = 0; j < Nu; j += unroll_factor<data_t>::n) {
+ const data_t *b = isTransB ? &B[j] : &B[j * ldb];
+ const data_t *a = isTransA ? &A[i * lda] : &A[i];
+ if (do_copy) {
+ if (j == 0) {
+ copy_A<data_t>(isTransA, K, a, lda, ws);
+ }
+ kernel_mxn<data_t, false, isTransB>(
+ K, ws, unroll_factor<data_t>::m, b, ldb,
+ &C[i + j * ldc], ldc, alpha, beta);
+ } else {
+ kernel_mxn<data_t, isTransA, isTransB>(
+ K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta);
+ }
+ }
+ }
+ // tail processing
+ for (int i = 0; i < M; i++) {
+ for (int j = Nu; j < N; j++) {
+ data_t c = beta == static_cast<data_t>(0.)
+ ? static_cast<data_t>(0.)
+ : beta * C[i + j * ldc];
+ for (int p = 0; p < K; p++) {
+ data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
+ data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
+ c += alpha * a * b;
+ }
+ C[i + j * ldc] = c;
+ }
+ }
+ for (int i = Mu; i < M; i++) {
+ for (int j = 0; j < Nu; j++) {
+ data_t c = beta == static_cast<data_t>(0.)
+ ? static_cast<data_t>(0.)
+ : beta * C[i + j * ldc];
+ for (int p = 0; p < K; p++) {
+ data_t b = isTransB ? B[j + p * ldb] : B[p + j * ldb];
+ data_t a = isTransA ? A[p + i * lda] : A[i + p * lda];
+ c += alpha * a * b;
+ }
+ C[i + j * ldc] = c;
+ }
+ }
+}
+
+template <typename data_t, bool isTransA, bool isTransB>
+void gemm_ithr(const int M, const int N, const int K, const data_t alpha,
+ const data_t *A, const dim_t lda, const data_t *B, const dim_t ldb,
+ const data_t beta, data_t *C, const dim_t ldc, bool do_copy,
+ data_t *ws) {
+ constexpr int BM = gemm_traits<data_t, isTransA, isTransB>::BM;
+ constexpr int BN = gemm_traits<data_t, isTransA, isTransB>::BN;
+ constexpr int BK = gemm_traits<data_t, isTransA, isTransB>::BK;
+
+ const data_t *curA;
+ const data_t *curB;
+ data_t *curC;
+
+ if ((M <= 0) || (N <= 0))
+ return;
+
+ if ((K <= 0) || (alpha == static_cast<data_t>(0))) {
+ dim_t MN = N * M;
+ if (beta == static_cast<data_t>(0.)) {
+ for (dim_t j = 0; j < MN; j++)
+ C[j] = static_cast<data_t>(0.);
+ } else if (beta != static_cast<data_t>(1.)) {
+ for (dim_t j = 0; j < MN; j++)
+ C[j] *= beta;
+ }
+ return;
+ }
+
+ for (int Bk = 0; Bk < K; Bk += BK) {
+ int kb = nstl::min(K - Bk, BK);
+ for (int Bm = 0; Bm < M; Bm += BM) {
+ int mb = nstl::min(M - Bm, BM);
+ for (int Bn = 0; Bn < N; Bn += BN) {
+ int nb = nstl::min(N - Bn, BN);
+ curA = isTransA ? A + Bk + Bm * lda : A + Bm + Bk * lda;
+ curB = isTransB ? B + Bn + Bk * ldb : B + Bk + Bn * ldb;
+ curC = C + Bm + Bn * ldc;
+ if (Bk == 0) {
+ block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
+ curB, ldb, curC, ldc, alpha, beta, ws, do_copy);
+ } else {
+ block_ker<data_t, isTransA, isTransB>(mb, nb, kb, curA, lda,
+ curB, ldb, curC, ldc, alpha, static_cast<data_t>(1.0),
+ ws, do_copy);
+ }
+ }
+ }
+ }
+}
+
+}
+
+template <typename data_t>
+mkldnn_status_t ref_gemm(
+ const char *transa_, const char *transb_, const int *M_,
+ const int *N_, const int *K_, const data_t *alpha_, const data_t *A,
+ const int *lda_, const data_t *B, const int *ldb_, const data_t *beta_,
+ data_t *C, const int *ldc_, const data_t *bias) {
+
+ bool isTransA = (*transa_ == 'T' || *transa_ == 't');
+ bool isTransB = (*transb_ == 'T' || *transb_ == 't');
+ const int M = *M_, N = *N_, K = *K_;
+ const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_;
+ const data_t alpha = *alpha_, beta = *beta_;
+
+ int max_nthr = mkldnn_in_parallel() ? 1 : mkldnn_get_max_threads();
+ int nthr_m, nthr_n, nthr_k;
+ int MB, NB, KB;
+ // thread balancing over M, N, K & size of blocking dimensions
+ calc_nthr_nocopy_avx(
+ M, N, K, max_nthr, &nthr_m, &nthr_n, &nthr_k, &MB, &NB, &KB);
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_k == 1));
+
+ data_t *c_buffers = nullptr;
+ data_t *ws_buffers = nullptr;
+ if (nthr_k > 1) {
+ c_buffers = (data_t *)malloc(nthr_m * nthr_n * (nthr_k - 1) * MB * NB
+ * sizeof(data_t), PAGE_4K);
+ if (!c_buffers) {
+ nthr_k = 1;
+ KB = K;
+ }
+ }
+
+ bool do_copy = (NB / unroll_factor<data_t>::n > 3);
+ const int nthr_mn = nthr_m * nthr_n;
+ const int nthr = nthr_mn * nthr_k;
+ const size_t ws_elems_per_thr = K * unroll_factor<data_t>::m;
+ const size_t ws_size_per_thr
+ = rnd_up(ws_elems_per_thr * sizeof(data_t), PAGE_4K);
+ if (do_copy) {
+ ws_buffers = (data_t*)malloc(nthr * ws_size_per_thr, PAGE_4K);
+ if (!ws_buffers)
+ do_copy = false;
+ }
+
+ auto get_thr_block = [&](int &from, int &to, int &myN, int NB, int N,
+ int ithr) {
+ from = NB * (ithr);
+ to = NB * (ithr + 1);
+ if (to > N)
+ to = N;
+ myN = to - from;
+ };
+
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_mn = ithr % nthr_mn;
+ int ithr_m = ithr_mn % nthr_m;
+ int ithr_n = ithr_mn / nthr_m;
+ int ithr_k = ithr / nthr_mn;
+
+ int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ data_t *ws = do_copy
+ ? ws_buffers + ithr * ws_size_per_thr / sizeof(data_t)
+ : nullptr;
+
+ int m_from = 0, m_to = 0, myM = 0, n_from = 0, n_to = 0, myN = 0,
+ k_from = 0, k_to = 0, myK = 0;
+
+ get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
+ get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
+ get_thr_block(k_from, k_to, myK, KB, K, ithr_k);
+
+ if (myM > 0 && myN > 0) {
+ data_t myBeta, *myC;
+ dim_t ld;
+ if (ithr_k == 0) {
+ myC = &(C[m_from + n_from * ldc]);
+ myBeta = beta;
+ ld = ldc;
+ } else {
+ myC = c_buffers + (dim_t)MB * NB * (cbase + ithr_k - 1);
+ myBeta = 0.0f;
+ ld = MB;
+ }
+ const data_t *myA = isTransA
+ ? &(A[k_from + m_from * lda])
+ : &(A[m_from + k_from * lda]);
+ const data_t *myB = isTransB
+ ? &(B[n_from + k_from * ldb])
+ : &(B[k_from + n_from * ldb]);
+
+ if (!isTransA) {
+ if (!isTransB) {
+ gemm_ithr<data_t, false, false>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
+ } else {
+ gemm_ithr<data_t, false, true>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
+ }
+ } else {
+ if (!isTransB) {
+ gemm_ithr<data_t, true, false>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
+ } else {
+ gemm_ithr<data_t, true, true>(myM, myN, myK, alpha, myA,
+ lda, myB, ldb, myBeta, myC, ld, do_copy, ws);
+ }
+ }
+ }
+ });
+
+ if (nthr_k > 1) {
+ parallel_nd(nthr, [&](const int ithr) {
+ int ithr_mn = ithr % nthr_mn;
+ int ithr_m = ithr_mn % nthr_m;
+ int ithr_k = ithr / nthr_mn;
+ int ithr_n = ithr_mn / nthr_m;
+
+ int n_from = 0, n_to = 0, myN = 0;
+ int m_from = 0, m_to = 0, myM = 0;
+
+ int cbase = (ithr_m + nthr_m * ithr_n) * (nthr_k - 1);
+
+ get_thr_block(n_from, n_to, myN, NB, N, ithr_n);
+ get_thr_block(m_from, m_to, myM, MB, M, ithr_m);
+
+ // sum matrices partitioned along K dimension
+ int offset = 0, block = 0;
+ gemm_utils::partition_unit_diff(ithr_k, nthr_k, myN, &offset,
+ &block);
+ for (int ik = 1; ik < nthr_k; ++ik) {
+ data_t *myC = c_buffers
+ + MB * ((dim_t)NB * (cbase + ik - 1) + offset);
+
+ gemm_utils::sum_two_matrices(myM, block, myC, MB,
+ &C[m_from + (n_from + offset) * ldc], ldc);
+ }
+ });
+ }
+
+ if (bias) {
+ parallel_nd(N, M, [&](int i, int j) {
+ C[i*ldc + j] += bias[j];
+ });
+ }
+
+ free(ws_buffers);
+ free(c_buffers);
+
+ return mkldnn_success;
+}
+
+template mkldnn_status_t ref_gemm<float>(
+ const char *transa_, const char *transb_,
+ const int *M_, const int *N_, const int *K_, const float *alpha_,
+ const float *A, const int *lda_, const float *B, const int *ldb_,
+ const float *beta_, float *C, const int *ldc_, const float *bias);
+
+template mkldnn_status_t ref_gemm<double>(
+ const char *transa_, const char *transb_,
+ const int *M_, const int *N_, const int *K_, const double *alpha_,
+ const double *A, const int *lda_, const double *B, const int *ldb_,
+ const double *beta_, double *C, const int *ldc_, const double *bias);
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp
new file mode 100644
index 0000000000..7c90ba6277
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/f32/ref_gemm_f32.hpp
@@ -0,0 +1,36 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef REF_GEMM_F32_HPP
+#define REF_GEMM_F32_HPP
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <typename data_t>
+mkldnn_status_t ref_gemm(const char *transa, const char *transb, const int *M,
+ const int *N, const int *K, const data_t *alpha, const data_t *A,
+ const int *lda, const data_t *B, const int *ldb, const data_t *beta,
+ data_t *C, const int *ldc, const data_t *bias);
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp
new file mode 100644
index 0000000000..3dbe07d743
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.cpp
@@ -0,0 +1,280 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "mkldnn_traits.hpp"
+#include "nstl.hpp"
+
+#include "jit_generator.hpp"
+
+#include "gemm.hpp"
+
+#include "f32/jit_avx512_common_gemm_f32.hpp"
+#include "f32/jit_avx_gemm_f32.hpp"
+#include "f32/ref_gemm_f32.hpp"
+
+#include "s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp"
+#include "s8x8s32/simple_gemm_s8s8s32.hpp"
+#include "s8x8s32/ref_gemm_s8x8s32.hpp"
+
+#include "os_blas.hpp"
+
+/* USE_MKL USE_CBLAS effect
+ * ------- --------- ------
+ * yes yes use Intel(R) MKL CBLAS
+ * yes no use jit
+ * no yes system-dependent CBLAS
+ * no no use jit
+ */
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t check_gemm_input(const char *transa, const char *transb,
+ const int *M, const int *N, const int *K, const int *lda,
+ const int *ldb, const int *ldc, const float *alpha, const float *beta,
+ const bool with_bias) {
+ if (utils::any_null(transa, transb, M, N, K, lda, ldb, ldc, alpha, beta))
+ return mkldnn_invalid_arguments;
+ if (with_bias && *beta != 0)
+ return mkldnn_unimplemented;
+ bool consistency = true
+ && utils::one_of(*transa, 'T', 't', 'N', 'n')
+ && utils::one_of(*transb, 'T', 't', 'N', 'n')
+ && *M >= 0
+ && *N >= 0
+ && *K >= 0;
+
+ if (!consistency)
+ return mkldnn_invalid_arguments;
+ bool isTransA = utils::one_of(*transa, 'T', 't');
+ bool isTransB = utils::one_of(*transb, 'T', 't');
+ int nrowA = isTransA ? *K : *M;
+ int nrowB = isTransB ? *N : *K;
+ consistency = true
+ && *lda >= nstl::max(1, nrowA)
+ && *ldb >= nstl::max(1, nrowB)
+ && *ldc >= nstl::max(1, *M);
+ if (!consistency)
+ return mkldnn_invalid_arguments;
+
+ return mkldnn_success;
+}
+
+mkldnn_status_t check_gemm_x8x8x32_input(const char *offsetc,
+ const char *transa, const char *transb, const int *M, const int *N,
+ const int *K, const int *lda, const int *ldb, const int *ldc,
+ const float *alpha, const float *beta, const bool with_bias) {
+ if (offsetc == nullptr)
+ return mkldnn_invalid_arguments;
+ if (!utils::one_of(*offsetc, 'F', 'f', 'C', 'c', 'R', 'r'))
+ return mkldnn_invalid_arguments;
+
+ return check_gemm_input(transa, transb, M, N, K, lda, ldb, ldc, alpha,
+ beta, with_bias);
+}
+
+mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
+ const int *M, const int *N, const int *K, const float *alpha,
+ const float *A, const int *lda, const float *B, const int *ldb,
+ const float *beta, float *C, const int *ldc,
+ const float *bias, const bool force_jit_gemm) {
+ mkldnn_status_t status = check_gemm_input(transa, transb, M, N, K,
+ lda, ldb, ldc, alpha, beta, bias != nullptr);
+ if (status != mkldnn_success)
+ return status;
+
+#ifdef USE_CBLAS
+ if (!force_jit_gemm) {
+ bool trA = *transa == 't' || *transa == 'T';
+ bool trB = *transb == 't' || *transb == 'T';
+ CBLAS_TRANSPOSE Cblas_trA = trA ? CblasTrans : CblasNoTrans;
+ CBLAS_TRANSPOSE Cblas_trB = trB ? CblasTrans : CblasNoTrans;
+ cblas_sgemm(CblasColMajor, Cblas_trA, Cblas_trB,
+ *M, *N, *K, *alpha, A, *lda, B, *ldb, *beta, C, *ldc);
+
+ if (bias) {
+ // Add bias if necessary (bias is applied to columns of C)
+ cblas_int incx = 1, incy = 1;
+ parallel_nd(*N, [&](int n) {
+ ptrdiff_t offset = (ptrdiff_t)n * (*ldc);
+ cblas_saxpy(*M, 1.0, bias, incx, C + offset, incy);
+ });
+ }
+ return mkldnn_success;
+ }
+#endif
+
+ if (mayiuse(avx512_common))
+ return jit_avx512_common_gemm_f32(transa, transb,
+ M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
+ else if (mayiuse(avx))
+ return jit_avx_gemm_f32(transa, transb,
+ M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
+ else
+ return ref_gemm<float>(transa, transb,
+ M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, bias);
+}
+
+template <typename b_dt>
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co) {
+ mkldnn_status_t status = check_gemm_x8x8x32_input(offsetc, transa, transb,
+ M, N, K, LDA, LDB, LDC, alpha, beta, false);
+ if (status != mkldnn_success)
+ return status;
+
+ if (*M == 0 || *N == 0 || *K == 0)
+ return mkldnn_success;
+
+#if USE_MKL_IGEMM
+ bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
+ bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
+ bool AisN = (*transa == 'N' || *transa == 'n');
+ bool BisN = (*transb == 'N' || *transb == 'n');
+
+ if (data_traits<b_dt>::data_type == data_type::u8) {
+ CBLAS_TRANSPOSE Cblas_trA = AisN ? CblasNoTrans : CblasTrans;
+ CBLAS_TRANSPOSE Cblas_trB = BisN ? CblasNoTrans : CblasTrans;
+ CBLAS_OFFSET Cblas_offsetc =
+ OCisR
+ ? CblasRowOffset
+ : OCisC
+ ? CblasColOffset
+ : CblasFixOffset;
+ cblas_gemm_s8u8s32(CblasColMajor, Cblas_trA, Cblas_trB, Cblas_offsetc,
+ *M, *N, *K, *alpha, A, *LDA, *ao, (uint8_t *)B, *LDB, *bo,
+ *beta, C, *LDC, co);
+ return mkldnn_success;
+ } else {
+ assert(data_traits<b_dt>::data_type == data_type::s8);
+ // TODO CBLAS implementation of gemm_s8s8s32 goes here.
+ // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
+ if (utils::everyone_is(0, *ao, *bo)) {
+ return simple_gemm_s8s8s32(transa, transb, offsetc, M,
+ N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
+ C, LDC, co);
+ } else {
+ return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
+ alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
+ }
+ }
+#else
+ cpu_isa_t isa = isa_any;
+ if (mayiuse(avx512_core_vnni)) {
+ isa = avx512_core_vnni;
+ } else if (mayiuse(avx512_core)) {
+ isa = avx512_core;
+ }
+
+ if (data_traits<b_dt>::data_type == data_type::u8) {
+ switch (isa) {
+ case avx512_core:
+ case avx512_core_vnni:
+ return jit_avx512_core_gemm_s8u8s32(transa, transb, offsetc, M,
+ N, K, alpha, A, LDA, ao, (uint8_t *)B, LDB, bo, beta,
+ C, LDC, co);
+ default:
+ return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
+ alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
+ }
+ } else {
+ assert(data_traits<b_dt>::data_type == data_type::s8);
+ // mkldnn_gemm_s8s8s32 doesn't support non-zero ao and bo
+ if ((mayiuse(avx512_core) || mayiuse(avx512_core_vnni))
+ && *ao == 0 && *bo == 0) {
+ return simple_gemm_s8s8s32(transa, transb, offsetc, M,
+ N, K, alpha, A, LDA, ao, (int8_t *)B, LDB, bo, beta,
+ C, LDC, co);
+ } else {
+ return ref_gemm_s8x8s32(transa, transb, offsetc, M, N, K,
+ alpha, A, LDA, ao, B, LDB, bo, beta, C, LDC, co);
+ }
+ }
+#endif
+}
+
+template
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const int8_t *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co);
+
+template
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const uint8_t *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co);
+
+}
+}
+}
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::cpu;
+
+mkldnn_status_t mkldnn_sgemm(const char *transa, const char *transb,
+ const int64_t *M, const int64_t *N, const int64_t *K, const float *alpha,
+ const float *A, const int64_t *lda, const float *B, const int64_t *ldb,
+ const float *beta, float *C, const int64_t *ldc) {
+ int M_s32 = (int)*M;
+ int N_s32 = (int)*N;
+ int K_s32 = (int)*K;
+ int lda_s32 = (int)*lda;
+ int ldb_s32 = (int)*ldb;
+ int ldc_s32 = (int)*ldc;
+
+ return extended_sgemm(transa, transb, &M_s32, &N_s32, &K_s32,
+ alpha, A, &lda_s32, B, &ldb_s32, beta, C, &ldc_s32);
+}
+
+mkldnn_status_t mkldnn_gemm_s8u8s32(const char *transa, const char *transb,
+ const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K,
+ const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao,
+ const uint8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta,
+ int32_t *C, const int64_t *ldc, const int32_t *co) {
+ int M_s32 = (int)*M;
+ int N_s32 = (int)*N;
+ int K_s32 = (int)*K;
+ int lda_s32 = (int)*lda;
+ int ldb_s32 = (int)*ldb;
+ int ldc_s32 = (int)*ldc;
+ return gemm_s8x8s32(transa, transb, offsetc, &M_s32, &N_s32, &K_s32,
+ alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co);
+}
+
+mkldnn_status_t mkldnn_gemm_s8s8s32(const char *transa, const char *transb,
+ const char *offsetc, const int64_t *M, const int64_t *N, const int64_t *K,
+ const float *alpha, const int8_t *A, const int64_t *lda, const int8_t *ao,
+ const int8_t *B, const int64_t *ldb, const int8_t *bo, const float *beta,
+ int32_t *C, const int64_t *ldc, const int32_t *co) {
+ int M_s32 = (int)*M;
+ int N_s32 = (int)*N;
+ int K_s32 = (int)*K;
+ int lda_s32 = (int)*lda;
+ int ldb_s32 = (int)*ldb;
+ int ldc_s32 = (int)*ldc;
+
+ return gemm_s8x8s32<int8_t>(transa, transb, offsetc, &M_s32, &N_s32, &K_s32,
+ alpha, A, &lda_s32, ao, B, &ldb_s32, bo, beta, C, &ldc_s32, co);
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp
new file mode 100644
index 0000000000..dc15ff7130
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/gemm.hpp
@@ -0,0 +1,58 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef GEMM_HPP
+#define GEMM_HPP
+
+#include "mkldnn_types.h"
+#include "os_blas.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t extended_sgemm(const char *transa, const char *transb,
+ const int *M, const int *N, const int *K, const float *alpha,
+ const float *A, const int *lda, const float *B, const int *ldb,
+ const float *beta, float *C, const int *ldc,
+ const float *bias = nullptr, bool force_jit_gemm = false);
+
+template <typename b_dt>
+mkldnn_status_t gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *lda, const int8_t *ao,
+ const b_dt *B, const int *ldb, const int8_t *bo, const float *beta,
+ int32_t *c, const int *ldc, const int32_t *co);
+
+#ifdef USE_CBLAS
+#define GEMM_IMPL_STR "gemm:blas"
+#else
+#define GEMM_IMPL_STR "gemm:jit"
+#endif
+
+#if USE_MKL_IGEMM
+#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:blas"
+#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:blas"
+#else
+#define IGEMM_S8U8S32_IMPL_STR "igemm_s8u8s32:jit"
+#define IGEMM_S8S8S32_IMPL_STR "igemm_s8s8s32:jit"
+#endif
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp
new file mode 100644
index 0000000000..4d34ede0bd
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/os_blas.hpp
@@ -0,0 +1,86 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef OS_BLAS_HPP
+#define OS_BLAS_HPP
+
+/** \file
+ * Common stuff respecting USE_MKL and USE_CBLAS compile flags
+ *
+ * USE_MKL USE_CBLAS effect
+ * ------- --------- ------
+ * yes yes normal compile: jit *may* be preferred over Intel(R) MKL CBLAS
+ * yes no jit calls OK; assert if cblas is ever called
+ * no yes system-dependent CBLAS
+ * no no gemm convolution (or other blas) N/A; create stubs
+ */
+
+#if defined(USE_MKL)
+
+#include "mkl_version.h"
+
+#define USE_MKL_PACKED_GEMM (INTEL_MKL_VERSION >= 20190001)
+#define USE_MKL_IGEMM \
+ (INTEL_MKL_VERSION >= 20180000 && __INTEL_MKL_BUILD_DATE >= 20170628)
+
+#include "mkl_cblas.h"
+#if !defined(USE_CBLAS)
+#define cblas_sgemm(...) assert(!"CBLAS is unavailable")
+#endif
+
+#else /* defined(USE_MKL) */
+
+#define USE_MKL_PACKED_GEMM 0
+#define USE_MKL_IGEMM 0
+
+#if defined(_SX)
+/* TODO: _SX should also define USE_CBLAS in case the later is available */
+extern "C" {
+#include "cblas.h" // CHECK: does SX also have a fortran API sgemm?
+}
+
+#elif defined(USE_CBLAS)
+#include "cblas.h" // Maybe a system/cmake cblas works for you?
+#else
+/* put the stubs to make a code compilable but not workable */
+#define cblas_sgemm(...) assert(!"CBLAS is unavailable")
+#endif /* defined(_SX) */
+
+#endif /* defined(USE_MKL) */
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+#if defined(USE_MKL) && defined(USE_CBLAS)
+typedef MKL_INT cblas_int;
+
+#elif defined(USE_CBLAS)
+typedef int cblas_int;
+
+#if defined(_SX)
+/* this cblas.h is peculiar... */
+typedef CBLAS_ORDER CBLAS_LAYOUT;
+#endif
+#endif
+
+}
+}
+}
+
+#endif /* OS_BLAS_HPP */
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp
new file mode 100644
index 0000000000..dde72f4a17
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/common.hpp
@@ -0,0 +1,206 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef COMMON_H
+#define COMMON_H
+
+#define GEMM_CODE_SIZE (4096L * 32)
+
+#define AVX512_UNROLL_M 48
+#define AVX512_UNROLL_N 8
+#define AVX512_UNROLL_K 1
+#define AVX512_BM 9984
+#define AVX512_BN 384
+#define AVX512_BK 768
+#define AVX512_BK_VNNI 1536
+#define AVX512_BK_TRADITIONAL 384
+#define AVX512_BLOCKING_SMALL_K 48
+#define AVX512_BN_SMALL_K 24
+
+
+#define PAGESIZE 4096
+
+#define PADD_BYTESIZE_ONPAGE(x, size) (((x) * (size) + PAGESIZE - 1) / PAGESIZE) * PAGESIZE
+#define NEXT_THR_STRIDE(x, size) (PADD_BYTESIZE_ONPAGE(x, size)) / size
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+enum {
+ PARTITION_1D_ROW,
+ PARTITION_1D_COL,
+ PARTITION_2D_COL_MAJOR,
+ PARTITION_2D = PARTITION_2D_COL_MAJOR,
+};
+
+enum {
+ COPY_NONE,
+ COPY_A,
+};
+
+enum {
+ NO_OFFSET,
+ FIX_OFFSET,
+ COL_OFFSET,
+ ROW_OFFSET,
+};
+
+// Alias for any dimension related variable.
+typedef long long int dim_t;
+
+typedef struct {
+ // Interface arguments.
+ int transa, transb, offsetc;
+ dim_t m, n, k;
+ dim_t lda, ldb, ldc;
+ const int8_t *a;
+ const uint8_t *b;
+ int32_t *c;
+ const float *alpha, *beta;
+
+ int8_t ao, bo;
+ const int32_t *co;
+
+ // Kernel parameters.
+ dim_t um, un, uk, bm, bn, bk;
+ dim_t bn_small_k, bk_traditional, blocking_small_k;
+
+ int (*copyA)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ int (*copyB)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ int (*kernel)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_r)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_c)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ int (*kernel_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ // Gemv kernels
+ void (*gemv_s8u8s32_kernel)(const dim_t, const dim_t, const float,
+ const int8_t*, const dim_t, const uint8_t*,
+ const float, int32_t*);
+
+ void (*gemv_u8s8s32_kernel)(const dim_t, const dim_t, const float,
+ const uint8_t*, const dim_t, const int8_t*,
+ const float, int32_t*);
+
+ // Gemv parameters
+ int swap;
+
+} blas_t;
+
+
+class jit_avx512_core_u8_copy_an_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_an_kern);
+
+ public:
+ jit_avx512_core_u8_copy_an_kern();
+};
+
+class jit_avx512_core_u8_copy_at_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_at_kern);
+
+ public:
+ jit_avx512_core_u8_copy_at_kern();
+};
+
+class jit_avx512_core_u8_copy_bn_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bn_kern);
+
+ public:
+ jit_avx512_core_u8_copy_bn_kern();
+};
+
+class jit_avx512_core_u8_copy_bt_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_bt_kern);
+
+ public:
+ jit_avx512_core_u8_copy_bt_kern();
+};
+
+class jit_avx512_core_u8_copy_sum_an_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_an_kern);
+
+ public:
+ jit_avx512_core_u8_copy_sum_an_kern();
+};
+
+class jit_avx512_core_u8_copy_sum_at_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_at_kern);
+
+ public:
+ jit_avx512_core_u8_copy_sum_at_kern();
+};
+
+class jit_avx512_core_u8_copy_sum_bn_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bn_kern);
+
+ public:
+ jit_avx512_core_u8_copy_sum_bn_kern();
+};
+
+class jit_avx512_core_u8_copy_sum_bt_kern : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8_copy_sum_bt_kern);
+
+ public:
+ jit_avx512_core_u8_copy_sum_bt_kern();
+};
+
+}
+}
+}
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp
new file mode 100644
index 0000000000..db9dd9ef97
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/gemv.hpp
@@ -0,0 +1,28 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg);
+int gemv_threading_driver(blas_t *arg);
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp
new file mode 100644
index 0000000000..e4b8e1cde2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.cpp
@@ -0,0 +1,1409 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <cstdint>
+#include <mutex>
+
+#include "common.hpp"
+#include "mkldnn_types.h"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "jit_avx512_core_gemm_s8u8s32.hpp"
+#include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
+#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
+#include "gemv.hpp"
+
+#if defined(_MSC_VER)
+#include <malloc.h>
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+typedef struct {
+ int nthrs_m, nthrs_n;
+ int partition;
+ int copy_type;
+} blas_thread_t;
+
+static inline void round_to_nearest(int32_t *rounded_val, double fp_val) {
+ if (fp_val >= 0.) {
+ fp_val += 0.5;
+ if (fp_val > INT32_MAX) {
+ fp_val = INT32_MAX;
+ }
+ } else {
+ fp_val -= 0.5;
+ if (fp_val < INT32_MIN) {
+ fp_val = INT32_MIN;
+ }
+ }
+ *rounded_val = (int32_t) fp_val;
+}
+
+static inline void add_results(const dim_t m, const dim_t n, const dim_t k,
+ const float alpha, const float beta, const int32_t *c_partial_sum,
+ const dim_t ldcp, int32_t *c_data, const dim_t ldc,
+ const int32_t *a_row_sum, const int32_t *b_col_sum, const int8_t ao,
+ const int8_t bo, const int32_t *co, const int offsetc)
+{
+ for (dim_t j = 0; j < n; ++j) {
+ for (dim_t i = 0; i < m; ++i) {
+ int32_t ctemp = c_partial_sum[i + j * ldcp];
+
+ if (alpha == 1.0f) {
+ if (beta == 0.0f) {
+ c_data[i + j * ldc] = ctemp;
+ } else {
+ double c_float = (double) beta
+ * (double) c_data[i + j * ldc];
+ c_float += (double) ctemp;
+ round_to_nearest(&c_data[i + j * ldc], c_float);
+ }
+ } else if (alpha == -1.0f) {
+ if (beta == 0.0f) {
+ c_data[i + j * ldc] = -ctemp;
+ } else {
+ double c_float = (double) beta
+ * (double) c_data[i + j * ldc];
+ c_float -= (double) ctemp;
+ round_to_nearest(&c_data[i + j * ldc], c_float);
+ }
+ } else {
+ if (beta == 0.0f) {
+ double c_float = alpha * (double) ctemp;
+ round_to_nearest(&c_data[i + j * ldc], c_float);
+ } else {
+ double c_float = alpha * (double) ctemp +
+ beta * (double) c_data[i + j * ldc];
+ round_to_nearest(&c_data[i + j * ldc], c_float);
+ }
+ }
+
+ if (offsetc == FIX_OFFSET) {
+ c_data[i + j * ldc] += co[0];
+ } else if (offsetc == ROW_OFFSET) {
+ c_data[i + j * ldc] += co[j];
+ } else if (offsetc == COL_OFFSET) {
+ c_data[i + j * ldc] += co[i];
+ }
+ }
+ }
+}
+
+// TODO Find a better place for those functions.
+static inline dim_t ld_padd(const dim_t x)
+{
+ return ((x + ((2048 / sizeof(int32_t)) - 1)) / (2048 / sizeof(int32_t)))
+ * (2048 / sizeof(int32_t)) + (64 / sizeof(int32_t));
+}
+
+void igemm_inner_kernel(const dim_t m, const dim_t n, const dim_t k,
+ const int8_t *a, const uint8_t *b, float beta, int32_t *c,
+ const dim_t ldc, const int32_t *a_row_sum, const int32_t *b_col_sum,
+ const int32_t *co, const int offsetc, const blas_t *arg)
+{
+ int8_t ao = arg->ao;
+ int8_t bo = arg->bo;
+ int32_t co_0 = (offsetc == NO_OFFSET)? 0 : co[0];
+
+ // Since m and n are limited by blocking, stack overflow may not happen;
+ // it's up to 32kB
+#if !defined(_MSC_VER)
+ int32_t col_offset[m];
+ int32_t row_offset[n];
+#else
+ int32_t *col_offset = (int32_t *) _alloca(sizeof(*col_offset) * m);
+ int32_t *row_offset = (int32_t *) _alloca(sizeof(*row_offset) * n);
+#endif
+
+ int col_req = 0;
+ int row_req = 0;
+
+ if ((bo != 0) || (offsetc == COL_OFFSET))
+ col_req = 1;
+ if ((ao != 0) || (offsetc == ROW_OFFSET))
+ row_req = 1;
+
+ // It needs one of colum or row offsets, but it doesn't need both
+ if (((ao != 0) && (bo != 0)) || ((offsetc == FIX_OFFSET) && (co_0 != 0))) {
+ if ((col_req == 0) && (row_req == 0)) {
+ if (m <= n) {
+ col_req = 1;
+ } else {
+ row_req = 1;
+ }
+ }
+ }
+
+ if (col_req) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] = 0;
+
+ if (offsetc == COL_OFFSET) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] += co[i];
+ }
+
+ if (bo != 0) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] += bo * a_row_sum[i];
+ }
+ }
+
+ if (row_req) {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] = 0;
+
+ if (offsetc == ROW_OFFSET) {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] += co[i];
+ }
+
+ if (ao != 0) {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] += ao * b_col_sum[i];
+ }
+ }
+
+ if ((offsetc == FIX_OFFSET) && (co_0 != 0)) {
+ if (col_req) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] += co_0;
+ } else {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] += co_0;
+ }
+ }
+
+ if ((ao != 0) && (bo != 0)) {
+ if (col_req) {
+ for (dim_t i = 0; i < m; i++)
+ col_offset[i] += (int32_t) k * ao * bo;
+ } else {
+ for (dim_t i = 0; i < n; i++)
+ row_offset[i] += (int32_t) k * ao * bo;
+ }
+ }
+
+ if (col_req == 0) {
+ if (row_req == 0) {
+ if (beta == 0.0) {
+ arg->kernel_b0(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ } else {
+ arg->kernel(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ }
+ } else {
+ if (beta == 0.0) {
+ arg->kernel_b0_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ } else {
+ arg->kernel_r(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ }
+ }
+ } else {
+ if (row_req == 0) {
+ if (beta == 0.0) {
+ arg->kernel_b0_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ } else {
+ arg->kernel_c(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ }
+ } else {
+ if (beta == 0.0) {
+ arg->kernel_b0_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ } else {
+ arg->kernel_b(&m, &n, &k, NULL, a, b, c, ldc, col_offset,
+ row_offset);
+ }
+ }
+ }
+}
+
+static inline void *align(void *ptr, size_t alignment)
+{
+ return (void *) utils::rnd_up((uintptr_t) ptr, alignment);
+}
+
+static int gemm_kernel_driver(const dim_t m, const dim_t n, const dim_t k,
+ const int8_t *a, const uint8_t *b, int32_t *c, const int32_t *co,
+ const blas_t *arg)
+{
+ dim_t lda = arg->lda;
+ dim_t ldb = arg->ldb;
+ dim_t ldc = arg->ldc;
+ int8_t ao = arg->ao;
+ int8_t bo = arg->bo;
+ float alpha = *arg->alpha;
+ float beta = *arg->beta;
+
+ if (m <= 0 || n <= 0) {
+ return 0;
+ }
+
+ // Padding along K dimension.
+ dim_t k_padd = 0;
+ if (k <= arg->bk_traditional) {
+ k_padd = utils::rnd_up(k, arg->uk);
+ k_padd = nstl::max(128LL, k_padd);
+ } else if (k < 2 * arg->bk) {
+ k_padd = utils::rnd_up(k / 2, arg->uk);
+ } else {
+ k_padd = arg->bk;
+ }
+
+ // Padding along M dimension.
+ dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
+ arg->um);
+
+ // Padding along N dimension.
+ dim_t n_padd = 0;
+ if (k < arg->blocking_small_k) {
+ n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
+ arg->bn_small_k), arg->un);
+ } else {
+ n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
+ arg->un);
+ }
+
+ // Padding for temporary buffer for C
+ dim_t ldc_buf = ld_padd(m_padd);
+
+ dim_t strideAm = (arg->transa == 0)? 1 : lda;
+ dim_t strideAn = (arg->transa != 0)? 1 : lda;
+ dim_t strideBm = (arg->transb == 0)? 1 : ldb;
+ dim_t strideBn = (arg->transb != 0)? 1 : ldb;
+
+ size_t a_buf_nelems = m_padd * k_padd;
+ size_t b_buf_nelems = k_padd * n_padd;
+ size_t a_row_sum_nelems = m_padd;
+ size_t b_col_sum_nelems = n_padd;
+
+ size_t mem_size = a_buf_nelems * sizeof(*a) + PAGE_4K
+ + b_buf_nelems * sizeof(*b) + PAGE_4K
+ + a_row_sum_nelems * sizeof(*c) + PAGE_4K
+ + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
+
+ bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
+ if (need_c_buffer) {
+ size_t c_buf_nelems = ldc_buf * n_padd;
+ mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
+ }
+
+ char *mem = (char *) malloc(mem_size, 128);
+
+ if (!mem) {
+ return -1;
+ }
+
+ int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
+ uint8_t *bufferB = (uint8_t *) align(bufferA + a_buf_nelems, PAGE_4K);
+ int32_t *a_row_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
+ int32_t *b_col_sum = (int32_t *) align(a_row_sum + a_row_sum_nelems,
+ PAGE_4K);
+
+ int32_t *bufferC = NULL;
+ if (need_c_buffer) {
+ bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
+ }
+
+ float beta_saved = beta;
+
+ int a_block_copied = 0;
+ dim_t sizeM = 0;
+ for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
+ sizeM = m - Bm;
+ if (sizeM > m_padd)
+ sizeM = m_padd;
+
+ dim_t sizeK = 0;
+ for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
+ sizeK = k - Bk;
+ if (sizeK > k_padd)
+ sizeK = k_padd;
+
+ // Scale C blocks by beta only for the first time
+ if (Bk == 0)
+ beta = beta_saved;
+ else
+ beta = 1.0f;
+
+ // Apply C offset when to the last k-block of the partial sum.
+ int offsetc = NO_OFFSET;
+ if (Bk + sizeK == k)
+ offsetc = arg->offsetc;
+
+ dim_t sizeN = 0;
+ for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
+ sizeN = n - Bn;
+ if (sizeN > n_padd)
+ sizeN = n_padd;
+
+ const uint8_t *b_block = b + Bk * strideBm + Bn * strideBn;
+ arg->copyB(&sizeK, &sizeN, b_block, &ldb, NULL, bufferB, NULL,
+ NULL, b_col_sum);
+
+ dim_t sizeUM = 0;
+ for (dim_t Um = 0; Um < sizeM; Um += sizeUM) {
+ sizeUM = sizeM - Um;
+ if (sizeUM > arg->um)
+ sizeUM = arg->um;
+
+ /*
+ * Use the whole A buffer only if we have multiple B blocks
+ * for k-dimension, otherwise we are wasting cache to store
+ * B and C blocks.
+ */
+ dim_t Um_forA = 0;
+ if (sizeN < n)
+ Um_forA = Um;
+
+ const int8_t *a_block = a + (Bm + Um) * strideAm
+ + Bk * strideAn;
+ if (!a_block_copied) {
+ arg->copyA(&sizeK, &sizeUM, a_block, &lda, NULL,
+ bufferA + Um_forA * sizeK, NULL, NULL,
+ a_row_sum + Um_forA);
+ }
+
+ int32_t *c_block = c + (Bm + Um) + Bn * ldc;
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = Bn;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = Bm + Um;
+ }
+ if (need_c_buffer) {
+ igemm_inner_kernel(sizeUM, sizeN, sizeK,
+ bufferA + Um_forA * sizeK, bufferB, 0.0f,
+ bufferC + Um, ldc_buf, a_row_sum + Um_forA,
+ b_col_sum, NULL, NO_OFFSET, arg);
+
+ // Finish the block adding the necessary alpha, beta
+ // and offsets.
+ add_results(sizeUM, sizeN, sizeK, alpha, beta,
+ bufferC + Um, ldc_buf, c_block, ldc,
+ a_row_sum + Um_forA, b_col_sum, ao, bo,
+ co + co_stride, offsetc);
+ } else {
+ igemm_inner_kernel(sizeUM, sizeN, sizeK,
+ bufferA + Um_forA * sizeK, bufferB, beta,
+ c_block, ldc, a_row_sum + Um_forA, b_col_sum,
+ co + co_stride, offsetc, arg);
+ }
+ }
+ a_block_copied = 1;
+ }
+ a_block_copied = 0;
+ }
+ }
+
+ free(mem);
+
+ return 0;
+}
+
+static int kernel_driver_parallel_acopiedbcopy(const dim_t m, const dim_t n,
+ const dim_t k, const int8_t *bufferA, const uint8_t *b,
+ const float beta, int32_t *c, const int offsetc, const int32_t *co,
+ const int32_t *a_row_sum, const blas_t *arg)
+{
+ dim_t ldb = arg->ldb;
+ dim_t ldc = arg->ldc;
+ int8_t ao = arg->ao;
+ int8_t bo = arg->bo;
+ float alpha = *arg->alpha;
+
+ if (m <= 0 || n <= 0) {
+ return 0;
+ }
+
+ // Padding along N dimension.
+ dim_t n_padd = 0;
+ if (k < arg->blocking_small_k) {
+ n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un),
+ arg->bn_small_k), arg->un);
+ } else {
+ n_padd = utils::rnd_up(nstl::min(nstl::max(n, arg->un), arg->bn),
+ arg->un);
+ }
+
+ // Padding for temporary buffer for C
+ dim_t ldc_buf = ld_padd(m);
+
+ dim_t strideBn = (arg->transb != 0)? 1 : ldb;
+
+ size_t b_buf_nelems = k * n_padd;
+ size_t b_col_sum_nelems = n_padd;
+
+ size_t mem_size = b_buf_nelems * sizeof(*b) + PAGE_4K
+ + b_col_sum_nelems * sizeof(*c) + PAGE_4K;
+
+ bool need_c_buffer = alpha != 1.0f || (beta != 1 && beta != 0);
+ if (need_c_buffer) {
+ size_t c_buf_nelems = ldc_buf * n_padd;
+ mem_size += c_buf_nelems * sizeof(*c) + PAGE_4K;
+ }
+
+ char *mem = (char *) malloc(mem_size, 128);
+
+ if (!mem) {
+ return -1;
+ }
+
+ uint8_t *bufferB = (uint8_t *) align(mem, PAGE_4K);
+ int32_t *b_col_sum = (int32_t *) align(bufferB + b_buf_nelems, PAGE_4K);
+
+ int32_t *bufferC = NULL;
+ if (need_c_buffer) {
+ bufferC = (int32_t *) align(b_col_sum + b_col_sum_nelems, PAGE_4K);
+ }
+
+ dim_t sizeN = 0;
+ for (dim_t Bn = 0; Bn < n; Bn += sizeN) {
+ sizeN = n - Bn;
+ if (sizeN > n_padd)
+ sizeN = n_padd;
+
+ // Implement the kernel here.
+ const uint8_t *b_block = b + Bn * strideBn;
+ arg->copyB(&k, &sizeN, b_block, &ldb, NULL, bufferB, NULL, NULL,
+ b_col_sum);
+
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = Bn;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = 0;
+ }
+ int32_t *c_block = c + Bn * ldc;
+ if (need_c_buffer) {
+ igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, 0.0f, bufferC,
+ ldc_buf, a_row_sum, b_col_sum, NULL, NO_OFFSET, arg);
+
+ // Finish the block adding the necessary alpha, beta and offsets.
+ add_results(m, sizeN, k, alpha, beta, bufferC, ldc_buf, c_block,
+ ldc, a_row_sum, b_col_sum, ao, bo, co + co_stride,
+ offsetc);
+ } else {
+ igemm_inner_kernel(m, sizeN, k, bufferA, bufferB, beta, c_block,
+ ldc, a_row_sum, b_col_sum, co + co_stride, offsetc, arg);
+ }
+ }
+
+ free(mem);
+
+ return 0;
+
+}
+
+#define N2D_MAX_AVX512 384
+#define M2D_MIN_AVX512 384
+#define VECLEN 16
+#define NCONS 1
+static inline void set_thread_opts_avx512(int *p_nthrs,
+ blas_thread_t *thread_info, const blas_t *arg)
+{
+ int nthrs = *p_nthrs;
+ dim_t m = arg->m;
+ dim_t n = arg->n;
+
+ thread_info->nthrs_m = 0;
+ thread_info->nthrs_n = 0;
+ thread_info->copy_type = COPY_NONE; // By default don't do parallel copy.
+
+ int condition_2D_bsrc = -1;
+ if ((256 * m > nthrs * n) && (nthrs * m < 256 * n)) {
+ condition_2D_bsrc = 1;
+ } else {
+ condition_2D_bsrc = 0;
+ }
+
+ int condition_1D_copya = 0;
+ if ((m >= 1000) && (n >= nthrs * N2D_MAX_AVX512 / 4)) {
+ condition_2D_bsrc = 0;
+ condition_1D_copya = 1;
+ }
+
+ // If offset is non-zero, we need to keep 1D_copya to reduce update overhead
+ if (arg->ao != 0 || arg->bo != 0 || arg->co[0] != 0
+ || arg->offsetc != FIX_OFFSET) {
+ condition_2D_bsrc = 0;
+ condition_1D_copya = 1;
+ }
+
+ if (condition_2D_bsrc == 1) {
+ int nthrs_m = 1;
+ int nthrs_n = nthrs;
+
+ while ((nthrs_n % 2 == 0) &&
+ (n / nthrs > N2D_MAX_AVX512 ||
+ n / nthrs_n <= N2D_MAX_AVX512 / 2) &&
+ (m / nthrs_m >= 2 * M2D_MIN_AVX512) &&
+ (nthrs_m < 4)) {
+ nthrs_m *= 2;
+ nthrs_n /= 2;
+ }
+
+ thread_info->nthrs_m = nthrs_m;
+ thread_info->nthrs_n = nthrs_n;
+ thread_info->partition = PARTITION_2D;
+
+ // Reset the total number of threads that will be used.
+ *p_nthrs = nthrs_m * nthrs_n;
+
+ } else if (condition_1D_copya && mkldnn_thr_syncable()) {
+ // Use parallel copy A algorithm
+ thread_info->copy_type = COPY_A;
+ thread_info->partition = PARTITION_1D_COL;
+ } else {
+ if ((m > n) && (m / nthrs >= VECLEN || n < NCONS * nthrs)) {
+ thread_info->partition = PARTITION_1D_ROW;
+ } else {
+ thread_info->partition = PARTITION_1D_COL;
+ }
+ }
+}
+#undef N2D_MAX_AVX512
+#undef M2D_MIN_AVX512
+#undef VECLEN
+#undef NCONS
+
+static inline void partition_1d(const int ithr, const int nthrs, const dim_t n,
+ dim_t *t_offset, dim_t *t_block)
+{
+ dim_t band = n / nthrs;
+
+ dim_t tail = n - (nthrs - 1) * band;
+ if (tail > (band + 1))
+ band++;
+ tail = n - (nthrs - 1) * band;
+
+ if (ithr < (nthrs - 1))
+ *t_block = band;
+ else
+ *t_block = tail;
+
+ *t_offset = ithr * band;
+
+ if (*t_offset >= n) {
+ *t_block = 0;
+ *t_offset = 0;
+ } else if ((*t_offset + *t_block) > n) {
+ *t_block = n - *t_offset;
+ }
+}
+
+static inline void partition_2d(const int ithr, int *nthrs, const int ithr_i,
+ const int ithr_j, const int nthrs_m, const int nthrs_n, const dim_t m,
+ const dim_t n, dim_t *p_m_disp, dim_t *p_m_band, dim_t *p_n_disp,
+ dim_t *p_n_band)
+{
+ dim_t m_disp = 0, n_disp = 0;
+ dim_t m_band = 0, n_band = 0;
+
+ int mdiv = nthrs_m;
+ int ndiv = nthrs_n;
+
+ dim_t m_bandt = m / mdiv; /* size per thread */
+ dim_t n_bandt = n / ndiv; /* size per thread */
+ int firstmgroup = mdiv - 1;
+ int firstngroup = ndiv - 1;
+ dim_t firstmval = m_bandt;
+ dim_t firstnval = n_bandt;
+
+ int mthr_used = mdiv;
+ if (m - (mdiv - 1) * m_bandt > m_bandt + 1) {
+ if (m - (mdiv - 1) * m_bandt > mdiv)
+ ++m_bandt;
+
+ firstmval = m_bandt + 1;
+ mthr_used = (int) (m / firstmval);
+
+ if (mthr_used * firstmval < m)
+ ++mthr_used;
+
+ firstmgroup = mthr_used - 1;
+ }
+
+ int nthr_used = ndiv;
+ if (n - (ndiv - 1) * n_bandt > n_bandt + 1) {
+ firstnval = n_bandt + 1;
+ nthr_used = (int) (n / firstnval);
+
+ if (nthr_used * firstnval < n)
+ ++nthr_used;
+
+ firstngroup = nthr_used - 1;
+ }
+
+ *nthrs = mthr_used * nthr_used;
+
+ if (ithr < *nthrs) {
+ if (ithr_i < firstmgroup) {
+ m_band = firstmval;
+ m_disp = ithr_i * firstmval;
+ } else if (ithr_i <= mthr_used - 2) {
+ m_band = m_bandt;
+ m_disp = firstmgroup * firstmval + (ithr_i - firstmgroup) * m_bandt;
+ } else {
+ m_disp = firstmgroup * firstmval
+ + (mthr_used - 1 - firstmgroup) * m_bandt;
+ m_band = nstl::max(0LL, m - m_disp);
+ }
+
+ if (ithr_j < firstngroup) {
+ n_band = firstnval;
+ n_disp = ithr_j * firstnval;
+ } else if (ithr_j <= nthr_used - 2) {
+ n_band = n_bandt;
+ n_disp = firstngroup * firstnval + (ithr_j - firstngroup) * n_bandt;
+ } else {
+ n_disp = firstngroup * firstnval
+ + (nthr_used - 1 - firstngroup) * n_bandt;
+ n_band = nstl::max(0LL, n - n_disp);
+ }
+ m_disp = nstl::max(nstl::min(m_disp, m - 1), 0LL);
+ n_disp = nstl::max(nstl::min(n_disp, n - 1), 0LL);
+ }
+
+ if (ithr < *nthrs) {
+ *p_m_disp = m_disp;
+ *p_n_disp = n_disp;
+ *p_m_band = m_band;
+ *p_n_band = n_band;
+ } else {
+ *p_m_disp = 0;
+ *p_n_disp = 0;
+ *p_m_band = 0;
+ *p_n_band = 0;
+ }
+
+ return;
+}
+
+static inline void decompose_matrices(const int ithr, int *nthrs, dim_t *m,
+ dim_t *n, dim_t *k, const int8_t **a, const uint8_t **b, int32_t **c,
+ const int32_t **co, const blas_thread_t *thread_info, const blas_t *arg)
+{
+ dim_t strideAm = (arg->transa == 0)? 1 : arg->lda;
+ dim_t strideBn = (arg->transb != 0)? 1 : arg->ldb;
+ int offsetc = arg->offsetc;
+
+ switch (thread_info->partition) {
+ case PARTITION_1D_ROW:
+ {
+ dim_t offset = 0;
+ dim_t block = 0;
+ partition_1d(ithr, *nthrs, arg->m, &offset, &block);
+
+ *m = block;
+ *n = arg->n;
+ *k = arg->k;
+
+ // Set matrix A.
+ *a = arg->a + offset * strideAm;
+
+ // Set matrix B.
+ *b = arg->b;
+
+ // Set matrix C.
+ *c = arg->c + offset;
+
+ // Set offset vector for C matrix
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = offset;
+ }
+ *co = arg->co + co_stride;
+ break;
+ }
+
+ case PARTITION_1D_COL:
+ {
+ dim_t offset = 0;
+ dim_t block = 0;
+ partition_1d(ithr, *nthrs, arg->n, &offset, &block);
+
+ *m = arg->m;
+ *n = block;
+ *k = arg->k;
+
+ // Set matrix A.
+ *a = arg->a;
+
+ // Set matrix B.
+ *b = arg->b + offset * strideBn;
+
+ // Set matrix C.
+ *c = arg->c + offset * arg->ldc;
+
+ // Set offset vector for C matrix
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = offset;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = 0;
+ }
+ *co = arg->co + co_stride;
+ break;
+ }
+
+ case PARTITION_2D_COL_MAJOR:
+ {
+ int nthrs_m = thread_info->nthrs_m;
+ int nthrs_n = thread_info->nthrs_n;
+ int ithr_i = ithr % nthrs_m;
+ int ithr_j = ithr / nthrs_m;
+
+ dim_t m_disp = 0;
+ dim_t m_band = 0;
+ dim_t n_disp = 0;
+ dim_t n_band = 0;
+
+ partition_2d(ithr, nthrs, ithr_i, ithr_j, nthrs_m, nthrs_n,
+ arg->m, arg->n, &m_disp, &m_band, &n_disp, &n_band);
+
+ *m = m_band;
+ *n = n_band;
+ *k = arg->k;
+
+ // Set matrix A.
+ *a = arg->a + m_disp * strideAm;
+
+ // Set matrix B.
+ *b = arg->b + n_disp * strideBn;
+
+ // Set matrix C.
+ *c = arg->c + m_disp + n_disp * arg->ldc;
+
+ // Set offset vector for C matrix
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = n_disp;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = m_disp;
+ }
+ *co = arg->co + co_stride;
+ break;
+ }
+ }
+}
+
+#define MULTIPLIER 10
+static int parallel_a_copy(const int ithr, const int nthrs, const dim_t m,
+ const dim_t n, const dim_t k, const int8_t *a, const uint8_t *b,
+ int32_t *c, const int32_t *co, const blas_t *arg,
+ char **p_shared_mem)
+{
+ const dim_t lda = arg->lda;
+ const dim_t ldb = arg->ldb;
+ const dim_t strideAm = (arg->transa == 0)? 1 : lda;
+ const dim_t strideAn = (arg->transa != 0)? 1 : lda;
+ const dim_t strideBm = (arg->transb == 0)? 1 : ldb;
+
+ // Padding along M dimension.
+ dim_t m_padd = utils::rnd_up(nstl::min(nstl::max(m, arg->um), arg->bm),
+ arg->um);
+
+ // Padding along K dimension.
+ dim_t k_padd = 0;
+ if (k <= arg->bk_traditional) {
+ k_padd = utils::rnd_up(k, arg->uk);
+ k_padd = nstl::max(128LL, k_padd);
+ } else if (k < 2 * arg->bk) {
+ k_padd = utils::rnd_up(k / 2, arg->uk);
+ } else {
+ k_padd = arg->bk;
+ }
+
+ m_padd *= nthrs > MULTIPLIER ? MULTIPLIER : nthrs;
+ if (m_padd > m) {
+ m_padd = utils::rnd_up(m, arg->um);
+ }
+
+ size_t a_buf_nelems = m_padd * k_padd;
+
+ // Allocate shared memory for A and its row sum buffers in master thread.
+ if (ithr == 0) { // If thread master
+ size_t a_row_sum_nelems = m_padd;
+
+ size_t mem_size = (a_buf_nelems * sizeof(*a) + PAGE_4K)
+ + a_row_sum_nelems * sizeof(*c) + PAGE_4K;
+
+ *p_shared_mem = (char *) malloc(mem_size, 128);
+
+ }
+ mkldnn_thr_barrier();
+
+ char *mem = *p_shared_mem;
+ int8_t *bufferA = (int8_t *) align(mem, PAGE_4K);
+ int32_t *a_row_sum = (int32_t *) align(bufferA + a_buf_nelems, PAGE_4K);
+
+ if (!mem) {
+ return -1;
+ }
+
+ int result = 0; // Return status
+
+ dim_t sizeK = 0;
+ for (dim_t Bk = 0; Bk < k; Bk += sizeK) {
+ sizeK = k - Bk;
+ if (sizeK > k_padd)
+ sizeK = k_padd;
+
+ // Scale C blocks by beta only for the first term of partial sum.
+ float beta = 1.0f;
+ if (Bk == 0)
+ beta = *(arg->beta);
+
+ // Apply C offset for the last k-block of the partial sum.
+ int offsetc = NO_OFFSET;
+ if (Bk + sizeK == k)
+ offsetc = arg->offsetc;
+
+ dim_t sizeM = 0;
+ for (dim_t Bm = 0; Bm < m; Bm += sizeM) {
+ sizeM = m - Bm;
+ if (sizeM > m_padd)
+ sizeM = m_padd;
+
+ if (ithr < nthrs) {
+ dim_t band = (sizeM + nthrs - 1) / nthrs;
+ band = utils::rnd_up(band, arg->um);
+
+ dim_t offset = band * ithr;
+
+ // If offset is too large don't use that thread for copying.
+ if (offset >= sizeM) {
+ offset = 0;
+ band = 0;
+ }
+
+ // Handle the tail of the copy.
+ if (offset + band > sizeM) {
+ band = sizeM - offset;
+ }
+
+ if (band > 0) {
+ const int8_t *a_block = a + (Bm + offset) * strideAm
+ + Bk * strideAn;
+ arg->copyA(&sizeK, &band, a_block, &lda, NULL,
+ bufferA + offset * sizeK, NULL, NULL,
+ a_row_sum + offset);
+ }
+ }
+ mkldnn_thr_barrier(); // Wait for finishing parallel copy.
+
+ const uint8_t *b_block = b + Bk * strideBm;
+ int32_t *c_block = c + Bm;
+ dim_t co_stride = 0;
+ if (offsetc == FIX_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == ROW_OFFSET) {
+ co_stride = 0;
+ } else if (offsetc == COL_OFFSET) {
+ co_stride = Bm;
+ }
+
+ result = kernel_driver_parallel_acopiedbcopy(sizeM, n, sizeK,
+ bufferA, b_block, beta, c_block, offsetc, co + co_stride,
+ a_row_sum, arg);
+
+ mkldnn_thr_barrier(); // Wait for kernel computations to finish.
+ }
+ }
+
+ // Free memory allocated in master thread
+ if (ithr == 0) {
+ free(mem);
+ }
+
+ return result;
+}
+#undef MULTIPLIER
+
+static inline void get_omp_thread_count(dim_t m, dim_t n, dim_t k,
+ double fp_per_cycle, int *nthrs)
+{
+ double omp_overhead_small_core = 3.0e+3;
+ double omp_intercept_big_core = 4.0e+3;
+ double omp_slope_big_core = 5.0e+2;
+
+ double gemm_cycles = 8.0 * m * n * k / fp_per_cycle;
+
+ int i = *nthrs;
+
+ // Use a different model for omp overheads if nthrs is <= 4
+ if (*nthrs <= 4 && omp_overhead_small_core > 0) {
+ double omp_cycles = omp_overhead_small_core;
+ if (gemm_cycles < omp_cycles) {
+ *nthrs = 1;
+ return;
+ } else {
+ while (i > 1) {
+ if (omp_cycles * i < gemm_cycles * (i - 1)) break;
+ --i;
+ }
+ }
+ } else {
+ if (gemm_cycles < (omp_intercept_big_core + 2 * omp_slope_big_core)) {
+ *nthrs = 1;
+ return;
+ }
+
+ // adaptive decrement to march faster·
+ while (i > 1) {
+ double omp_cycles = omp_intercept_big_core + i * omp_slope_big_core;
+ if (omp_cycles * i < gemm_cycles * (i - 1))
+ break;
+
+ if (i < 10)
+ i -= 2;
+ else if (i < 30)
+ i -= 4;
+ else
+ i -= 8;
+ }
+ }
+
+ if (i < 1)
+ i = 1;
+
+ *nthrs = i;
+}
+
+#define CACHE_LINE_SIZE 64
+static int gemm_threading_driver(blas_t *arg)
+{
+ if ((arg->m <= 0) || (arg->n <= 0))
+ return mkldnn_success;
+
+ if (gemm_s8u8s32_jump_to_gemv_s8u8s32(arg)) {
+ return mkldnn_success;
+ }
+
+ int nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
+ get_omp_thread_count(arg->m, arg->n, arg->k, 64.0, &nthr);
+
+ if (nthr == 1) {
+ return gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a, arg->b,
+ arg->c, arg->co, arg);
+ }
+
+ int *results = (int *) malloc(sizeof(*results) * nthr * CACHE_LINE_SIZE,
+ PAGE_4K);
+
+ if (!results) {
+ return -1;
+ }
+
+ for (int i = 0; i < nthr; i++) {
+ results[i * CACHE_LINE_SIZE] = 0; // Initialize to success
+ }
+
+ char *shared_mem = NULL;
+
+ parallel(nthr, [&](const int ithr, const int nthr) {
+ int nthrs = nthr;
+ if (nthrs == 1) {
+ results[0] = gemm_kernel_driver(arg->m, arg->n, arg->k, arg->a,
+ arg->b, arg->c, arg->co, arg);
+ } else {
+ blas_thread_t thread_info;
+ set_thread_opts_avx512(&nthrs, &thread_info, arg);
+
+ const int8_t *a = NULL;
+ const uint8_t *b = NULL;
+ int32_t *c = NULL;
+ const int32_t *co = NULL;
+ dim_t m = -1;
+ dim_t n = -1;
+ dim_t k = -1;
+ decompose_matrices(ithr, &nthrs, &m, &n, &k, &a, &b, &c, &co,
+ &thread_info, arg);
+
+ if (ithr < nthrs) {
+ switch (thread_info.copy_type) {
+ case COPY_A:
+ results[ithr * CACHE_LINE_SIZE] =
+ parallel_a_copy(ithr, nthrs, m, n, k, a, b, c, co, arg,
+ &shared_mem);
+ break;
+
+ default:
+ case COPY_NONE:
+ results[ithr * CACHE_LINE_SIZE] =
+ gemm_kernel_driver(m, n, k, a, b, c, co, arg);
+ break;
+ }
+ }
+ }
+ });
+
+ int result = 0; // Initialize to success
+ for (int i = 0; i < nthr; i++) {
+ if (results[i] != 0) {
+ result = results[i * CACHE_LINE_SIZE];
+ break;
+ }
+ }
+
+ free(results);
+
+ return result;
+}
+#undef CACHE_LINE_SIZE
+
+static jit_avx512_core_u8_copy_an_kern *copy_an;
+static jit_avx512_core_u8_copy_at_kern *copy_at;
+static jit_avx512_core_u8_copy_bn_kern *copy_bn;
+static jit_avx512_core_u8_copy_bt_kern *copy_bt;
+static jit_avx512_core_u8_copy_sum_an_kern *copy_sum_an;
+static jit_avx512_core_u8_copy_sum_at_kern *copy_sum_at;
+static jit_avx512_core_u8_copy_sum_bn_kern *copy_sum_bn;
+static jit_avx512_core_u8_copy_sum_bt_kern *copy_sum_bt;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_r;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_c;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_b;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_r;
+static jit_avx512_core_gemm_s8u8s32_kern *kernel_b0_c;
+static jit_avx512_core_gemv_s8u8s32_kern *gemv_s8u8s32_kernel;
+static jit_avx512_core_gemv_s8u8s32_kern *gemv_u8s8s32_kernel;
+
+static void jit_init(blas_t *arg)
+{
+ static int (*copyAn)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copyAt)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copyBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copyBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copySumAn)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copySumAt)(const dim_t *m, const dim_t *n, const int8_t *a,
+ const dim_t *lda, const int8_t *alpha, int8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copySumBn)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*copySumBt)(const dim_t *m, const dim_t *n, const uint8_t *a,
+ const dim_t *lda, const uint8_t *alpha, uint8_t *b,
+ const dim_t *dummy1, const dim_t *dummy2, int32_t *row_col_sum);
+
+ static int (*kern)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_r)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_c)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b0)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b0_b)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b0_r)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static int (*kern_b0_c)(const dim_t *m, const dim_t *n, const dim_t *k,
+ const float *alpha, const int8_t *a, const uint8_t *b, int32_t *c,
+ const dim_t ldc, const int32_t *col_offset,
+ const int32_t *row_offset);
+
+ static void (*gemv_s8u8s32_kern)(const dim_t, const dim_t, const float,
+ const int8_t*, const dim_t, const uint8_t*,
+ const float, int32_t*);
+
+ static void (*gemv_u8s8s32_kern)(const dim_t, const dim_t, const float,
+ const uint8_t*, const dim_t, const int8_t*,
+ const float, int32_t*);
+
+ if (mayiuse(avx512_core_vnni)) {
+ arg->um = AVX512_UNROLL_M;
+ arg->un = AVX512_UNROLL_N;
+ arg->uk = AVX512_UNROLL_K;
+ arg->bm = AVX512_BM;
+ arg->bn = AVX512_BN;
+ arg->bk = AVX512_BK_VNNI;
+
+ arg->bk_traditional = AVX512_BK_TRADITIONAL;
+ arg->bn_small_k = AVX512_BN_SMALL_K;
+ arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
+ } else {
+ arg->um = AVX512_UNROLL_M;
+ arg->un = AVX512_UNROLL_N;
+ arg->uk = AVX512_UNROLL_K;
+ arg->bm = AVX512_BM;
+ arg->bn = AVX512_BN;
+ arg->bk = AVX512_BK;
+
+ arg->bk_traditional = AVX512_BK_TRADITIONAL;
+ arg->bn_small_k = AVX512_BN_SMALL_K;
+ arg->blocking_small_k = AVX512_BLOCKING_SMALL_K;
+ }
+
+ static std::once_flag initialized;
+ std::call_once(initialized, []{
+
+ copy_an = new jit_avx512_core_u8_copy_an_kern();
+ copy_at = new jit_avx512_core_u8_copy_at_kern();
+ copy_bn = new jit_avx512_core_u8_copy_bn_kern();
+ copy_bt = new jit_avx512_core_u8_copy_bt_kern();
+
+ copy_sum_an = new jit_avx512_core_u8_copy_sum_an_kern();
+ copy_sum_at = new jit_avx512_core_u8_copy_sum_at_kern();
+ copy_sum_bn = new jit_avx512_core_u8_copy_sum_bn_kern();
+ copy_sum_bt = new jit_avx512_core_u8_copy_sum_bt_kern();
+
+ kernel = new jit_avx512_core_gemm_s8u8s32_kern(false, false, false);
+ kernel_b = new jit_avx512_core_gemm_s8u8s32_kern(false, true, true);
+ kernel_r = new jit_avx512_core_gemm_s8u8s32_kern(false, false, true);
+ kernel_c = new jit_avx512_core_gemm_s8u8s32_kern(false, true, false);
+ kernel_b0 = new jit_avx512_core_gemm_s8u8s32_kern(true, false, false);
+ kernel_b0_b = new jit_avx512_core_gemm_s8u8s32_kern(true, true, true);
+ kernel_b0_r = new jit_avx512_core_gemm_s8u8s32_kern(true, false, true);
+ kernel_b0_c = new jit_avx512_core_gemm_s8u8s32_kern(true, true, false);
+
+ gemv_s8u8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
+ gemv_u8s8s32_kernel = new jit_avx512_core_gemv_s8u8s32_kern();
+
+
+ copyAn = copy_an->getCode<int (*)(const dim_t *, const dim_t *,
+ const int8_t *, const dim_t *, const int8_t *, int8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copyAt = copy_at->getCode<int (*)(const dim_t *, const dim_t *,
+ const int8_t *, const dim_t *, const int8_t *, int8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copyBn = copy_bn->getCode<int (*)(const dim_t *, const dim_t *,
+ const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copyBt = copy_bt->getCode<int (*)(const dim_t *, const dim_t *,
+ const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copySumAn = copy_sum_an->getCode<int (*)(const dim_t *, const dim_t *,
+ const int8_t *, const dim_t *, const int8_t *, int8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copySumAt = copy_sum_at->getCode<int (*)(const dim_t *, const dim_t *,
+ const int8_t *, const dim_t *, const int8_t *, int8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copySumBn = copy_sum_bn->getCode<int (*)(const dim_t *, const dim_t *,
+ const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ copySumBt = copy_sum_bt->getCode<int (*)(const dim_t *, const dim_t *,
+ const uint8_t *, const dim_t *, const uint8_t *, uint8_t *,
+ const dim_t *, const dim_t *, int32_t *)>();
+
+ kern = kernel->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b = kernel_b->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_r = kernel_r->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_c = kernel_c->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b0 = kernel_b0->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b0_b = kernel_b0_b->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b0_r = kernel_b0_r->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ kern_b0_c = kernel_b0_c->getCode<int (*)(const dim_t *, const dim_t *,
+ const dim_t *, const float *, const int8_t *, const uint8_t *,
+ int32_t *, const dim_t, const int32_t *, const int32_t *)>();
+
+ gemv_s8u8s32_kern =
+ gemv_s8u8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>
+ (mayiuse(avx512_core_vnni));
+ gemv_u8s8s32_kern =
+ gemv_u8s8s32_kernel -> generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>
+ (mayiuse(avx512_core_vnni));
+ });
+
+ if (arg->bo == 0) { // No need to compute A row sum if bo is zero
+ if (arg->transa == 0) {
+ arg->copyA = copyAn;
+ } else {
+ arg->copyA = copyAt;
+ }
+ } else {
+ if (arg->transa == 0) {
+ arg->copyA = copySumAn;
+ } else {
+ arg->copyA = copySumAt;
+ }
+ }
+
+ if (arg->ao == 0) { // No need to compute B column sum if ao is zero
+ if (arg->transb == 0) {
+ arg->copyB = copyBn;
+ } else {
+ arg->copyB = copyBt;
+ }
+ } else {
+ if (arg->transb == 0) {
+ arg->copyB = copySumBn;
+ } else {
+ arg->copyB = copySumBt;
+ }
+ }
+
+ arg->kernel = kern;
+ arg->kernel_b = kern_b;
+ arg->kernel_r = kern_r;
+ arg->kernel_c = kern_c;
+ arg->kernel_b0 = kern_b0;
+ arg->kernel_b0_b = kern_b0_b;
+ arg->kernel_b0_r = kern_b0_r;
+ arg->kernel_b0_c = kern_b0_c;
+ arg -> gemv_s8u8s32_kernel = gemv_s8u8s32_kern;
+ arg -> gemv_u8s8s32_kernel = gemv_u8s8s32_kern;
+}
+
+mkldnn_status_t jit_avx512_core_gemm_s8u8s32(
+ const char *transA, const char *transB, const char *offsetC,
+ const int *m, const int *n, const int *k,
+ const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
+ const uint8_t *b, const int *ldb, const int8_t *ob,
+ const float *beta, int32_t *c, const int *ldc, const int32_t *oc)
+{
+ char transa = *transA;
+ char transb = *transB;
+ char offsetc = *offsetC;
+
+ blas_t args;
+
+ // Initialize blas structure
+ args.m = *m;
+ args.n = *n;
+ args.k = *k;
+ args.alpha = alpha;
+ args.a = a;
+ args.lda = *lda;
+ args.b = b;
+ args.ldb = *ldb;
+ args.beta = beta;
+ args.c = c;
+ args.ldc = *ldc;
+ args.transa = (transa == 'N' || transa == 'n') ? 0 : 1;
+ args.transb = (transb == 'N' || transb == 'n') ? 0 : 1;
+ args.um = 0;
+ args.un = 0;
+ args.bm = 0;
+ args.bn = 0;
+ args.bk = 0;
+ args.copyA = NULL;
+ args.copyB = NULL;
+ args.kernel = NULL;
+ args.kernel_b0 = NULL;
+ args.ao = *oa;
+ args.bo = *ob;
+ args.co = oc;
+
+ if (offsetc == 'F' || offsetc == 'f') {
+ args.offsetc = FIX_OFFSET;
+ } else if (offsetc == 'R' || offsetc == 'r') {
+ args.offsetc = ROW_OFFSET;
+ } else { // offsetc == 'C' || offsetc == 'c'
+ args.offsetc = COL_OFFSET;
+ }
+
+ jit_init(&args);
+ int result = gemm_threading_driver(&args);
+
+ return (result < 0) ? mkldnn_out_of_memory : mkldnn_success;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp
new file mode 100644
index 0000000000..b2e2902a12
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32.hpp
@@ -0,0 +1,38 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX512_CORE_GEMM_S8U8S32_HPP
+#define JIT_AVX512_CORE_GEMM_S8U8S32_HPP
+
+#include <cstdint>
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t jit_avx512_core_gemm_s8u8s32(
+ const char *transA, const char *transB, const char *offsetC,
+ const int *m, const int *n, const int *k,
+ const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
+ const uint8_t *b, const int *ldb, const int8_t *ob,
+ const float *beta, int32_t *c, const int *ldc, const int32_t *oc);
+
+}
+}
+}
+
+#endif // JIT_AVX512_CORE_GEMM_S8U8S32_HPP
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp
new file mode 100644
index 0000000000..57554a1852
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.cpp
@@ -0,0 +1,539 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_avx512_core_gemm_s8u8s32_kern.hpp"
+
+
+#ifdef _WIN32
+static const bool is_windows = 1;
+#else
+static const bool is_windows = 0;
+#endif
+
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace Xbyak;
+
+
+
+
+// Convert between vector register lengths.
+static inline Xmm make_xmm(const Xmm &v) { return Xmm(v.getIdx()); }
+static inline Ymm make_ymm(const Xmm &v) { return Ymm(v.getIdx()); }
+
+// Load from or store to C.
+void jit_avx512_core_gemm_s8u8s32_kern::c_load(const Xbyak::Xmm &dst,
+ const Xbyak::Address &src, int nelems)
+{
+ switch (nelems) {
+ default: vmovups(dst, src); break;
+ case 8: vmovups(make_ymm(dst), src); break;
+ case 4: vmovups(make_xmm(dst), src); break;
+ case 2: vmovlps(make_xmm(dst), src); break;
+ case 1: vmovss(make_xmm(dst), src); break;
+ }
+}
+void jit_avx512_core_gemm_s8u8s32_kern::c_store(const Xbyak::Address &dst,
+ const Xbyak::Xmm &src, int nelems)
+{
+ switch (nelems) {
+ default: vmovups(dst, src); break;
+ case 8: vmovups(dst, make_ymm(src)); break;
+ case 4: vmovups(dst, make_xmm(src)); break;
+ case 2: vmovsd(dst, make_xmm(src)); break;
+ case 1: vmovss(dst, make_xmm(src)); break;
+ }
+}
+
+// Perform length-4 dot product accumulations of unsigned and signed bytes
+// in parallel.
+// Use vpdpbusd if VNNI available, otherwise emulate.
+void jit_avx512_core_gemm_s8u8s32_kern::dot_product(const Xmm &dst,
+ const Xmm &src1, const Xmm &src2)
+{
+ if (vnni)
+ vpdpbusd(dst, src1, src2);
+ else {
+ vpmaddubsw(dp_scratch, src1, src2);
+ vpmaddwd(dp_scratch, ones, dp_scratch);
+ vpaddd(dst, dst, dp_scratch);
+ }
+}
+
+// Inner kernel.
+void jit_avx512_core_gemm_s8u8s32_kern::kernel_loop(int unroll_m, int unroll_n,
+ bool cfetch)
+{
+ int um_vecs = (unroll_m + 15) >> 4;
+ Label label_kernel_loop;
+
+ L_aligned(label_kernel_loop); {
+ for (int h = 0; h < 4; h++) {
+ for (int j = 0; j < unroll_n; j++) {
+ const Zmm b = b_regs[j & 1];
+
+ vpbroadcastd(b, ptr[BO + isize *
+ (2 * j + 2 * h * unroll_n - offset_b)]);
+ dot_product(c_regs[0][j], b, a_regs[0]);
+
+ if (j == 1 && !(h & 1))
+ prefetch_b(ptr[BO + isize * (prefetch_size_b
+ + 2 * h * unroll_n - offset_b)]);
+ else if (j % 3 == 0)
+ prefetch_a(ptr[AO + isize * (prefetch_size_a
+ + 32 * (j / 3) + 2 * h * unroll_m - offset_a)]);
+
+ for (int i = 1; i < um_vecs; i++)
+ dot_product(c_regs[i][j], b, a_regs[i]);
+
+ if (cfetch && (j == std::min(1, unroll_n - 1))) {
+ if (h == 3)
+ lea(CO2, ptr[CO2 + LDC]);
+ else if (h < um_vecs)
+ prefetch_c(ptr[CO2 + (16 * h * size)]);
+ }
+
+ if (h == 3 && j == std::min(3, unroll_n - 1))
+ lea(AA, ptr[AA + (32 * isize)]);
+ }
+
+ for (int i = 0; i < um_vecs; i++)
+ vmovups(a_regs[i], ptr[AO + isize *
+ (32 * i + 2 * (h + 1) * unroll_m - offset_a)]);
+
+ if (h == 2)
+ prefetch_x(ptr[AA - (offset_a * isize)]);
+ }
+
+ add(AO, 8 * isize * unroll_m);
+ add(BO, 8 * isize * unroll_n);
+ sub(LoopCount, 1);
+ jg(label_kernel_loop, T_NEAR);
+ }
+}
+
+// k remainder loop for kernel.
+void jit_avx512_core_gemm_s8u8s32_kern::remainder_kernel(int unroll_m,
+ int unroll_n, int unroll_k, int bwidth)
+{
+ if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
+ || (unroll_m < 0) || (unroll_n < 0))
+ return;
+
+ int um_vecs = (unroll_m + 15) >> 4;
+
+ for (int h = 0; h < unroll_k; h++) {
+ for (int j = 0; j < unroll_n; j++) {
+ Zmm b = b_regs[j & 1];
+ auto b_src = ptr[BO + (-isize * offset_b
+ + bwidth * (j + h * unroll_n))];
+
+ switch (bwidth) {
+ case 4:
+ vpbroadcastd(b, b_src);
+ break;
+ case 2:
+ vpbroadcastw(b, b_src);
+ break;
+ case 1:
+ vpbroadcastb(b, b_src);
+ break;
+ }
+ for (int i = 0; i < um_vecs; i++)
+ dot_product(c_regs[i][j], b, a_regs[i]);
+ }
+
+ if (unroll_k > 1) {
+ for (int i = 0; i < um_vecs; i++)
+ vmovups(a_regs[i], ptr[AO + isize * (32 * i
+ + (h + 1) * 2 * unroll_m - offset_a)]);
+ }
+ }
+
+ add(AO, unroll_k * unroll_m * bwidth);
+ add(BO, unroll_k * unroll_n * bwidth);
+}
+
+// Inner loop.
+void jit_avx512_core_gemm_s8u8s32_kern::innerloop(int unroll_m, int unroll_n)
+{
+ if ((unroll_m > IGEMM_UNROLL_M) || (unroll_n > IGEMM_UNROLL_N)
+ || (unroll_m < 0) || (unroll_n < 0))
+ return;
+
+ int um_vecs = (unroll_m + 15) >> 4;
+ int stage1 = unroll_n, stage2 = unroll_n;
+
+ Label label_kernel_loop_1, label_k_main_loop_2, label_kernel_loop_2;
+ Label label_k_main_loop_3, label_kernel_loop_3;
+ Label label_k_remainder_loop_begin, label_k_rem_4, label_k_rem_2;
+ Label label_k_rem_1, label_update_begin;
+
+ mov(AO, A);
+ for (int i = 0; i < um_vecs; i++)
+ vmovups(a_regs[i], ptr[AO + isize * (32 * i - offset_a)]);
+
+ mov(LoopCount, K);
+ sar(LoopCount, 4);
+ jle(label_k_remainder_loop_begin, T_NEAR);
+
+ // Main k loops, broken into three parts to time C prefetching.
+ sub(LoopCount, stage1 + stage2);
+ jle(label_k_main_loop_2, T_NEAR);
+
+ kernel_loop(unroll_m, unroll_n, false);
+
+ L_aligned(label_k_main_loop_2);
+ lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
+ add(LoopCount, stage1);
+ jle(label_k_main_loop_3, T_NEAR);
+
+ kernel_loop(unroll_m, unroll_n, true);
+
+ L_aligned(label_k_main_loop_3);
+ lea(CO2, ptr[CO1 + size * (std::min(unroll_m, 16) - 1)]);
+ add(LoopCount, stage2);
+ jle(label_k_remainder_loop_begin, T_NEAR);
+
+ kernel_loop(unroll_m, unroll_n, true);
+
+ // k remainder handling
+ L_aligned(label_k_remainder_loop_begin);
+ mov(LoopCount, K);
+ test(LoopCount, 8);
+ je(label_k_rem_4, T_NEAR);
+
+ remainder_kernel(unroll_m, unroll_n, 2, 4);
+
+ L_aligned(label_k_rem_4);
+ mov(LoopCount, K);
+ test(LoopCount, 4);
+ je(label_k_rem_2, T_NEAR);
+
+ remainder_kernel(unroll_m, unroll_n, 1, 4);
+
+ L_aligned(label_k_rem_2);
+ mov(LoopCount, K);
+ test(LoopCount, 2);
+ je(label_k_rem_1, T_NEAR);
+
+ Zmm zero = zmm6;
+ Zmm tmp = zmm5;
+
+ vpxorq(zero, zero, zero);
+ for (int i = 0; i < um_vecs; i++) {
+ Zmm a = a_regs[i];
+ vbroadcasti64x4(a, ptr[AO + isize * (16 * i - offset_a)]);
+ vpunpcklwd(tmp, a, zero);
+ vpunpckhwd(a, a, zero);
+ vshufi32x4(a, tmp, a, 0x44);
+ vshufi32x4(a, a, a, 0xD8);
+ }
+
+ remainder_kernel(unroll_m, unroll_n, 1, 2);
+
+ L_aligned(label_k_rem_1);
+ mov(LoopCount, K);
+ test(LoopCount, 1);
+ je(label_update_begin, T_NEAR);
+
+ vpxorq(zero, zero, zero);
+ for (int i = 0; i < um_vecs; i++) {
+ Zmm a = a_regs[i];
+ vbroadcasti32x4(a, ptr[AO + isize * (8 * i - offset_a)]);
+ vpunpcklbw(tmp, a, zero);
+ vpunpckhbw(a, a, zero);
+ vinsertf128(make_ymm(a), make_ymm(tmp), make_xmm(a), 1);
+ vpunpcklwd(tmp, a, zero);
+ vpunpckhwd(a, a, zero);
+ vshufi32x4(a, tmp, a, 0x44);
+ vshufi32x4(a, a, a, 0xD8);
+ }
+
+ remainder_kernel(unroll_m, unroll_n, 1, 1);
+
+ // Add offsets and update C.
+ L_aligned(label_update_begin);
+
+ if (enable_offset_r) {
+ // Add row offsets.
+ mov(rax, coffset_ry);
+ for (int j = 0; j < unroll_n; j++) {
+ Zmm row_offset = zmm0;
+
+ vbroadcastss(row_offset, ptr[rax + size * j]);
+
+ for (int i = 0; i < um_vecs; i++)
+ vpaddd(c_regs[i][j], c_regs[i][j], row_offset);
+ }
+ add(coffset_ry, size * unroll_n);
+ }
+
+ if (enable_offset_c) {
+ // Add column offsets.
+ mov(rax, coffset_cy);
+ for (int i = 0; i < um_vecs; i++) {
+ Zmm col_offset = zmm0;
+
+ c_load(col_offset, ptr[rax + size * 16 * i], unroll_m);
+
+ for (int j = 0; j < unroll_n; j++)
+ vpaddd(c_regs[i][j], c_regs[i][j], col_offset);
+ }
+ }
+
+ Reg64 LDC3 = rax;
+ lea(LDC3, ptr[LDC + LDC * 2]);
+
+ // C updates.
+ int c_off_j = 0;
+ for (int j = 0; j < unroll_n; j++) {
+ if (j > 0 && (j & 3) == 0) {
+ lea(CO1, ptr[CO1 + LDC * 4]);
+ c_off_j += 4;
+ }
+
+ int jj = j - c_off_j;
+
+ for (int i = 0; i < um_vecs; i++) {
+ Zmm c = c_regs[i][j];
+ Zmm c_old = zmm0;
+ decltype(LDC * jj) ldc_mult = (jj == 3) ? LDC3 : LDC * jj;
+
+ auto c_mem = ptr[CO1 + ldc_mult + size * 16 * i];
+
+ if (beta_zero)
+ c_store(c_mem, c, unroll_m);
+ else {
+ c_load(c_old, c_mem, unroll_m);
+ vpaddd(c_old, c, c_old);
+ c_store(c_mem, c_old, unroll_m);
+ }
+
+ vpxorq(c, c, c);
+ }
+ }
+
+ lea(CO1, ptr[CO1 + LDC * (unroll_n - c_off_j)]);
+}
+
+// Outer loop.
+void jit_avx512_core_gemm_s8u8s32_kern::outerloop(int unroll_x, int unroll_y,
+ Label *&cur_outerloop_label)
+{
+ Label label_m_loop, label_n_loop, label_n_remainder_loops[6];
+
+ L(*cur_outerloop_label);
+ cur_outerloop_label++;
+ if (unroll_x >= IGEMM_UNROLL_M) {
+ mov(J, M);
+ cmp(J, unroll_x);
+ jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label.
+ } else {
+ test(J, unroll_x);
+ jle(*cur_outerloop_label, T_NEAR);
+ }
+
+ L_aligned(label_m_loop); {
+ mov(CO1, C);
+ add(C, unroll_x * size);
+
+ mov(BO, B);
+
+ mov(AA, K);
+ imul(AA, AA, unroll_x * isize);
+ lea(AA, ptr[A + AA + isize * prefetch_size_a]);
+
+ if (enable_offset_c) {
+ mov(rax, coffset_cx);
+ mov(coffset_cy, rax);
+ add(rax, unroll_x * size);
+ mov(coffset_cx, rax);
+ }
+
+ if (enable_offset_r) {
+ mov(rax, coffset_rx);
+ mov(coffset_ry, rax);
+ }
+
+ mov(I, N);
+ cmp(I, unroll_y);
+ jl(label_n_remainder_loops[0], T_NEAR);
+
+ L_aligned(label_n_loop); {
+ innerloop(unroll_x, unroll_y);
+ sub(I, unroll_y);
+ cmp(I, unroll_y);
+ jge(label_n_loop, T_NEAR);
+ }
+
+ align(16);
+
+ int label_idx = 0;
+ for (int uy = 16; uy > 0; uy >>= 1) {
+ L(label_n_remainder_loops[label_idx++]);
+ if (unroll_y > uy) {
+ test(I, uy);
+ jle(label_n_remainder_loops[label_idx], T_NEAR);
+
+ innerloop(unroll_x, uy);
+ align(16);
+ }
+ }
+ L(label_n_remainder_loops[label_idx]);
+
+ mov(A, AO);
+ if (unroll_x >= IGEMM_UNROLL_M) {
+ sub(J, unroll_x);
+ cmp(J, unroll_x);
+ jge(label_m_loop);
+ }
+ }
+
+ align(16);
+}
+
+void jit_avx512_core_gemm_s8u8s32_kern::generate()
+{
+ // Prologue
+ preamble();
+ sub(rsp, stack_alloc_size);
+
+ if (is_windows) {
+ mov(A, arg_a);
+ mov(B, arg_b);
+ }
+
+ mov(C, arg_c);
+ mov(LDC, arg_ldc);
+
+ sub(A, -offset_a * isize);
+ sub(B, -offset_b * isize);
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(K, qword[K]);
+
+ lea(LDC, ptr[LDC * size]);
+
+ if (enable_offset_c) {
+ mov(rax, arg_coffset_c);
+ mov(coffset_cx, rax);
+ }
+ if (enable_offset_r) {
+ mov(rax, arg_coffset_r);
+ mov(coffset_rx, rax);
+ }
+
+ for (int i = 0; i < (max_unroll_m >> 4); i++) {
+ for (int j = 0; j < max_unroll_n; j++) {
+ auto &c = c_regs[i][j];
+ vpxorq(c, c, c);
+ }
+ }
+
+ if (!vnni) {
+ mov(rax, 1);
+ movq(make_xmm(ones), rax);
+ vpbroadcastw(ones, make_xmm(ones));
+ }
+
+ Label outerloop_labels[8];
+ Label *cur_outerloop_label = &outerloop_labels[0];
+
+ // Main m loop.
+ outerloop(IGEMM_UNROLL_M, IGEMM_UNROLL_N, cur_outerloop_label);
+
+ // m remainder loops.
+ for (int um = 32; um > 0; um >>= 1)
+ if (IGEMM_UNROLL_M > um)
+ outerloop(um, IGEMM_UNROLL_N, cur_outerloop_label);
+
+ L(*cur_outerloop_label);
+
+ // Epilogue.
+ add(rsp, stack_alloc_size);
+ postamble();
+}
+
+
+jit_avx512_core_gemm_s8u8s32_kern::jit_avx512_core_gemm_s8u8s32_kern(bool
+ beta_zero_, bool enable_offset_c_, bool enable_offset_r_) :
+ jit_generator(nullptr, 100000), arg_a(0), arg_b(0), arg_c(0), arg_ldc(0),
+ arg_coffset_c(0), arg_coffset_r(0), coffset_cx(0), coffset_cy(0),
+ coffset_rx(0), coffset_ry(0)
+{
+ beta_zero = beta_zero_;
+ enable_offset_c = enable_offset_c_;
+ enable_offset_r = enable_offset_r_;
+ vnni = mayiuse(avx512_core_vnni);
+
+ // Assign integer registers
+ M = is_windows ? rcx : rdi;
+ N = is_windows ? rdx : rsi;
+ K = is_windows ? r8 : rdx;
+ A = is_windows ? rsi : r8;
+ B = r9;
+ C = r10;
+ LDC = r11;
+ I = r12;
+ J = r13;
+ LoopCount = rax;
+ AO = r14;
+ BO = r15;
+ CO1 = rbx;
+ CO2 = rbp;
+ AA = is_windows ? rdi : rcx;
+
+ // Assign vector registers
+ dp_scratch = zmm6;
+ ones = zmm7;
+ for (int i = 0; i < (max_unroll_m >> 4); i++)
+ a_regs[i] = Zmm(i);
+ b_regs[0] = zmm4;
+ b_regs[1] = zmm5;
+
+ int rn = 0;
+ for (int i = 0; i < (max_unroll_m >> 4); i++)
+ for (int j = 0; j < max_unroll_n; j++)
+ c_regs[i][j] = Zmm(8 + rn++);
+
+ // Assign stack variables.
+ stack_alloc_size = 32;
+ auto args_offset = stack_alloc_size + get_size_of_abi_save_regs()
+ + 8 + (is_windows ? 48 : 0);
+
+ arg_a = ptr[rsp + (args_offset - 16)];
+ arg_b = ptr[rsp + (args_offset - 8)];
+ arg_c = ptr[rsp + (args_offset + 0)];
+ arg_ldc = ptr[rsp + (args_offset + 8)];
+ arg_coffset_c = ptr[rsp + (args_offset + 16)];
+ arg_coffset_r = ptr[rsp + (args_offset + 24)];
+
+ coffset_cx = qword[rsp + 0];
+ coffset_cy = qword[rsp + 8];
+ coffset_rx = qword[rsp + 16];
+ coffset_ry = qword[rsp + 24];
+
+ generate();
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp
new file mode 100644
index 0000000000..e8efcc1cc8
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemm_s8u8s32_kern.hpp
@@ -0,0 +1,101 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef IGEMM_KERNEL_GENERATOR_HPP
+#define IGEMM_KERNEL_GENERATOR_HPP
+
+#include "jit_generator.hpp"
+
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+class jit_avx512_core_gemm_s8u8s32_kern : public jit_generator {
+public:
+ jit_avx512_core_gemm_s8u8s32_kern(bool beta_zero_, bool enable_offset_c_,
+ bool enable_offset_r_);
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemm_s8u8s32_kern);
+
+protected:
+ bool beta_zero;
+ bool enable_offset_c, enable_offset_r;
+ bool vnni;
+
+ void prefetch_a(const Xbyak::Address &src) {
+ prefetcht0(src);
+ }
+ void prefetch_b(const Xbyak::Address &src) {
+ prefetcht0(src);
+ }
+ void prefetch_c(const Xbyak::Address &src) {
+ prefetchw(src);
+ }
+ void prefetch_x(const Xbyak::Address &src) {
+ prefetcht0(src);
+ }
+
+ void c_load(const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems);
+ void c_store(const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems);
+
+ void dot_product(const Xbyak::Xmm &dst, const Xbyak::Xmm &src1,
+ const Xbyak::Xmm &src2);
+ void kernel_loop(int unroll_m, int unroll_n, bool cfetch);
+ void remainder_kernel(int unroll_m, int unroll_n, int unroll_k, int bwidth);
+ void innerloop(int unroll_m, int unroll_n);
+ void outerloop(int unroll_x, int unroll_y, Xbyak::Label *&outerloop_label);
+
+ void generate();
+
+
+private:
+ static const int IGEMM_UNROLL_M = 48;
+ static const int IGEMM_UNROLL_N = 8;
+
+ static const int isize = 2;
+ static const int size = 4;
+
+ // Prefetch configuration
+ static const int prefetch_size_a = 32 * 5;
+ static const int prefetch_size_b = 32 * 4;
+
+ static const int offset_a = 256, offset_b = 256;
+ static const int max_unroll_m = 48, max_unroll_n = 8;
+
+ // Integer register assignments
+ Xbyak::Reg64 M, N, K, A, B, C, LDC, I, J, LoopCount;
+ Xbyak::Reg64 AO, BO, CO1, CO2, AA;
+
+ // Vector register assignments
+ Xbyak::Zmm dp_scratch, ones, a_regs[max_unroll_m >> 4], b_regs[2];
+ Xbyak::Zmm c_regs[max_unroll_m >> 4][max_unroll_n];
+
+ // Stack variable assignments
+ int stack_alloc_size;
+ Xbyak::Address arg_a, arg_b, arg_c, arg_ldc, arg_coffset_c, arg_coffset_r;
+ Xbyak::Address coffset_cx, coffset_cy, coffset_rx, coffset_ry;
+
+ void L_aligned(Xbyak::Label &label, int alignment = 16) {
+ align(alignment);
+ L(label);
+ }
+};
+
+}
+}
+}
+
+#endif /* header guard */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp
new file mode 100644
index 0000000000..4f0b10dadd
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_gemv_s8u8s32.cpp
@@ -0,0 +1,290 @@
+/*******************************************************************************
+ * Copyright 2019 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#include "gemv.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+int gemm_s8u8s32_jump_to_gemv_s8u8s32(blas_t *arg) {
+
+ blas_t arg_gemv = *arg;
+
+ if ((arg -> offsetc == FIX_OFFSET) && // Fix offset
+ (arg -> ao == 0) &&
+ (arg -> bo == 0) &&
+ (arg -> co[0] == 0) &&
+ (*(arg -> alpha) == 1.0f) &&
+ ((*(arg -> beta) == 1.0f) || *(arg -> beta) == 0.0f)) {
+
+ if (arg -> n == 1) {
+
+ if (arg -> transa == 1) { // A transpose
+ arg_gemv.n = arg -> k;
+ arg_gemv.ldc = 1;
+ arg_gemv.swap = 0;
+ if (arg -> transb == 0) { // B non transpose
+ arg_gemv.ldb = 1;
+ }
+ // B transpose arg_gemv.ldb = arg -> ldb
+ gemv_threading_driver(&arg_gemv);
+ return 1;
+ }
+ }
+
+ if (arg -> m == 1) {
+
+ if (arg -> transb == 0) { // B non transpose
+ arg_gemv.transa = 1;
+ arg_gemv.m = arg -> n;
+ arg_gemv.n = arg -> k;
+ arg_gemv.a = (int8_t *) arg -> b;
+ arg_gemv.lda = arg -> ldb;
+ arg_gemv.b = (uint8_t *) arg -> a;
+ arg_gemv.swap = 1;
+ if (arg -> transa == 0) { // A non transpose
+ arg_gemv.ldb = arg -> lda;
+ }
+ else { // A transpose
+ arg_gemv.ldb = 1;
+ }
+ gemv_threading_driver(&arg_gemv);
+ return 1;
+ }
+ }
+ }
+
+ return 0;
+}
+
+
+int gemv_kernel_driver(blas_t *arg) {
+
+ dim_t m = arg -> m;
+ dim_t n = arg -> n;
+ uint8_t *a = (uint8_t *) arg -> a;
+ dim_t lda = arg -> lda;
+ int8_t *b = (int8_t *) arg -> b;
+ float beta = *(arg -> beta);
+
+ if (arg -> swap) {
+ arg -> gemv_u8s8s32_kernel(m, n, 1.0f, a, lda, b, beta, arg -> c);
+ }
+ else {
+ arg -> gemv_s8u8s32_kernel(arg -> m, arg -> n, 1.0f, arg -> a,
+ arg -> lda, arg -> b, *(arg -> beta), arg -> c);
+ }
+
+ return 0;
+}
+
+int gemv_threading_driver(blas_t *arg) {
+
+ dim_t nthr_m, nthr_n = 1;
+ dim_t MB, NB, UM = 16, UN = 64;
+ dim_t BLOCKM = 192, BLOCKN = 3072;
+ int status;
+ dim_t i;
+
+ dim_t nthr = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
+
+ uint8_t *new_x = NULL;
+ int32_t *tmp_y = NULL, *new_y = NULL;
+
+ dim_t m = arg -> m, n = arg -> n;
+
+ blas_t arg_seq = *arg;
+ float zero = 0.0f;
+
+ nthr_m = std::min(std::max(m / BLOCKM, (dim_t) 1), nthr);
+ MB = m / nthr_m;
+ MB = (((MB / UM) * UM) == MB) ? MB : (MB / UM) * UM + UM;
+ nthr_m = (((m / MB) * MB) == m) ? m / MB : m / MB + 1;
+ nthr_m = std::min(std::max(nthr_m, (dim_t) 1), nthr);
+
+ while ((nthr_m * (nthr_n + 1) <= nthr) && ((n / (nthr_n + 1)) >= BLOCKN)) {
+ nthr_n++;
+ }
+
+ NB = n / nthr_n;
+ NB = (((NB / UN) * UN) == NB) ? NB : (NB / UN) * UN + UN;
+ nthr_n = (((n / NB) * NB) == n) ? n / NB : n / NB + 1;
+ nthr_n = std::min(std::max(nthr_n, (dim_t) 1), nthr / nthr_m);
+
+ nthr = nthr_m * nthr_n;
+
+ if (arg -> ldb != 1) {
+ new_x = (uint8_t *)malloc(n, 64);
+ if (new_x == NULL)
+ return 1;
+ for (i = 0; i < n; i++) {
+ new_x[i] = (arg -> b)[i * arg -> ldb];
+ }
+ arg_seq.b = new_x;
+ arg_seq.ldb = 1;
+ }
+ else new_x = (uint8_t *) arg -> b;
+
+ if (arg -> ldc != 1) {
+ new_y = (int32_t *) malloc(nthr_m * PADD_BYTESIZE_ONPAGE(MB, sizeof(int32_t)), 64);
+ if (new_y == NULL) {
+ if (arg -> ldb != 1) {
+ free(new_x);
+ }
+ return 1;
+ }
+ }
+
+ // GEMV computation
+ if (nthr == 1) {
+
+ if (arg -> ldc != 1) {
+ if (*(arg -> beta) != 0.0f) {
+ for (i = 0; i < m; i++) {
+ new_y[i] = arg -> c[i * arg -> ldc];
+ }
+ }
+ }
+
+ status = gemv_kernel_driver(&arg_seq);
+
+ if (arg -> ldc != 1) {
+ for (i = 0; i < m; i++) {
+ arg -> c[i * arg -> ldc] = new_y[i];
+ }
+ }
+
+ if (arg -> ldb != 1) {
+ free(new_x);
+ }
+ if (arg -> ldc != 1) {
+ free(new_y);
+ }
+ return status;
+ }
+
+ if (nthr_n > 1) {
+ tmp_y = (int32_t *) malloc((nthr_n - 1) * PADD_BYTESIZE_ONPAGE(m, sizeof(int32_t)), PAGESIZE);
+ if (tmp_y == NULL) {
+ if (arg -> ldb != 1) {
+ free(new_x);
+ }
+ return 1;
+ }
+ }
+
+ parallel_nd((int) nthr, [&](const dim_t ithr) {
+
+ dim_t m_from, m_to, myM;
+ dim_t n_from, n_to, myN;
+
+ dim_t n_id, m_id;
+ dim_t loc_incy = 1;
+ int32_t *loc_y;
+
+ blas_t arg_loc = arg_seq;
+ int j;
+
+ m_id = ithr / nthr_n;
+ n_id = ithr % nthr_n;
+
+ m_from = MB * m_id;
+ m_to = MB * (m_id + 1);
+ if ((m_to > m) || (m_id == nthr_m - 1))
+ m_to = m;
+
+ myM = m_to - m_from;
+
+ n_from = NB * n_id;
+ n_to = NB * (n_id + 1);
+ if ((n_to > n) || (n_id == nthr_n - 1))
+ n_to = n;
+
+ myN = n_to - n_from;
+
+ if (n_id != 0) {
+ arg_loc.beta = &zero;
+ loc_y = tmp_y + (NEXT_THR_STRIDE(m, sizeof(int32_t))) * (n_id - 1) + m_from;
+ }
+ else {
+ if (arg -> ldc == 1) {
+ loc_y = arg_seq.c + m_from;
+ }
+ else {
+ // need to copy the block of c in new_y
+ loc_y = new_y + m_id * NEXT_THR_STRIDE(MB, sizeof(int32_t));
+ if (*(arg -> beta) != 0.0f) {
+ for (j = 0; j < myM; j++) {
+ loc_y[j] = arg -> c[(m_from + j) * arg -> ldc];
+ }
+ }
+ }
+ }
+
+ arg_loc.m = myM;
+ arg_loc.n = myN;
+ arg_loc.a = arg_seq.a + m_from * arg_seq.lda + n_from;
+ arg_loc.b = arg_seq.b + n_from;
+ arg_loc.c = loc_y;
+ arg_loc.ldc = loc_incy;
+
+ gemv_kernel_driver(&arg_loc);
+
+ if ((n_id == 0) && (arg -> ldc != 1)) {
+ for (j = 0; j < myM; j++) {
+ arg -> c[(m_from + j) * arg -> ldc] = loc_y[j];
+ }
+ }
+
+ });
+
+ if (nthr_n > 1) {
+ parallel_nd((int) nthr_m, [&](const dim_t ithr) {
+
+ dim_t j, j_from, j_to, ii;
+ int32_t acc;
+
+ j_from = MB * ithr;
+ j_to = MB * (ithr + 1);
+ if ((j_to > m) || (ithr == nthr - 1))
+ j_to = m;
+
+ for (j = j_from; j < j_to; j++) {
+ acc = 0;
+ for (ii = 0; ii < nthr_n - 1; ii++) {
+ acc += tmp_y[ii * NEXT_THR_STRIDE(m, sizeof(int32_t)) + j];
+ }
+ (arg -> c)[j * arg -> ldc] += acc;
+ }
+ });
+ free(tmp_y);
+ }
+
+ if (arg -> ldb != 1) {
+ free(new_x);
+ }
+
+ if (arg -> ldc != 1) {
+ free(new_y);
+ }
+
+ return 0;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp
new file mode 100644
index 0000000000..c57a8c1d12
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.cpp
@@ -0,0 +1,411 @@
+/*******************************************************************************
+ * Copyright 2019 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#include "jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp"
+
+#ifdef _WIN32
+#define is_windows 1
+#else
+#define is_windows 0
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+void jit_avx512_core_gemv_s8u8s32_kern::vnni(Xbyak::Zmm acc, Xbyak::Zmm b,
+ Xbyak::Zmm a, Xbyak::Zmm tmp,
+ Xbyak::Zmm one, bool swap,
+ int use_vnni) {
+
+ if (use_vnni) {
+ if (swap)
+ vpdpbusd(acc, a, b);
+ else
+ vpdpbusd(acc, b, a);
+ }
+
+ else {
+ if (swap)
+ vpmaddubsw(tmp, a, b);
+ else
+ vpmaddubsw(tmp, b, a);
+ vpmaddwd(tmp, tmp, one);
+ vpaddd(acc, tmp, acc);
+ }
+
+}
+
+void jit_avx512_core_gemv_s8u8s32_kern::n_loop_body(int start_a_idx, int start_acc_idx,
+ int b_idx, int nreg_acc,
+ Xbyak::Reg64 A, Xbyak::Reg64 lda,
+ Xbyak::Reg64 X, Xbyak::Zmm tmp,
+ Xbyak::Zmm one, bool swap, int use_vnni,
+ int use_mask, Xbyak::Opmask mask_n) {
+
+ int i;
+ int nreg_A = nreg_acc / 2 + (nreg_acc % 2);
+
+ // load X + j
+ if (use_mask)
+ vmovdqu8(Xbyak::Zmm(b_idx) | mask_n | T_z, ptr[X]);
+ else
+ vmovdqu8(Xbyak::Zmm(b_idx), ptr[X]);
+
+ xor_(r14, r14);
+ // load values of A
+ for (i = 0; i < nreg_A; i++) {
+ if (use_mask)
+ vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
+ else
+ vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
+ add(r14, lda);
+ }
+
+ for (i = 0; i < nreg_A; i++) {
+ // vnni (acc, b, a, tmp, one, swap, use_vnni)
+ vnni(Xbyak::Zmm(start_acc_idx + i), Xbyak::Zmm(b_idx),
+ Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
+ }
+
+ for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
+ if (use_mask)
+ vmovdqu8(Xbyak::Zmm(start_a_idx + i) | mask_n | T_z, ptr[A + r14]);
+ else
+ vmovdqu8(Xbyak::Zmm(start_a_idx + i), ptr[A + r14]);
+ add(r14, lda);
+ }
+
+ for (i = 0; i < nreg_A - (nreg_acc % 2); i++) {
+ vnni(Xbyak::Zmm(start_acc_idx + i + nreg_A), Xbyak::Zmm(b_idx),
+ Xbyak::Zmm(start_a_idx + i), tmp, one, swap, use_vnni);
+ }
+
+}
+
+void jit_avx512_core_gemv_s8u8s32_kern::shuffle_and_add(Xbyak::Zmm dest, Xbyak::Zmm A,
+ Xbyak::Zmm B, Xbyak::Zmm C,
+ Xbyak::Zmm D) {
+
+ vshufi32x4(dest, A, C, 0x44);
+ vshufi32x4(A, A, C, 0xEE);
+ vpaddd(C, dest, A); // C = A0 + A2|A1 + A3|C0 + C2|C1 + C3
+
+ vshufi32x4(dest, B, D, 0x44);
+ vshufi32x4(B, B, D, 0xEE);
+ vpaddd(D, dest, B); // D = B0 + B2|B1 + B3|D0 + D2|D1 + D3
+
+ vshufi32x4(A, C, D, 0x88);
+ vshufi32x4(B, C, D, 0xDD);
+ vpaddd(dest, A, B); // dest = SAi|SBi|SCi|SDi
+
+}
+
+void jit_avx512_core_gemv_s8u8s32_kern::update_c(int nreg_acc, Xbyak::Reg64 Y,
+ int start_a_idx, int start_acc_idx,
+ Xbyak::Xmm beta, int use_mask,
+ Xbyak::Opmask mask_m) {
+
+ int l, i, k, j, last_it;
+ Xbyak::Label store_label;
+
+ l = 0;
+ for (k = 0; k < nreg_acc; k += 8) {
+ for (i = 0, j = k; i < 8; i += 4, j += 2) {
+ if (j < nreg_acc) {
+ // shuffle per block of 4 registers
+ shuffle_and_add(Xbyak::Zmm(start_a_idx + l), // dest
+ Xbyak::Zmm(start_acc_idx + j), // A = acc0
+ Xbyak::Zmm(start_acc_idx + 1 + j), // B = acc1
+ Xbyak::Zmm(start_acc_idx + 4 + j), // C = acc4
+ Xbyak::Zmm(start_acc_idx + 5 + j)); // D = acc5
+
+ // extract low and high from dest and hadd
+ vextracti32x8(Xbyak::Ymm(start_a_idx + l + 1), Xbyak::Zmm(start_a_idx + l), 0);
+ vextracti32x8(Xbyak::Ymm(start_a_idx + l + 2), Xbyak::Zmm(start_a_idx + l), 1);
+ vphaddd(Xbyak::Ymm(start_a_idx + l),
+ Xbyak::Ymm(start_a_idx + l + 1),
+ Xbyak::Ymm(start_a_idx + l + 2));
+ }
+ l++;
+ }
+
+ vphaddd(Xbyak::Ymm(start_a_idx + l),
+ Xbyak::Ymm(start_a_idx + l - 2),
+ Xbyak::Ymm(start_a_idx + l - 1));
+
+ l++;
+ }
+
+ // eventually add with C and store new value
+ vxorps(Xbyak::Ymm(start_a_idx),
+ Xbyak::Ymm(start_a_idx),
+ Xbyak::Ymm(start_a_idx));
+ vucomiss(beta, Xbyak::Ymm(start_a_idx));
+ je(store_label, T_NEAR);
+
+ // beta = 1
+ for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
+ // load Y and add
+ last_it = (k + 8) > nreg_acc;
+ if (use_mask && last_it)
+ vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8) | mask_m | T_z, ptr[Y + (k / 8) * 32]);
+ else
+ vmovdqu32(Xbyak::Ymm(start_a_idx + k / 8), ptr[Y + (k / 8) * 32]);
+
+ vpaddd(Xbyak::Ymm(start_a_idx + l),
+ Xbyak::Ymm(start_a_idx + l),
+ Xbyak::Ymm(start_a_idx + k / 8));
+ }
+
+ // store
+ aligned_label(store_label);
+ for (k = 0, l = 2; k < nreg_acc; k += 8, l += 3) {
+ last_it = (k + 8) > nreg_acc;
+ if (use_mask && last_it)
+ vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l) | mask_m);
+ else
+ vmovdqu32(ptr[Y + (k / 8) * 32], Xbyak::Ymm(start_a_idx + l));
+ }
+
+}
+
+template <typename T>
+T jit_avx512_core_gemv_s8u8s32_kern::generate(int use_vnni) {
+
+ Xbyak::Opmask mask_n = k1, mask_m = k2;
+ Xbyak::Label one_label, m_tail_label, m_loop_label, n_loop_label;
+ Xbyak::Label n_tail_label, update_c_label, end_label;
+ constexpr unsigned int n_labels = (1 << unroll_m) - 1;
+ Xbyak::Label m_tail_label_case[n_labels];
+ Xbyak::Label n_loop_label_case[n_labels];
+ Xbyak::Label n_tail_label_case[n_labels];
+ Xbyak::Label update_c_label_case[n_labels];
+
+ int i, ii;
+
+ Xbyak::Zmm one, tmp;
+ Xbyak::Reg64 n = abi_param2, m = abi_param1;
+ Xbyak::Reg64 A = is_windows ? abi_param4 : abi_param3;
+ Xbyak::Reg64 lda = is_windows ? abi_param3 : abi_param4;
+ Xbyak::Reg64 X = is_windows ? rdi : r8;
+ Xbyak::Xmm beta = xmm1;
+ Xbyak::Reg64 Y = is_windows ? rsi : r9;
+
+ bool swap = !std::is_same<T, gemv_s8u8s32_kernel_t>::value;
+
+ // Windows: read on the stack lda, X, beta, Y
+
+ int zmm_idx = 1;
+ int nreg_acc = 1 << unroll_m;
+ int nreg_A = 1 << (unroll_m - 1);
+ int nreg_A_acc = nreg_acc + nreg_A;
+
+ if (!use_vnni) {
+ // set a zmm register to one
+ tmp = Xbyak::Zmm(0);
+ one = Xbyak::Zmm(zmm_idx + 1);
+ zmm_idx += 2; // one + tmp
+ }
+ else {
+ beta = xmm0;
+ }
+
+ preamble();
+
+ if (is_windows) {
+ mov(lda, ptr[rsp + get_size_of_abi_save_regs() + 40]);
+ mov(X, ptr[rsp + get_size_of_abi_save_regs() + 48]);
+ movss(beta, ptr[rsp + get_size_of_abi_save_regs() + 56]);
+ mov(Y, ptr[rsp + get_size_of_abi_save_regs() + 64]);
+ }
+
+ if (use_vnni && !is_windows) {
+ movaps(beta, xmm1);
+ }
+
+ mov(rax, (1 << unroll_n) - 1);
+ kmovq(k3, rax);
+
+ and_(rax, n); // rax contains n & ((1 << unroll_n) - 1)
+ mov(rbx, 1);
+ shlx(rbx, rbx, rax);
+ sub(rbx, 1);
+ kmovq(mask_n, rbx);
+ // mask_n set (AVX512 only), can use rax and rbx again
+
+ // set mask_m for update of the C matrix
+ // load/store on the C matrix use Ymm so tail according to Ymm size
+ mov(rax, 7); // 8 * 32 = 256 Ymm size
+ and_(rax, m); // rax contains m & 7
+ mov(rbx, 1);
+ shlx(rbx, rbx, rax);
+ sub(rbx, 1);
+ kmovq(mask_m, rbx);
+ // mask_m set (AVX512 only), can use rax and rbx again
+
+ // setup register of ones when VNNI instructions not available
+ if (!use_vnni) {
+ vmovdqu16(one, ptr[rip + one_label]);
+ }
+
+ // M loop
+ // base pointer for A rax contains a + i * lda
+ // Loop stop when rax >= a + (m & mask_um) * lda = rbx
+ // loop increment r10 = um * lda
+ // rbp = Y + i
+ mov(rax, A); // i = 0
+ mov(rbx, m);
+ and_(rbx, mask_um);
+ imul(rbx, lda);
+ add(rbx, A);
+ mov(r10, lda);
+ sal(r10, unroll_m);
+ mov(rbp, Y);
+
+ // N loop
+ // base pointer for X r11 contains x + j
+ // Loop stop when r11 >= x + n & mask_un = r12
+ // loop increment un
+ // r13 = rax + j = A + i * lda + j
+ mov(r12, n);
+ and_(r12, mask_un);
+ add(r12, X);
+
+ // M loop
+ aligned_label(m_loop_label);
+ cmp(rax, rbx);
+ jge(m_tail_label, T_NEAR);
+
+ // enter M loop
+ for(i = 0; i < nreg_acc; i++) {
+ vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
+ Xbyak::Zmm(i + zmm_idx + nreg_A),
+ Xbyak::Zmm(i + zmm_idx + nreg_A));
+ }
+
+ // N loop
+ mov(r11, X); // j = 0
+ mov(r13, rax);
+ aligned_label(n_loop_label);
+ cmp(r11, r12);
+ jge(n_tail_label, T_NEAR);
+
+ // enter N loop
+
+ n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
+ r13, lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
+
+ // increment rax with un
+ add(r11, 1 << unroll_n);
+ add(r13, 1 << unroll_n);
+ jmp(n_loop_label, T_NEAR);
+ // end N loop
+
+ // N tail
+ aligned_label(n_tail_label);
+
+ ktestq(mask_n, k3);
+ je(update_c_label, T_NEAR);
+ n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, nreg_acc,
+ r13, lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
+
+ // update C matrix
+ aligned_label(update_c_label);
+
+ update_c(nreg_acc, rbp, zmm_idx, zmm_idx + nreg_A, beta, 0, mask_m);
+
+ // increment rax with um * lda
+ add(rax, r10);
+ add(rbp, 1 << (unroll_m + 2));
+ jmp(m_loop_label, T_NEAR);
+ // end M loop
+
+ // M tail
+ aligned_label(m_tail_label);
+
+ // r10 will contain m_tail = m % unroll_m = m & (1 << unroll_m) - 1
+ mov(r10, m);
+ and_(r10, (1 << unroll_m) - 1);
+ for (ii = 1; ii < 1 << unroll_m; ii++) {
+ aligned_label(m_tail_label_case[ii-1]);
+ cmp(r10, ii);
+ if (ii == (1 << unroll_m) - 1)
+ jne(end_label, T_NEAR);
+ else
+ jne(m_tail_label_case[ii], T_NEAR);
+
+ // m_tail = i, use i accumulators
+
+ for(i = 0; i < ii; i++) {
+ vpxorq(Xbyak::Zmm(i + zmm_idx + nreg_A),
+ Xbyak::Zmm(i + zmm_idx + nreg_A),
+ Xbyak::Zmm(i + zmm_idx + nreg_A));
+ }
+
+ // N loop
+ mov(r11, X); // j = 0
+ mov(r13, rax);
+ aligned_label(n_loop_label_case[ii - 1]);
+ cmp(r11, r12);
+ jge(n_tail_label_case[ii - 1], T_NEAR);
+
+ n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
+ lda, r11, tmp, one, swap, use_vnni, 0, mask_n);
+
+ // increment rax with un
+ add(r11, 1 << unroll_n);
+ add(r13, 1 << unroll_n);
+ jmp(n_loop_label_case[ii - 1], T_NEAR);
+ // end N loop
+
+ // N tail
+ aligned_label(n_tail_label_case[ii - 1]);
+ ktestq(mask_n, k3);
+ je(update_c_label_case[ii - 1], T_NEAR);
+ n_loop_body(zmm_idx, zmm_idx + nreg_A, zmm_idx + nreg_A_acc, ii, r13,
+ lda, r11, tmp, one, swap, use_vnni, 1, mask_n);
+
+ // update C matrix
+ aligned_label(update_c_label_case[ii - 1]);
+ update_c(ii, rbp, zmm_idx, zmm_idx + nreg_A, beta, 1, mask_m);
+
+ if (ii < ((1 << unroll_m) - 1))
+ jmp(end_label, T_NEAR);
+ }
+
+ aligned_label(end_label);
+
+ postamble();
+
+ if (!use_vnni) {
+ aligned_label(one_label);
+ for (i = 0; i < size_vec_reg/8; i++)
+ dq(0x0001000100010001);
+ }
+
+ return (T) getCode();
+}
+
+template jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t
+jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_s8u8s32_kernel_t>(int);
+
+template jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t
+jit_avx512_core_gemv_s8u8s32_kern::generate<jit_avx512_core_gemv_s8u8s32_kern::gemv_u8s8s32_kernel_t>(int);
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp
new file mode 100644
index 0000000000..9ea23a5f56
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_kernel_gemv_s8u8s32_kern.hpp
@@ -0,0 +1,64 @@
+/*******************************************************************************
+ * Copyright 2019 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+class jit_avx512_core_gemv_s8u8s32_kern : jit_generator {
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_gemv_s8u8s32_kern);
+
+ // assumes untoll_{m,n} are a power of 2
+ static constexpr unsigned int unroll_m = 4; // real unrolling factor is 2^unroll_m
+ const int mask_um = 0xFFFFFFF0;
+ static constexpr unsigned int unroll_n = 6; // real unrolling factor is 2^unroll_n
+ const int mask_un = 0xFFFFFFC0;
+ const int size_vec_reg = 64; // bytes
+
+ void aligned_label(Xbyak::Label &label, int alignment = 16) {
+ align(alignment);
+ L(label);
+ }
+
+ void vnni(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, bool, int);
+ void n_loop_body(int, int, int, int, Xbyak::Reg64, Xbyak::Reg64,
+ Xbyak::Reg64, Xbyak::Zmm, Xbyak::Zmm, bool, int, int, Xbyak::Opmask);
+ void shuffle_and_add(Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm, Xbyak::Zmm);
+ void update_c(int, Xbyak::Reg64, int, int, Xbyak::Xmm, int, Xbyak::Opmask);
+
+public:
+ jit_avx512_core_gemv_s8u8s32_kern() : jit_generator(nullptr, GEMM_CODE_SIZE) {};
+
+ // m, n, alpha, a, lda, x, beta, y
+ typedef void (*gemv_s8u8s32_kernel_t)(const dim_t, const dim_t, const float,
+ const int8_t*, const dim_t, const uint8_t*,
+ const float, int32_t*);
+ typedef void (*gemv_u8s8s32_kernel_t)(const dim_t, const dim_t, const float,
+ const uint8_t*, const dim_t, const int8_t*,
+ const float, int32_t*);
+
+ template <typename T>
+ T generate(int use_vnni);
+
+};
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp
new file mode 100644
index 0000000000..544cd2ff25
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_an_kern.cpp
@@ -0,0 +1,819 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_an_kern::jit_avx512_core_u8_copy_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l170;
+Xbyak::Label l1f0;
+Xbyak::Label l20;
+Xbyak::Label l224;
+Xbyak::Label l234;
+Xbyak::Label l240;
+Xbyak::Label l254;
+Xbyak::Label l32c;
+Xbyak::Label l34;
+Xbyak::Label l388;
+Xbyak::Label l3b0;
+Xbyak::Label l3c0;
+Xbyak::Label l3cc;
+Xbyak::Label l3dc;
+Xbyak::Label l454;
+Xbyak::Label l48c;
+Xbyak::Label l4a8;
+Xbyak::Label l4b8;
+Xbyak::Label l4c4;
+Xbyak::Label l4d8;
+Xbyak::Label l570;
+Xbyak::Label l5c4;
+Xbyak::Label l5f0;
+Xbyak::Label l60c;
+Xbyak::Label l61c;
+Xbyak::Label l628;
+Xbyak::Label l638;
+Xbyak::Label l6b0;
+Xbyak::Label l6f4;
+Xbyak::Label l720;
+Xbyak::Label l73c;
+Xbyak::Label l74c;
+Xbyak::Label l758;
+Xbyak::Label l76c;
+Xbyak::Label l804;
+Xbyak::Label l858;
+Xbyak::Label l88c;
+Xbyak::Label l8a4;
+Xbyak::Label l8b2;
+Xbyak::Label l8bc;
+Xbyak::Label l8cc;
+Xbyak::Label l944;
+Xbyak::Label l98c;
+Xbyak::Label l9b0;
+Xbyak::Label l9c8;
+Xbyak::Label l9d8;
+
+ preamble();
+#ifdef _WIN32
+ auto stacksize = get_size_of_abi_save_regs();
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(LDA, qword[LDA]);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ sub(A, -128);
+ sub(B, -128);
+ cmp(N, 0x30);
+ jl(l234, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ add(A, 0x30);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l170, T_NEAR);
+ align(4);
+
+L(l34);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm2);
+ movdqu(xmm0, xword[A1-0x70]);
+ movdqu(xmm1, xword[A1+LDA*1-0x70]);
+ movdqu(xmm2, xword[A1+LDA*2-0x70]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x70]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B-0x30], xmm1);
+ movdqu(xword[B-0x20], xmm4);
+ movdqu(xword[B-0x10], xmm2);
+ movdqu(xmm0, xword[A1-0x60]);
+ movdqu(xmm1, xword[A1+LDA*1-0x60]);
+ movdqu(xmm2, xword[A1+LDA*2-0x60]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x60]);
+ lea(A1, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B], xmm0);
+ movdqu(xword[B+0x10], xmm1);
+ movdqu(xword[B+0x20], xmm4);
+ movdqu(xword[B+0x30], xmm2);
+ sub(B, -192);
+ dec(I);
+ jg(l34, T_NEAR);
+ align(4);
+
+L(l170);
+ test(M, 0x2);
+ jle(l1f0, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ movdqu(xmm2, xword[A1-0x60]);
+ add(A1, LDA);
+ movdqu(xmm3, xword[A1-0x80]);
+ movdqu(xmm4, xword[A1-0x70]);
+ movdqu(xmm5, xword[A1-0x60]);
+ add(A1, LDA);
+ movdqa(xmm6, xmm0);
+ punpcklbw(xmm0, xmm3);
+ punpckhbw(xmm6, xmm3);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm6);
+ movdqa(xmm6, xmm1);
+ punpcklbw(xmm1, xmm4);
+ punpckhbw(xmm6, xmm4);
+ movdqu(xword[B-0x60], xmm1);
+ movdqu(xword[B-0x50], xmm6);
+ movdqa(xmm6, xmm2);
+ punpcklbw(xmm2, xmm5);
+ punpckhbw(xmm6, xmm5);
+ movdqu(xword[B-0x40], xmm2);
+ movdqu(xword[B-0x30], xmm6);
+ sub(B, -96);
+ align(4);
+
+L(l1f0);
+ test(M, 0x1);
+ jle(l224, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ movdqu(xmm2, xword[A1-0x60]);
+ add(A1, LDA);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm2);
+ sub(B, -48);
+ align(4);
+
+L(l224);
+ sub(N, 0x30);
+ cmp(N, 0x30);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l234);
+ cmp(N, 0x20);
+ jl(l3c0, T_NEAR);
+ align(4);
+
+L(l240);
+ mov(A1, A);
+ add(A, 0x20);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l32c, T_NEAR);
+ align(4);
+
+L(l254);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm2);
+ movdqu(xmm0, xword[A1-0x70]);
+ movdqu(xmm1, xword[A1+LDA*1-0x70]);
+ movdqu(xmm2, xword[A1+LDA*2-0x70]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x70]);
+ lea(A1, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B-0x30], xmm1);
+ movdqu(xword[B-0x20], xmm4);
+ movdqu(xword[B-0x10], xmm2);
+ sub(B, -128);
+ dec(I);
+ jg(l254, T_NEAR);
+ align(4);
+
+L(l32c);
+ test(M, 0x2);
+ jle(l388, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqu(xmm2, xword[A1-0x80]);
+ movdqu(xmm3, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm2);
+ punpckhbw(xmm4, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm4);
+ movdqa(xmm4, xmm1);
+ punpcklbw(xmm1, xmm3);
+ punpckhbw(xmm4, xmm3);
+ movdqu(xword[B-0x60], xmm1);
+ movdqu(xword[B-0x50], xmm4);
+ sub(B, -64);
+ align(4);
+
+L(l388);
+ test(M, 0x1);
+ jle(l3b0, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l3b0);
+ sub(N, 0x20);
+ cmp(N, 0x20);
+ jge(l240, T_NEAR);
+ align(4);
+
+L(l3c0);
+ cmp(N, 0x10);
+ jl(l4b8, T_NEAR);
+ align(4);
+
+L(l3cc);
+ mov(A1, A);
+ add(A, 0x10);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l454, T_NEAR);
+ align(4);
+
+L(l3dc);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm1, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm2, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm3, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm1, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm1, xmm3);
+ movdqa(xmm3, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm3, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm1);
+ punpckhwd(xmm2, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm3);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm2);
+ sub(B, -64);
+ dec(I);
+ jg(l3dc, T_NEAR);
+ align(4);
+
+L(l454);
+ test(M, 0x2);
+ jle(l48c, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm1, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqa(xmm2, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm2, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ align(4);
+
+L(l48c);
+ test(M, 0x1);
+ jle(l4a8, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l4a8);
+ sub(N, 0x10);
+ cmp(N, 0x10);
+ jge(l3cc, T_NEAR);
+ align(4);
+
+L(l4b8);
+ cmp(N, 0x8);
+ jl(l61c, T_NEAR);
+ align(4);
+
+L(l4c4);
+ mov(A1, A);
+ add(A, 0x8);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l570, T_NEAR);
+ align(4);
+
+L(l4d8);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ dec(I);
+ jg(l4d8, T_NEAR);
+ align(4);
+
+L(l570);
+ test(M, 0x4);
+ jle(l5c4, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l5c4);
+ test(M, 0x2);
+ jle(l5f0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l5f0);
+ test(M, 0x1);
+ jle(l60c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l60c);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l4c4, T_NEAR);
+ align(4);
+
+L(l61c);
+ cmp(N, 0x4);
+ jl(l74c, T_NEAR);
+ align(4);
+
+L(l628);
+ mov(A1, A);
+ add(A, 0x4);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l6b0, T_NEAR);
+ align(4);
+
+L(l638);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ dec(I);
+ jg(l638, T_NEAR);
+ align(4);
+
+L(l6b0);
+ test(M, 0x4);
+ jle(l6f4, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l6f4);
+ test(M, 0x2);
+ jle(l720, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l720);
+ test(M, 0x1);
+ jle(l73c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l73c);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l628, T_NEAR);
+ align(4);
+
+L(l74c);
+ cmp(N, 0x2);
+ jl(l8b2, T_NEAR);
+ align(4);
+
+L(l758);
+ mov(A1, A);
+ add(A, 0x2);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l804, T_NEAR);
+ align(4);
+
+L(l76c);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm4, eax, 0x0);
+ punpcklbw(xmm1, xmm2);
+ punpcklbw(xmm3, xmm4);
+ punpcklwd(xmm1, xmm3);
+ punpcklqdq(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(LDA3);
+ jg(l76c, T_NEAR);
+ align(4);
+
+L(l804);
+ test(M, 0x4);
+ jle(l858, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l858);
+ test(M, 0x2);
+ jle(l88c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l88c);
+ test(M, 0x1);
+ jle(l8a4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ mov(word[B-0x80], ax);
+ sub(B, -2);
+ align(4);
+
+L(l8a4);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l758, T_NEAR);
+ align(4);
+
+L(l8b2);
+ cmp(N, 0x1);
+ jl(l9d8, T_NEAR);
+ align(4);
+
+L(l8bc);
+ mov(A1, A);
+ add(A, 0x1);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l944, T_NEAR);
+ align(4);
+
+L(l8cc);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x7);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ dec(LDA3);
+ jg(l8cc, T_NEAR);
+ align(4);
+
+L(l944);
+ test(M, 0x4);
+ jle(l98c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l98c);
+ test(M, 0x2);
+ jle(l9b0, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l9b0);
+ test(M, 0x1);
+ jle(l9c8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l9c8);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l8bc, T_NEAR);
+ align(4);
+
+L(l9d8);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp
new file mode 100644
index 0000000000..1c11fc6cef
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_at_kern.cpp
@@ -0,0 +1,2209 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_at_kern::jit_avx512_core_u8_copy_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l1014;
+Xbyak::Label l1390;
+Xbyak::Label l159c;
+Xbyak::Label l173c;
+Xbyak::Label l18e4;
+Xbyak::Label l1a7c;
+Xbyak::Label l1a8c;
+Xbyak::Label l1a98;
+Xbyak::Label l1ab4;
+Xbyak::Label l1c64;
+Xbyak::Label l1d74;
+Xbyak::Label l1e50;
+Xbyak::Label l1f2c;
+Xbyak::Label l1ffc;
+Xbyak::Label l20;
+Xbyak::Label l200c;
+Xbyak::Label l2018;
+Xbyak::Label l2034;
+Xbyak::Label l2110;
+Xbyak::Label l21a0;
+Xbyak::Label l2210;
+Xbyak::Label l2284;
+Xbyak::Label l22f0;
+Xbyak::Label l2300;
+Xbyak::Label l230c;
+Xbyak::Label l2324;
+Xbyak::Label l2398;
+Xbyak::Label l23e8;
+Xbyak::Label l242c;
+Xbyak::Label l2474;
+Xbyak::Label l24b4;
+Xbyak::Label l24c4;
+Xbyak::Label l24d0;
+Xbyak::Label l24e8;
+Xbyak::Label l2520;
+Xbyak::Label l254c;
+Xbyak::Label l2578;
+Xbyak::Label l25a8;
+Xbyak::Label l25c8;
+Xbyak::Label l25d6;
+Xbyak::Label l25e0;
+Xbyak::Label l25f0;
+Xbyak::Label l260c;
+Xbyak::Label l262c;
+Xbyak::Label l264c;
+Xbyak::Label l2668;
+Xbyak::Label l2680;
+Xbyak::Label l2690;
+Xbyak::Label l44;
+Xbyak::Label l58c;
+Xbyak::Label l8b0;
+Xbyak::Label lb14;
+Xbyak::Label ld84;
+Xbyak::Label lfdc;
+Xbyak::Label lfec;
+Xbyak::Label lff8;
+
+ preamble();
+#ifdef _WIN32
+ auto stacksize = get_size_of_abi_save_regs();
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(N, qword[N]);
+ mov(M, qword[M]);
+ mov(LDA, qword[LDA]);
+ sub(A, -128);
+ sub(B, -128);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ cmp(N, 0x30);
+ jl(lfec, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x5);
+ lea(I, ptr[I+LDA*8]);
+ lea(I, ptr[I+LDA*8]);
+ add(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l58c, T_NEAR);
+ align(4);
+
+L(l44);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B+0x40], xmm1);
+ movdqu(xword[B+0x100], xmm4);
+ movdqu(xword[B+0x1c0], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B+0x50], xmm1);
+ movdqu(xword[B+0x110], xmm4);
+ movdqu(xword[B+0x1d0], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B+0x60], xmm1);
+ movdqu(xword[B+0x120], xmm4);
+ movdqu(xword[B+0x1e0], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B+0x70], xmm1);
+ movdqu(xword[B+0x130], xmm4);
+ movdqu(xword[B+0x1f0], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B+0x80], xmm1);
+ movdqu(xword[B+0x140], xmm4);
+ movdqu(xword[B+0x200], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ movdqu(xword[B+0x90], xmm1);
+ movdqu(xword[B+0x150], xmm4);
+ movdqu(xword[B+0x210], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ movdqu(xword[B+0xa0], xmm1);
+ movdqu(xword[B+0x160], xmm4);
+ movdqu(xword[B+0x220], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ movdqu(xword[B+0xb0], xmm1);
+ movdqu(xword[B+0x170], xmm4);
+ movdqu(xword[B+0x230], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B], xmm0);
+ movdqu(xword[B+0xc0], xmm1);
+ movdqu(xword[B+0x180], xmm4);
+ movdqu(xword[B+0x240], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B+0x10], xmm0);
+ movdqu(xword[B+0xd0], xmm1);
+ movdqu(xword[B+0x190], xmm4);
+ movdqu(xword[B+0x250], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B+0x20], xmm0);
+ movdqu(xword[B+0xe0], xmm1);
+ movdqu(xword[B+0x1a0], xmm4);
+ movdqu(xword[B+0x260], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B+0x30], xmm0);
+ movdqu(xword[B+0xf0], xmm1);
+ movdqu(xword[B+0x1b0], xmm4);
+ movdqu(xword[B+0x270], xmm3);
+ sub(A1, -16);
+ sub(B, -768);
+ dec(I);
+ jg(l44, T_NEAR);
+ align(4);
+
+L(l58c);
+ test(M, 0x8);
+ jle(l8b0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B+0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B+0x50], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B+0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B+0x70], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B+0x80], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x30], xmm0);
+ movdqu(xword[B+0x90], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x20], xmm0);
+ movdqu(xword[B+0xa0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x10], xmm0);
+ movdqu(xword[B+0xb0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B], xmm0);
+ movdqu(xword[B+0xc0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B+0x10], xmm0);
+ movdqu(xword[B+0xd0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B+0x20], xmm0);
+ movdqu(xword[B+0xe0], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B+0x30], xmm0);
+ movdqu(xword[B+0xf0], xmm1);
+ sub(A1, -8);
+ sub(B, -384);
+ align(4);
+
+L(l8b0);
+ test(M, 0x4);
+ jle(lb14, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x40], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x30], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x20], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x10], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B+0x10], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B+0x20], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B+0x30], xmm0);
+ sub(A1, -4);
+ sub(B, -192);
+ align(4);
+
+L(lb14);
+ test(M, 0x2);
+ jle(ld84, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x70], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x60], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x50], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x40], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x30], xmm0);
+ sub(A1, -2);
+ sub(B, -96);
+ align(4);
+
+L(ld84);
+ test(M, 0x1);
+ jle(lfdc, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x80], xmm0);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x70], xmm0);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x60], xmm0);
+ sub(B, -48);
+ align(4);
+
+L(lfdc);
+ sub(N, 0x30);
+ cmp(N, 0x30);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(lfec);
+ cmp(N, 0x20);
+ jl(l1a8c, T_NEAR);
+ align(4);
+
+L(lff8);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x5);
+ add(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l1390, T_NEAR);
+ align(4);
+
+L(l1014);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B], xmm1);
+ movdqu(xword[B+0x80], xmm4);
+ movdqu(xword[B+0x100], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B+0x10], xmm1);
+ movdqu(xword[B+0x90], xmm4);
+ movdqu(xword[B+0x110], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B+0x20], xmm1);
+ movdqu(xword[B+0xa0], xmm4);
+ movdqu(xword[B+0x120], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B+0x30], xmm1);
+ movdqu(xword[B+0xb0], xmm4);
+ movdqu(xword[B+0x130], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B+0x40], xmm1);
+ movdqu(xword[B+0xc0], xmm4);
+ movdqu(xword[B+0x140], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ movdqu(xword[B+0x50], xmm1);
+ movdqu(xword[B+0xd0], xmm4);
+ movdqu(xword[B+0x150], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ movdqu(xword[B+0x60], xmm1);
+ movdqu(xword[B+0xe0], xmm4);
+ movdqu(xword[B+0x160], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ movdqu(xword[B+0x70], xmm1);
+ movdqu(xword[B+0xf0], xmm4);
+ movdqu(xword[B+0x170], xmm3);
+ sub(A1, -16);
+ sub(B, -512);
+ dec(I);
+ jg(l1014, T_NEAR);
+ align(4);
+
+L(l1390);
+ test(M, 0x8);
+ jle(l159c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B+0x10], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B+0x20], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B+0x30], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x40], xmm0);
+ movdqu(xword[B+0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x30], xmm0);
+ movdqu(xword[B+0x50], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x20], xmm0);
+ movdqu(xword[B+0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x10], xmm0);
+ movdqu(xword[B+0x70], xmm1);
+ sub(A1, -8);
+ sub(B, -256);
+ align(4);
+
+L(l159c);
+ test(M, 0x4);
+ jle(l173c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x40], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x30], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x20], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x10], xmm0);
+ sub(A1, -4);
+ sub(B, -128);
+ align(4);
+
+L(l173c);
+ test(M, 0x2);
+ jle(l18e4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x70], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x60], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqu(xword[B-0x50], xmm0);
+ sub(A1, -2);
+ sub(B, -64);
+ align(4);
+
+L(l18e4);
+ test(M, 0x1);
+ jle(l1a7c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x80], xmm0);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l1a7c);
+ sub(N, 0x20);
+ cmp(N, 0x20);
+ jge(lff8, T_NEAR);
+ align(4);
+
+L(l1a8c);
+ cmp(N, 0x10);
+ jl(l200c, T_NEAR);
+ align(4);
+
+L(l1a98);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x4);
+ add(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l1c64, T_NEAR);
+ align(4);
+
+L(l1ab4);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x40], xmm1);
+ movdqu(xword[B], xmm4);
+ movdqu(xword[B+0x40], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x30], xmm1);
+ movdqu(xword[B+0x10], xmm4);
+ movdqu(xword[B+0x50], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x20], xmm1);
+ movdqu(xword[B+0x20], xmm4);
+ movdqu(xword[B+0x60], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B-0x10], xmm1);
+ movdqu(xword[B+0x30], xmm4);
+ movdqu(xword[B+0x70], xmm3);
+ sub(A1, -16);
+ sub(B, -256);
+ dec(I);
+ jg(l1ab4, T_NEAR);
+ align(4);
+
+L(l1c64);
+ test(M, 0x8);
+ jle(l1d74, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x30], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x20], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ movdqu(xword[B-0x10], xmm1);
+ sub(A1, -8);
+ sub(B, -128);
+ align(4);
+
+L(l1d74);
+ test(M, 0x4);
+ jle(l1e50, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x50], xmm0);
+ sub(A1, -4);
+ sub(B, -64);
+ align(4);
+
+L(l1e50);
+ test(M, 0x2);
+ jle(l1f2c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x70], xmm0);
+ sub(A1, -2);
+ sub(B, -32);
+ align(4);
+
+L(l1f2c);
+ test(M, 0x1);
+ jle(l1ffc, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0xf);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l1ffc);
+ sub(N, 0x10);
+ cmp(N, 0x10);
+ jge(l1a98, T_NEAR);
+ align(4);
+
+L(l200c);
+ cmp(N, 0x8);
+ jl(l2300, T_NEAR);
+ align(4);
+
+L(l2018);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*4]);
+ lea(I, ptr[A1+LDA*8]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l2110, T_NEAR);
+ align(4);
+
+L(l2034);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ sub(A1, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x60], xmm1);
+ movdqu(xword[B-0x40], xmm4);
+ movdqu(xword[B-0x20], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ movdqu(xword[B-0x30], xmm4);
+ movdqu(xword[B-0x10], xmm3);
+ sub(B, -128);
+ dec(I);
+ jg(l2034, T_NEAR);
+ align(4);
+
+L(l2110);
+ test(M, 0x8);
+ jle(l21a0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ sub(A1, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ align(4);
+
+L(l21a0);
+ test(M, 0x4);
+ jle(l2210, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ sub(A1, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l2210);
+ test(M, 0x2);
+ jle(l2284, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l2284);
+ test(M, 0x1);
+ jle(l22f0, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x7);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l22f0);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l2018, T_NEAR);
+ align(4);
+
+L(l2300);
+ cmp(N, 0x4);
+ jl(l24c4, T_NEAR);
+ align(4);
+
+L(l230c);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*2]);
+ lea(I, ptr[A1+LDA*4]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l2398, T_NEAR);
+ align(4);
+
+L(l2324);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm2, xword[A2-0x80]);
+ movdqu(xmm3, xword[A2+LDA*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm3);
+ sub(B, -64);
+ dec(I);
+ jg(l2324, T_NEAR);
+ align(4);
+
+L(l2398);
+ test(M, 0x8);
+ jle(l23e8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ sub(A1, -8);
+ movq(xmm2, qword[A2-0x80]);
+ movq(xmm3, qword[A2+LDA*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l23e8);
+ test(M, 0x4);
+ jle(l242c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ sub(A1, -4);
+ movd(xmm2, dword[A2-0x80]);
+ movd(xmm3, dword[A2+LDA*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l242c);
+ test(M, 0x2);
+ jle(l2474, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x3);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l2474);
+ test(M, 0x1);
+ jle(l24b4, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l24b4);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l230c, T_NEAR);
+ align(4);
+
+L(l24c4);
+ cmp(N, 0x2);
+ jl(l25d6, T_NEAR);
+ align(4);
+
+L(l24d0);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*1]);
+ lea(I, ptr[A1+LDA*2]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l2520, T_NEAR);
+ align(4);
+
+L(l24e8);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm1, xword[A2-0x80]);
+ sub(A2, -16);
+ movdqa(xmm2, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm2, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ dec(I);
+ jg(l24e8, T_NEAR);
+ align(4);
+
+L(l2520);
+ test(M, 0x8);
+ jle(l254c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(xmm1, qword[A2-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l254c);
+ test(M, 0x4);
+ jle(l2578, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(xmm1, dword[A2-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l2578);
+ test(M, 0x2);
+ jle(l25a8, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x1);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l25a8);
+ test(M, 0x1);
+ jle(l25c8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A2-0x80]);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l25c8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l24d0, T_NEAR);
+ align(4);
+
+L(l25d6);
+ cmp(N, 0x1);
+ jl(l2690, T_NEAR);
+ align(4);
+
+L(l25e0);
+ mov(A1, A);
+ add(A, LDA);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l260c, T_NEAR);
+ align(4);
+
+L(l25f0);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(I);
+ jg(l25f0, T_NEAR);
+ align(4);
+
+L(l260c);
+ test(M, 0x8);
+ jle(l262c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l262c);
+ test(M, 0x4);
+ jle(l264c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l264c);
+ test(M, 0x2);
+ jle(l2668, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ mov(word[B-0x80], ax);
+ sub(A1, -2);
+ sub(B, -2);
+ align(4);
+
+L(l2668);
+ test(M, 0x1);
+ jle(l2680, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l2680);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l25e0, T_NEAR);
+ align(4);
+
+L(l2690);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp
new file mode 100644
index 0000000000..56c36ee14a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bn_kern.cpp
@@ -0,0 +1,564 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_bn_kern::jit_avx512_core_u8_copy_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l118;
+Xbyak::Label l1a8;
+Xbyak::Label l20;
+Xbyak::Label l218;
+Xbyak::Label l28c;
+Xbyak::Label l2f8;
+Xbyak::Label l308;
+Xbyak::Label l314;
+Xbyak::Label l32c;
+Xbyak::Label l3a0;
+Xbyak::Label l3c;
+Xbyak::Label l3f0;
+Xbyak::Label l434;
+Xbyak::Label l47c;
+Xbyak::Label l4bc;
+Xbyak::Label l4cc;
+Xbyak::Label l4d8;
+Xbyak::Label l4f0;
+Xbyak::Label l528;
+Xbyak::Label l554;
+Xbyak::Label l580;
+Xbyak::Label l5b0;
+Xbyak::Label l5d0;
+Xbyak::Label l5de;
+Xbyak::Label l5e8;
+Xbyak::Label l5f8;
+Xbyak::Label l614;
+Xbyak::Label l634;
+Xbyak::Label l654;
+Xbyak::Label l670;
+Xbyak::Label l688;
+Xbyak::Label l698;
+
+ preamble();
+#ifdef _WIN32
+ auto stacksize = get_size_of_abi_save_regs();
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(N, qword[N]);
+ mov(M, qword[M]);
+ mov(LDA, qword[LDA]);
+ sub(A, -128);
+ sub(B, -128);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ cmp(N, 0x8);
+ jl(l308, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*4]);
+ lea(I, ptr[A1+LDA*8]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l118, T_NEAR);
+ align(4);
+
+L(l3c);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ sub(A1, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x60], xmm1);
+ movdqu(xword[B-0x40], xmm4);
+ movdqu(xword[B-0x20], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ movdqu(xword[B-0x30], xmm4);
+ movdqu(xword[B-0x10], xmm3);
+ sub(B, -128);
+ dec(I);
+ jg(l3c, T_NEAR);
+ align(4);
+
+L(l118);
+ test(M, 0x8);
+ jle(l1a8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ sub(A1, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ align(4);
+
+L(l1a8);
+ test(M, 0x4);
+ jle(l218, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ sub(A1, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l218);
+ test(M, 0x2);
+ jle(l28c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x7);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l28c);
+ test(M, 0x1);
+ jle(l2f8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x7);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l2f8);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l308);
+ cmp(N, 0x4);
+ jl(l4cc, T_NEAR);
+ align(4);
+
+L(l314);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*2]);
+ lea(I, ptr[A1+LDA*4]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l3a0, T_NEAR);
+ align(4);
+
+L(l32c);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm2, xword[A2-0x80]);
+ movdqu(xmm3, xword[A2+LDA*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm3);
+ sub(B, -64);
+ dec(I);
+ jg(l32c, T_NEAR);
+ align(4);
+
+L(l3a0);
+ test(M, 0x8);
+ jle(l3f0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ sub(A1, -8);
+ movq(xmm2, qword[A2-0x80]);
+ movq(xmm3, qword[A2+LDA*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l3f0);
+ test(M, 0x4);
+ jle(l434, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ sub(A1, -4);
+ movd(xmm2, dword[A2-0x80]);
+ movd(xmm3, dword[A2+LDA*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l434);
+ test(M, 0x2);
+ jle(l47c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x3);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l47c);
+ test(M, 0x1);
+ jle(l4bc, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l4bc);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l314, T_NEAR);
+ align(4);
+
+L(l4cc);
+ cmp(N, 0x2);
+ jl(l5de, T_NEAR);
+ align(4);
+
+L(l4d8);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*1]);
+ lea(I, ptr[A1+LDA*2]);
+ mov(A, I);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l528, T_NEAR);
+ align(4);
+
+L(l4f0);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm1, xword[A2-0x80]);
+ sub(A2, -16);
+ movdqa(xmm2, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm2, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ dec(I);
+ jg(l4f0, T_NEAR);
+ align(4);
+
+L(l528);
+ test(M, 0x8);
+ jle(l554, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(xmm1, qword[A2-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l554);
+ test(M, 0x4);
+ jle(l580, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(xmm1, dword[A2-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l580);
+ test(M, 0x2);
+ jle(l5b0, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x1);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l5b0);
+ test(M, 0x1);
+ jle(l5d0, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A2-0x80]);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l5d0);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l4d8, T_NEAR);
+ align(4);
+
+L(l5de);
+ cmp(N, 0x1);
+ jl(l698, T_NEAR);
+ align(4);
+
+L(l5e8);
+ mov(A1, A);
+ add(A, LDA);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l614, T_NEAR);
+ align(4);
+
+L(l5f8);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(I);
+ jg(l5f8, T_NEAR);
+ align(4);
+
+L(l614);
+ test(M, 0x8);
+ jle(l634, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l634);
+ test(M, 0x4);
+ jle(l654, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l654);
+ test(M, 0x2);
+ jle(l670, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ mov(word[B-0x80], ax);
+ sub(A1, -2);
+ sub(B, -2);
+ align(4);
+
+L(l670);
+ test(M, 0x1);
+ jle(l688, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l688);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l5e8, T_NEAR);
+ align(4);
+
+L(l698);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp
new file mode 100644
index 0000000000..53e99d94de
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_bt_kern.cpp
@@ -0,0 +1,501 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_bt_kern::jit_avx512_core_u8_copy_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l120;
+Xbyak::Label l14c;
+Xbyak::Label l168;
+Xbyak::Label l178;
+Xbyak::Label l184;
+Xbyak::Label l194;
+Xbyak::Label l20;
+Xbyak::Label l20c;
+Xbyak::Label l250;
+Xbyak::Label l27c;
+Xbyak::Label l298;
+Xbyak::Label l2a8;
+Xbyak::Label l2b4;
+Xbyak::Label l2c8;
+Xbyak::Label l34;
+Xbyak::Label l360;
+Xbyak::Label l3b4;
+Xbyak::Label l3e8;
+Xbyak::Label l400;
+Xbyak::Label l40e;
+Xbyak::Label l418;
+Xbyak::Label l428;
+Xbyak::Label l4a0;
+Xbyak::Label l4e8;
+Xbyak::Label l50c;
+Xbyak::Label l524;
+Xbyak::Label l534;
+Xbyak::Label lcc;
+
+ preamble();
+#ifdef _WIN32
+ auto stacksize = get_size_of_abi_save_regs();
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(LDA, qword[LDA]);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ sub(A, -128);
+ sub(B, -128);
+ cmp(N, 0x8);
+ jl(l178, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ add(A, 0x8);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(lcc, T_NEAR);
+ align(4);
+
+L(l34);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ dec(I);
+ jg(l34, T_NEAR);
+ align(4);
+
+L(lcc);
+ test(M, 0x4);
+ jle(l120, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l120);
+ test(M, 0x2);
+ jle(l14c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l14c);
+ test(M, 0x1);
+ jle(l168, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l168);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l178);
+ cmp(N, 0x4);
+ jl(l2a8, T_NEAR);
+ align(4);
+
+L(l184);
+ mov(A1, A);
+ add(A, 0x4);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l20c, T_NEAR);
+ align(4);
+
+L(l194);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ dec(I);
+ jg(l194, T_NEAR);
+ align(4);
+
+L(l20c);
+ test(M, 0x4);
+ jle(l250, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l250);
+ test(M, 0x2);
+ jle(l27c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l27c);
+ test(M, 0x1);
+ jle(l298, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l298);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l184, T_NEAR);
+ align(4);
+
+L(l2a8);
+ cmp(N, 0x2);
+ jl(l40e, T_NEAR);
+ align(4);
+
+L(l2b4);
+ mov(A1, A);
+ add(A, 0x2);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l360, T_NEAR);
+ align(4);
+
+L(l2c8);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm4, eax, 0x0);
+ punpcklbw(xmm1, xmm2);
+ punpcklbw(xmm3, xmm4);
+ punpcklwd(xmm1, xmm3);
+ punpcklqdq(xmm0, xmm1);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(LDA3);
+ jg(l2c8, T_NEAR);
+ align(4);
+
+L(l360);
+ test(M, 0x4);
+ jle(l3b4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l3b4);
+ test(M, 0x2);
+ jle(l3e8, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l3e8);
+ test(M, 0x1);
+ jle(l400, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ mov(word[B-0x80], ax);
+ sub(B, -2);
+ align(4);
+
+L(l400);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l2b4, T_NEAR);
+ align(4);
+
+L(l40e);
+ cmp(N, 0x1);
+ jl(l534, T_NEAR);
+ align(4);
+
+L(l418);
+ mov(A1, A);
+ add(A, 0x1);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l4a0, T_NEAR);
+ align(4);
+
+L(l428);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x7);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ dec(LDA3);
+ jg(l428, T_NEAR);
+ align(4);
+
+L(l4a0);
+ test(M, 0x4);
+ jle(l4e8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l4e8);
+ test(M, 0x2);
+ jle(l50c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l50c);
+ test(M, 0x1);
+ jle(l524, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l524);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l418, T_NEAR);
+ align(4);
+
+L(l534);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp
new file mode 100644
index 0000000000..49a312fc88
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_an_kern.cpp
@@ -0,0 +1,1283 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_sum_an_kern::jit_avx512_core_u8_copy_sum_an_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#define ARG_BIAS 24+stacksize+rsp
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+#define ARG_BIAS 72+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l1024;
+Xbyak::Label l1090;
+Xbyak::Label l10d4;
+Xbyak::Label l10fc;
+Xbyak::Label l111a;
+Xbyak::Label l1124;
+Xbyak::Label l113c;
+Xbyak::Label l11d4;
+Xbyak::Label l1234;
+Xbyak::Label l1278;
+Xbyak::Label l129c;
+Xbyak::Label l12bc;
+Xbyak::Label l20;
+Xbyak::Label l2a0;
+Xbyak::Label l3c0;
+Xbyak::Label l438;
+Xbyak::Label l480;
+Xbyak::Label l48c;
+Xbyak::Label l4c8;
+Xbyak::Label l5c;
+Xbyak::Label l6a8;
+Xbyak::Label l7b4;
+Xbyak::Label l850;
+Xbyak::Label l89c;
+Xbyak::Label l8a8;
+Xbyak::Label l8d0;
+Xbyak::Label l9d0;
+Xbyak::Label la64;
+Xbyak::Label lab8;
+Xbyak::Label lae8;
+Xbyak::Label laf4;
+Xbyak::Label lb14;
+Xbyak::Label lc30;
+Xbyak::Label lcc8;
+Xbyak::Label ld1c;
+Xbyak::Label ld54;
+Xbyak::Label ld78;
+Xbyak::Label ld84;
+Xbyak::Label ld9c;
+Xbyak::Label le58;
+Xbyak::Label lebc;
+Xbyak::Label lef8;
+Xbyak::Label lf1c;
+Xbyak::Label lf3c;
+Xbyak::Label lf48;
+Xbyak::Label lf60;
+
+ preamble();
+ auto stacksize = get_size_of_abi_save_regs();
+#ifdef _WIN32
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(LDA, qword[LDA]);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ sub(A, -128);
+ sub(B, -128);
+ cmp(N, 0x30);
+ jl(l480, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ add(A, 0x30);
+ vxorps(ymm8, ymm8, ymm8);
+ vxorps(ymm9, ymm9, ymm9);
+ vxorps(ymm10, ymm10, ymm10);
+ vxorps(ymm11, ymm11, ymm11);
+ vxorps(ymm12, ymm12, ymm12);
+ vxorps(ymm13, ymm13, ymm13);
+ vxorps(ymm14, ymm14, ymm14);
+ vxorps(ymm15, ymm15, ymm15);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l2a0, T_NEAR);
+ align(4);
+
+L(l5c);
+ vmovdqu(xmm0, xword[A1-0x80]);
+ vmovdqu(xmm1, xword[A1+LDA*1-0x80]);
+ vmovdqu(xmm2, xword[A1+LDA*2-0x80]);
+ vmovdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ vpunpcklbw(xmm4, xmm0, xmm1);
+ vpunpckhbw(xmm5, xmm0, xmm1);
+ vpunpcklbw(xmm6, xmm2, xmm3);
+ vpunpckhbw(xmm7, xmm2, xmm3);
+ vpunpcklwd(xmm0, xmm4, xmm6);
+ vpunpckhwd(xmm1, xmm4, xmm6);
+ vpunpcklwd(xmm2, xmm5, xmm7);
+ vpunpckhwd(xmm3, xmm5, xmm7);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vmovdqu(xword[B-0x80], xmm0);
+ vmovdqu(xword[B-0x70], xmm1);
+ vpmovsxbw(ymm5, xmm2);
+ vmovhlps(xmm6, xmm2, xmm2);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vmovdqu(xword[B-0x60], xmm2);
+ vmovdqu(xword[B-0x50], xmm3);
+ vmovdqu(xmm0, xword[A1-0x70]);
+ vmovdqu(xmm1, xword[A1+LDA*1-0x70]);
+ vmovdqu(xmm2, xword[A1+LDA*2-0x70]);
+ vmovdqu(xmm3, xword[A1+LDA3*1-0x70]);
+ vpunpcklbw(xmm4, xmm0, xmm1);
+ vpunpckhbw(xmm5, xmm0, xmm1);
+ vpunpcklbw(xmm6, xmm2, xmm3);
+ vpunpckhbw(xmm7, xmm2, xmm3);
+ vpunpcklwd(xmm0, xmm4, xmm6);
+ vpunpckhwd(xmm1, xmm4, xmm6);
+ vpunpcklwd(xmm2, xmm5, xmm7);
+ vpunpckhwd(xmm3, xmm5, xmm7);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vmovdqu(xword[B-0x40], xmm0);
+ vmovdqu(xword[B-0x30], xmm1);
+ vpmovsxbw(ymm5, xmm2);
+ vmovhlps(xmm6, xmm2, xmm2);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vmovdqu(xword[B-0x20], xmm2);
+ vmovdqu(xword[B-0x10], xmm3);
+ vmovdqu(xmm0, xword[A1-0x60]);
+ vmovdqu(xmm1, xword[A1+LDA*1-0x60]);
+ vmovdqu(xmm2, xword[A1+LDA*2-0x60]);
+ vmovdqu(xmm3, xword[A1+LDA3*1-0x60]);
+ lea(A1, ptr[A1+LDA*4]);
+ vpunpcklbw(xmm4, xmm0, xmm1);
+ vpunpckhbw(xmm5, xmm0, xmm1);
+ vpunpcklbw(xmm6, xmm2, xmm3);
+ vpunpckhbw(xmm7, xmm2, xmm3);
+ vpunpcklwd(xmm0, xmm4, xmm6);
+ vpunpckhwd(xmm1, xmm4, xmm6);
+ vpunpcklwd(xmm2, xmm5, xmm7);
+ vpunpckhwd(xmm3, xmm5, xmm7);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vmovdqu(xword[B], xmm0);
+ vmovdqu(xword[B+0x10], xmm1);
+ vpmovsxbw(ymm5, xmm2);
+ vmovhlps(xmm6, xmm2, xmm2);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ vmovdqu(xword[B+0x20], xmm2);
+ vmovdqu(xword[B+0x30], xmm3);
+ sub(B, -192);
+ dec(I);
+ jg(l5c, T_NEAR);
+ align(4);
+
+L(l2a0);
+ test(M, 0x2);
+ jle(l3c0, T_NEAR);
+ vmovdqu(xmm0, xword[A1-0x80]);
+ vmovdqu(xmm1, xword[A1-0x70]);
+ vmovdqu(xmm2, xword[A1-0x60]);
+ add(A1, LDA);
+ vmovdqu(xmm6, xword[A1-0x80]);
+ vmovdqu(xmm4, xword[A1-0x70]);
+ vmovdqu(xmm5, xword[A1-0x60]);
+ add(A1, LDA);
+ vpunpcklbw(xmm3, xmm0, xmm6);
+ vpunpckhbw(xmm0, xmm0, xmm6);
+ vpmovsxbw(ymm7, xmm3);
+ vmovhlps(xmm6, xmm3, xmm3);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm8, ymm8, ymm7);
+ vmovdqu(xword[B-0x80], xmm3);
+ vpmovsxbw(ymm7, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm9, ymm9, ymm7);
+ vmovdqu(xword[B-0x70], xmm0);
+ vpunpcklbw(xmm3, xmm1, xmm4);
+ vpunpckhbw(xmm0, xmm1, xmm4);
+ vpmovsxbw(ymm7, xmm3);
+ vmovhlps(xmm6, xmm3, xmm3);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm10, ymm10, ymm7);
+ vmovdqu(xword[B-0x60], xmm3);
+ vpmovsxbw(ymm7, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm11, ymm11, ymm7);
+ vmovdqu(xword[B-0x50], xmm0);
+ vpunpcklbw(xmm3, xmm2, xmm5);
+ vpunpckhbw(xmm0, xmm2, xmm5);
+ vpmovsxbw(ymm7, xmm3);
+ vmovhlps(xmm6, xmm3, xmm3);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm12, ymm12, ymm7);
+ vmovdqu(xword[B-0x40], xmm3);
+ vpmovsxbw(ymm7, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm7, ymm7, ymm6);
+ vpmovsxwd(ymm7, xmm7);
+ vpaddd(ymm13, ymm13, ymm7);
+ vmovdqu(xword[B-0x30], xmm0);
+ sub(B, -96);
+ align(4);
+
+L(l3c0);
+ test(M, 0x1);
+ jle(l438, T_NEAR);
+ vmovdqu(xmm0, xword[A1-0x80]);
+ vmovdqu(xmm1, xword[A1-0x70]);
+ vmovdqu(xmm2, xword[A1-0x60]);
+ add(A1, LDA);
+ vpmovsxbd(ymm7, xmm0);
+ vpaddd(ymm8, ymm8, ymm7);
+ vmovhlps(xmm7, xmm0, xmm0);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm9, ymm9, ymm7);
+ vmovdqu(xword[B-0x80], xmm0);
+ vpmovsxbd(ymm7, xmm1);
+ vpaddd(ymm10, ymm10, ymm7);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm11, ymm11, ymm7);
+ vmovdqu(xword[B-0x70], xmm1);
+ vpmovsxbd(ymm7, xmm2);
+ vpaddd(ymm12, ymm12, ymm7);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm13, ymm13, ymm7);
+ vmovdqu(xword[B-0x60], xmm2);
+ sub(B, -48);
+ align(4);
+
+L(l438);
+ mov(A1, qword[ARG_BIAS]);
+ vmovdqu(yword[A1], ymm8);
+ vmovdqu(yword[A1+0x20], ymm9);
+ vmovdqu(yword[A1+0x40], ymm10);
+ vmovdqu(yword[A1+0x60], ymm11);
+ vmovdqu(yword[A1+0x80], ymm12);
+ vmovdqu(yword[A1+0xa0], ymm13);
+ add(qword[ARG_BIAS], 0xc0);
+ sub(N, 0x30);
+ cmp(N, 0x30);
+ jge(l20, T_NEAR);
+ vzeroupper();
+ align(4);
+
+L(l480);
+ cmp(N, 0x20);
+ jl(l89c, T_NEAR);
+ align(4);
+
+L(l48c);
+ mov(A1, A);
+ add(A, 0x20);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ pxor(xmm10, xmm10);
+ pxor(xmm11, xmm11);
+ pxor(xmm12, xmm12);
+ pxor(xmm13, xmm13);
+ pxor(xmm14, xmm14);
+ pxor(xmm15, xmm15);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l6a8, T_NEAR);
+ align(4);
+
+L(l4c8);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm4);
+ pmovsxbw(xmm5, xmm2);
+ movhlps(xmm6, xmm2);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm2);
+ movdqu(xmm0, xword[A1-0x70]);
+ movdqu(xmm1, xword[A1+LDA*1-0x70]);
+ movdqu(xmm2, xword[A1+LDA*2-0x70]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x70]);
+ lea(A1, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm5);
+ punpckhwd(xmm2, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B-0x30], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B-0x20], xmm4);
+ pmovsxbw(xmm5, xmm2);
+ movhlps(xmm6, xmm2);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B-0x10], xmm2);
+ sub(B, -128);
+ dec(I);
+ jg(l4c8, T_NEAR);
+ align(4);
+
+L(l6a8);
+ test(M, 0x2);
+ jle(l7b4, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqu(xmm2, xword[A1-0x80]);
+ movdqu(xmm3, xword[A1-0x70]);
+ add(A1, LDA);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm2);
+ punpckhbw(xmm4, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm4);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x70], xmm4);
+ movdqa(xmm4, xmm1);
+ punpcklbw(xmm1, xmm3);
+ punpckhbw(xmm4, xmm3);
+ pmovsxbw(xmm5, xmm1);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm13, xmm6);
+ movdqu(xword[B-0x60], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm15, xmm6);
+ movdqu(xword[B-0x50], xmm4);
+ sub(B, -64);
+ align(4);
+
+L(l7b4);
+ test(M, 0x1);
+ jle(l850, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1-0x70]);
+ add(A1, LDA);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm8, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbd(xmm5, xmm1);
+ paddd(xmm12, xmm5);
+ pshufd(xmm6, xmm1, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm13, xmm6);
+ pshufd(xmm5, xmm1, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ pshufd(xmm6, xmm1, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm15, xmm6);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l850);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ movdqu(xword[A1+0x20], xmm10);
+ movdqu(xword[A1+0x30], xmm11);
+ movdqu(xword[A1+0x40], xmm12);
+ movdqu(xword[A1+0x50], xmm13);
+ movdqu(xword[A1+0x60], xmm14);
+ movdqu(xword[A1+0x70], xmm15);
+ add(qword[ARG_BIAS], 0x80);
+ sub(N, 0x20);
+ cmp(N, 0x20);
+ jge(l48c, T_NEAR);
+ align(4);
+
+L(l89c);
+ cmp(N, 0x10);
+ jl(lae8, T_NEAR);
+ align(4);
+
+L(l8a8);
+ mov(A1, A);
+ add(A, 0x10);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ pxor(xmm10, xmm10);
+ pxor(xmm11, xmm11);
+ mov(I, M);
+ sar(I, 0x2);
+ jle(l9d0, T_NEAR);
+ align(4);
+
+L(l8d0);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm1, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm2, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm3, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqa(xmm4, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm4, xmm1);
+ movdqa(xmm1, xmm2);
+ punpcklbw(xmm2, xmm3);
+ punpckhbw(xmm1, xmm3);
+ movdqa(xmm3, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm3, xmm2);
+ movdqa(xmm2, xmm4);
+ punpcklwd(xmm4, xmm1);
+ punpckhwd(xmm2, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm3);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pmovsxbw(xmm5, xmm2);
+ movhlps(xmm6, xmm2);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x60], xmm4);
+ movdqu(xword[B-0x50], xmm2);
+ sub(B, -64);
+ dec(I);
+ jg(l8d0, T_NEAR);
+ align(4);
+
+L(l9d0);
+ test(M, 0x2);
+ jle(la64, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqu(xmm1, xword[A1-0x80]);
+ add(A1, LDA);
+ movdqa(xmm2, xmm0);
+ punpcklbw(xmm0, xmm1);
+ punpckhbw(xmm2, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pmovsxbw(xmm5, xmm2);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movhlps(xmm6, xmm2);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ align(4);
+
+L(la64);
+ test(M, 0x1);
+ jle(lab8, T_NEAR);
+ movdqu(xmm0, xword[A1-0x80]);
+ add(A1, LDA);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm8, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(lab8);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ movdqu(xword[A1+0x20], xmm10);
+ movdqu(xword[A1+0x30], xmm11);
+ add(qword[ARG_BIAS], 0x40);
+ sub(N, 0x10);
+ cmp(N, 0x10);
+ jge(l8a8, T_NEAR);
+ align(4);
+
+L(lae8);
+ cmp(N, 0x8);
+ jl(ld78, T_NEAR);
+ align(4);
+
+L(laf4);
+ mov(A1, A);
+ add(A, 0x8);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(lc30, T_NEAR);
+ align(4);
+
+L(lb14);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ dec(I);
+ jg(lb14, T_NEAR);
+ align(4);
+
+L(lc30);
+ test(M, 0x4);
+ jle(lcc8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(lcc8);
+ test(M, 0x2);
+ jle(ld1c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(ld1c);
+ test(M, 0x1);
+ jle(ld54, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ pmovsxbd(xmm5, xmm0);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm8, xmm5);
+ paddd(xmm9, xmm6);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(ld54);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ add(qword[ARG_BIAS], 0x20);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(laf4, T_NEAR);
+ align(4);
+
+L(ld78);
+ cmp(N, 0x4);
+ jl(lf3c, T_NEAR);
+ align(4);
+
+L(ld84);
+ mov(A1, A);
+ add(A, 0x4);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(le58, T_NEAR);
+ align(4);
+
+L(ld9c);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ dec(I);
+ jg(ld9c, T_NEAR);
+ align(4);
+
+L(le58);
+ test(M, 0x4);
+ jle(lebc, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(lebc);
+ test(M, 0x2);
+ jle(lef8, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(lef8);
+ test(M, 0x1);
+ jle(lf1c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(lf1c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x10);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(ld84, T_NEAR);
+ align(4);
+
+L(lf3c);
+ cmp(N, 0x2);
+ jl(l111a, T_NEAR);
+ align(4);
+
+L(lf48);
+ mov(A1, A);
+ add(A, 0x2);
+ pxor(xmm7, xmm7);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l1024, T_NEAR);
+ align(4);
+
+L(lf60);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm4, eax, 0x0);
+ punpcklbw(xmm1, xmm2);
+ punpcklbw(xmm3, xmm4);
+ punpcklwd(xmm1, xmm3);
+ punpcklqdq(xmm0, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(LDA3);
+ jg(lf60, T_NEAR);
+ align(4);
+
+L(l1024);
+ test(M, 0x4);
+ jle(l1090, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l1090);
+ test(M, 0x2);
+ jle(l10d4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l10d4);
+ test(M, 0x1);
+ jle(l10fc, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(word[B-0x80], ax);
+ sub(B, -2);
+ align(4);
+
+L(l10fc);
+ mov(A1, qword[ARG_BIAS]);
+ movq(qword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(lf48, T_NEAR);
+ align(4);
+
+L(l111a);
+ cmp(N, 0x1);
+ jl(l12bc, T_NEAR);
+ align(4);
+
+L(l1124);
+ mov(A1, A);
+ add(A, 0x1);
+ pxor(xmm7, xmm7);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l11d4, T_NEAR);
+ align(4);
+
+L(l113c);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ dec(LDA3);
+ jg(l113c, T_NEAR);
+ align(4);
+
+L(l11d4);
+ test(M, 0x4);
+ jle(l1234, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l1234);
+ test(M, 0x2);
+ jle(l1278, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l1278);
+ test(M, 0x1);
+ jle(l129c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l129c);
+ mov(A1, qword[ARG_BIAS]);
+ movd(dword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x4);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l1124, T_NEAR);
+ align(4);
+
+L(l12bc);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+#undef ARG_BIAS
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp
new file mode 100644
index 0000000000..a4f4ff09c6
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_at_kern.cpp
@@ -0,0 +1,3163 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_sum_at_kern::jit_avx512_core_u8_copy_sum_at_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#define ARG_BIAS 24+stacksize+rsp
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+#define ARG_BIAS 72+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l1750;
+Xbyak::Label l1b6c;
+Xbyak::Label l1e14;
+Xbyak::Label l20;
+Xbyak::Label l2068;
+Xbyak::Label l226c;
+Xbyak::Label l22b8;
+Xbyak::Label l22c4;
+Xbyak::Label l22f4;
+Xbyak::Label l26b4;
+Xbyak::Label l28cc;
+Xbyak::Label l2a2c;
+Xbyak::Label l2b5c;
+Xbyak::Label l2c64;
+Xbyak::Label l2c94;
+Xbyak::Label l2ca0;
+Xbyak::Label l2cc8;
+Xbyak::Label l2eac;
+Xbyak::Label l2fc0;
+Xbyak::Label l3078;
+Xbyak::Label l3118;
+Xbyak::Label l319c;
+Xbyak::Label l31c0;
+Xbyak::Label l31cc;
+Xbyak::Label l31ec;
+Xbyak::Label l32e4;
+Xbyak::Label l3378;
+Xbyak::Label l33dc;
+Xbyak::Label l3434;
+Xbyak::Label l347c;
+Xbyak::Label l349c;
+Xbyak::Label l34a8;
+Xbyak::Label l34c8;
+Xbyak::Label l3558;
+Xbyak::Label l35b0;
+Xbyak::Label l35f4;
+Xbyak::Label l3638;
+Xbyak::Label l366c;
+Xbyak::Label l368a;
+Xbyak::Label l3694;
+Xbyak::Label l36a8;
+Xbyak::Label l36ec;
+Xbyak::Label l3728;
+Xbyak::Label l3760;
+Xbyak::Label l3794;
+Xbyak::Label l37b8;
+Xbyak::Label l37d8;
+Xbyak::Label l5cc;
+Xbyak::Label l6c;
+Xbyak::Label l968;
+Xbyak::Label lc80;
+Xbyak::Label lf1c;
+Xbyak::Label lf64;
+Xbyak::Label lf70;
+Xbyak::Label lfb4;
+
+ preamble();
+ auto stacksize = get_size_of_abi_save_regs();
+#ifdef _WIN32
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(N, qword[N]);
+ mov(M, qword[M]);
+ mov(LDA, qword[LDA]);
+ sub(A, -128);
+ sub(B, -128);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ cmp(N, 0x30);
+ jl(lf64, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x5);
+ lea(I, ptr[I+LDA*8]);
+ lea(I, ptr[I+LDA*8]);
+ add(A, I);
+ vxorps(ymm8, ymm8, ymm8);
+ vxorps(ymm9, ymm9, ymm9);
+ vxorps(ymm10, ymm10, ymm10);
+ vxorps(ymm11, ymm11, ymm11);
+ vxorps(ymm12, ymm12, ymm12);
+ vxorps(ymm13, ymm13, ymm13);
+ vxorps(ymm14, ymm14, ymm14);
+ vxorps(ymm15, ymm15, ymm15);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l5cc, T_NEAR);
+ align(4);
+
+L(l6c);
+ vmovq(xmm0, qword[A1-0x80]);
+ vmovq(xmm1, qword[A1+LDA*1-0x80]);
+ vmovq(xmm2, qword[A1+LDA*2-0x80]);
+ vmovq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x80], xmm0);
+ vmovdqu(xword[B+0x40], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B-0x70], xmm2);
+ vmovdqu(xword[B+0x50], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x60], xmm0);
+ vmovdqu(xword[B+0x60], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B-0x50], xmm2);
+ vmovdqu(xword[B+0x70], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x40], xmm0);
+ vmovdqu(xword[B+0x80], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B-0x30], xmm2);
+ vmovdqu(xword[B+0x90], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x20], xmm0);
+ vmovdqu(xword[B+0xa0], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B-0x10], xmm2);
+ vmovdqu(xword[B+0xb0], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B], xmm0);
+ vmovdqu(xword[B+0xc0], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B+0x10], xmm2);
+ vmovdqu(xword[B+0xd0], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vmovq(xmm0, qword[A2-0x80]);
+ vmovq(xmm1, qword[A2+LDA*1-0x80]);
+ vmovq(xmm2, qword[A2+LDA*2-0x80]);
+ vmovq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm0, xmm1);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm1, xmm3);
+ vpunpckhqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B+0x20], xmm0);
+ vmovdqu(xword[B+0xe0], xmm1);
+ vmovq(xmm2, qword[A2-0x80]);
+ vmovq(xmm3, qword[A2+LDA*1-0x80]);
+ vmovq(xmm4, qword[A2+LDA*2-0x80]);
+ vmovq(xmm5, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm3, xmm2, xmm3);
+ vpunpckldq(xmm5, xmm4, xmm5);
+ vpunpcklqdq(xmm2, xmm3, xmm5);
+ vpunpckhqdq(xmm3, xmm3, xmm5);
+ vmovdqu(xword[B+0x30], xmm2);
+ vmovdqu(xword[B+0xf0], xmm3);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm2);
+ vmovhlps(xmm7, xmm2, xmm2);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ vpmovsxbw(ymm5, xmm1);
+ vmovhlps(xmm6, xmm1, xmm1);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm3);
+ vmovhlps(xmm7, xmm3, xmm3);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ sub(A1, -8);
+ sub(B, -384);
+ dec(I);
+ jg(l6c, T_NEAR);
+ align(4);
+
+L(l5cc);
+ test(M, 0x4);
+ jle(l968, T_NEAR);
+ vmovd(xmm0, dword[A1-0x80]);
+ vmovd(xmm1, dword[A1+LDA*1-0x80]);
+ vmovd(xmm2, dword[A1+LDA*2-0x80]);
+ vmovd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B-0x80], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x70], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B-0x60], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x50], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B-0x40], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x30], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B-0x20], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B-0x10], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B+0x10], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vmovd(xmm0, dword[A2-0x80]);
+ vmovd(xmm1, dword[A2+LDA*1-0x80]);
+ vmovd(xmm2, dword[A2+LDA*2-0x80]);
+ vmovd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm0, xmm0, xmm1);
+ vpunpckldq(xmm2, xmm2, xmm3);
+ vpunpcklqdq(xmm0, xmm0, xmm2);
+ vmovdqu(xword[B+0x20], xmm0);
+ vmovd(xmm1, dword[A2-0x80]);
+ vmovd(xmm2, dword[A2+LDA*1-0x80]);
+ vmovd(xmm3, dword[A2+LDA*2-0x80]);
+ vmovd(xmm4, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpunpckldq(xmm1, xmm1, xmm2);
+ vpunpckldq(xmm3, xmm3, xmm4);
+ vpunpcklqdq(xmm1, xmm1, xmm3);
+ vmovdqu(xword[B+0x30], xmm1);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxbw(ymm6, xmm1);
+ vmovhlps(xmm7, xmm1, xmm1);
+ vpmovsxbw(ymm7, xmm7);
+ vphaddw(ymm6, ymm6, ymm7);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ sub(A1, -4);
+ sub(B, -192);
+ align(4);
+
+L(l968);
+ test(M, 0x2);
+ jle(lc80, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm8, ymm8, ymm5);
+ vmovdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm9, ymm9, ymm5);
+ vmovdqu(xword[B-0x70], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm10, ymm10, ymm5);
+ vmovdqu(xword[B-0x60], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm11, ymm11, ymm5);
+ vmovdqu(xword[B-0x50], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm12, ymm12, ymm5);
+ vmovdqu(xword[B-0x40], xmm0);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrw(xmm0, xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ vpinsrw(xmm0, xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ vpmovsxbw(ymm5, xmm0);
+ vmovhlps(xmm6, xmm0, xmm0);
+ vpmovsxbw(ymm6, xmm6);
+ vphaddw(ymm5, ymm5, ymm6);
+ vpmovsxwd(ymm5, xmm5);
+ vpaddd(ymm13, ymm13, ymm5);
+ vmovdqu(xword[B-0x30], xmm0);
+ sub(A1, -2);
+ sub(B, -96);
+ align(4);
+
+L(lc80);
+ test(M, 0x1);
+ jle(lf1c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xf);
+ vpmovsxbd(ymm7, xmm0);
+ vpaddd(ymm8, ymm8, ymm7);
+ vmovhlps(xmm7, xmm0, xmm0);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm9, ymm9, ymm7);
+ vmovdqu(xword[B-0x80], xmm0);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xf);
+ vpmovsxbd(ymm7, xmm0);
+ vpaddd(ymm10, ymm10, ymm7);
+ vmovhlps(xmm7, xmm0, xmm0);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm11, ymm11, ymm7);
+ vmovdqu(xword[B-0x70], xmm0);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ vpinsrb(xmm0, xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ vpinsrb(xmm0, xmm0, eax, 0xf);
+ vpmovsxbd(ymm7, xmm0);
+ vpaddd(ymm12, ymm12, ymm7);
+ vmovhlps(xmm7, xmm0, xmm0);
+ vpmovsxbd(ymm7, xmm7);
+ vpaddd(ymm13, ymm13, ymm7);
+ vmovdqu(xword[B-0x60], xmm0);
+ sub(B, -48);
+ align(4);
+
+L(lf1c);
+ mov(A1, qword[ARG_BIAS]);
+ vmovdqu(yword[A1], ymm8);
+ vmovdqu(yword[A1+0x20], ymm9);
+ vmovdqu(yword[A1+0x40], ymm10);
+ vmovdqu(yword[A1+0x60], ymm11);
+ vmovdqu(yword[A1+0x80], ymm12);
+ vmovdqu(yword[A1+0xa0], ymm13);
+ add(qword[ARG_BIAS], 0xc0);
+ sub(N, 0x30);
+ cmp(N, 0x30);
+ jge(l20, T_NEAR);
+ vzeroupper();
+ align(4);
+
+L(lf64);
+ cmp(N, 0x20);
+ jl(l22b8, T_NEAR);
+ align(4);
+
+L(lf70);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x5);
+ add(A, I);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ pxor(xmm10, xmm10);
+ pxor(xmm11, xmm11);
+ pxor(xmm12, xmm12);
+ pxor(xmm13, xmm13);
+ pxor(xmm14, xmm14);
+ pxor(xmm15, xmm15);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l1750, T_NEAR);
+ align(4);
+
+L(lfb4);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B+0x80], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B+0x100], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x10], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x90], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x110], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x20], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0xa0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x120], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x30], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0xb0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x130], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B+0x40], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B+0xc0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B+0x140], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B+0x50], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B+0xd0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B+0x150], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B+0x60], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B+0xe0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B+0x160], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B+0x70], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B+0xf0], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B+0x170], xmm3);
+ sub(A1, -16);
+ sub(B, -512);
+ dec(I);
+ jg(lfb4, T_NEAR);
+ align(4);
+
+L(l1750);
+ test(M, 0x8);
+ jle(l1b6c, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x10], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x20], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x30], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B+0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B+0x50], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B+0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B+0x70], xmm1);
+ sub(A1, -8);
+ sub(B, -256);
+ align(4);
+
+L(l1b6c);
+ test(M, 0x4);
+ jle(l1e14, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movdqu(xword[B-0x40], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm13, xmm5);
+ movdqu(xword[B-0x30], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movdqu(xword[B-0x20], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm15, xmm5);
+ movdqu(xword[B-0x10], xmm0);
+ sub(A1, -4);
+ sub(B, -128);
+ align(4);
+
+L(l1e14);
+ test(M, 0x2);
+ jle(l2068, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x70], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm12, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm13, xmm6);
+ movdqu(xword[B-0x60], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ lea(A2, ptr[A2+LDA*4]);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm15, xmm6);
+ movdqu(xword[B-0x50], xmm0);
+ sub(A1, -2);
+ sub(B, -64);
+ align(4);
+
+L(l2068);
+ test(M, 0x1);
+ jle(l226c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm8, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xf);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm12, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm13, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm14, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm15, xmm6);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l226c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ movdqu(xword[A1+0x20], xmm10);
+ movdqu(xword[A1+0x30], xmm11);
+ movdqu(xword[A1+0x40], xmm12);
+ movdqu(xword[A1+0x50], xmm13);
+ movdqu(xword[A1+0x60], xmm14);
+ movdqu(xword[A1+0x70], xmm15);
+ add(qword[ARG_BIAS], 0x80);
+ sub(N, 0x20);
+ cmp(N, 0x20);
+ jge(lf70, T_NEAR);
+ align(4);
+
+L(l22b8);
+ cmp(N, 0x10);
+ jl(l2c94, T_NEAR);
+ align(4);
+
+L(l22c4);
+ mov(A1, A);
+ mov(I, LDA);
+ shl(I, 0x4);
+ add(A, I);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ pxor(xmm10, xmm10);
+ pxor(xmm11, xmm11);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l26b4, T_NEAR);
+ align(4);
+
+L(l22f4);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x40], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B+0x40], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x30], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x10], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B+0x50], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x20], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x20], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B+0x60], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x10], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x30], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B+0x70], xmm3);
+ sub(A1, -16);
+ sub(B, -256);
+ dec(I);
+ jg(l22f4, T_NEAR);
+ align(4);
+
+L(l26b4);
+ test(M, 0x8);
+ jle(l28cc, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x40], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x30], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x20], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x10], xmm1);
+ sub(A1, -8);
+ sub(B, -128);
+ align(4);
+
+L(l28cc);
+ test(M, 0x4);
+ jle(l2a2c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm11, xmm5);
+ movdqu(xword[B-0x50], xmm0);
+ sub(A1, -4);
+ sub(B, -64);
+ align(4);
+
+L(l2a2c);
+ test(M, 0x2);
+ jle(l2b5c, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x70], xmm0);
+ sub(A1, -2);
+ sub(B, -32);
+ align(4);
+
+L(l2b5c);
+ test(M, 0x1);
+ jle(l2c64, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ lea(A2, ptr[A1+LDA*4]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0x7);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x8);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x9);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xa);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ lea(A2, ptr[A2+LDA*4]);
+ pinsrb(xmm0, eax, 0xb);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0xc);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0xd);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0xe);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0xf);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm8, xmm5);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ pshufd(xmm5, xmm0, 0xaa);
+ pmovsxbd(xmm5, xmm5);
+ paddd(xmm10, xmm5);
+ pshufd(xmm6, xmm0, 0xff);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm11, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l2c64);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ movdqu(xword[A1+0x20], xmm10);
+ movdqu(xword[A1+0x30], xmm11);
+ add(qword[ARG_BIAS], 0x40);
+ sub(N, 0x10);
+ cmp(N, 0x10);
+ jge(l22c4, T_NEAR);
+ align(4);
+
+L(l2c94);
+ cmp(N, 0x8);
+ jl(l31c0, T_NEAR);
+ align(4);
+
+L(l2ca0);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*4]);
+ lea(I, ptr[A1+LDA*8]);
+ mov(A, I);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l2eac, T_NEAR);
+ align(4);
+
+L(l2cc8);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ sub(A1, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x60], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x40], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x20], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x50], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x30], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x10], xmm3);
+ sub(B, -128);
+ dec(I);
+ jg(l2cc8, T_NEAR);
+ align(4);
+
+L(l2eac);
+ test(M, 0x8);
+ jle(l2fc0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ sub(A1, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ align(4);
+
+L(l2fc0);
+ test(M, 0x4);
+ jle(l3078, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ sub(A1, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l3078);
+ test(M, 0x2);
+ jle(l3118, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l3118);
+ test(M, 0x1);
+ jle(l319c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x7);
+ pmovsxbd(xmm5, xmm0);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm8, xmm5);
+ paddd(xmm9, xmm6);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l319c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ add(qword[ARG_BIAS], 0x20);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l2ca0, T_NEAR);
+ align(4);
+
+L(l31c0);
+ cmp(N, 0x4);
+ jl(l349c, T_NEAR);
+ align(4);
+
+L(l31cc);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*2]);
+ lea(I, ptr[A1+LDA*4]);
+ mov(A, I);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l32e4, T_NEAR);
+ align(4);
+
+L(l31ec);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm2, xword[A2-0x80]);
+ movdqu(xmm3, xword[A2+LDA*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x60], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x50], xmm3);
+ sub(B, -64);
+ dec(I);
+ jg(l31ec, T_NEAR);
+ align(4);
+
+L(l32e4);
+ test(M, 0x8);
+ jle(l3378, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ sub(A1, -8);
+ movq(xmm2, qword[A2-0x80]);
+ movq(xmm3, qword[A2+LDA*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l3378);
+ test(M, 0x4);
+ jle(l33dc, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ sub(A1, -4);
+ movd(xmm2, dword[A2-0x80]);
+ movd(xmm3, dword[A2+LDA*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l33dc);
+ test(M, 0x2);
+ jle(l3434, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x3);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l3434);
+ test(M, 0x1);
+ jle(l347c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l347c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x10);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l31cc, T_NEAR);
+ align(4);
+
+L(l349c);
+ cmp(N, 0x2);
+ jl(l368a, T_NEAR);
+ align(4);
+
+L(l34a8);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*1]);
+ lea(I, ptr[A1+LDA*2]);
+ mov(A, I);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l3558, T_NEAR);
+ align(4);
+
+L(l34c8);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm1, xword[A2-0x80]);
+ sub(A2, -16);
+ movdqa(xmm2, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm2, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pshufd(xmm6, xmm2, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ dec(I);
+ jg(l34c8, T_NEAR);
+ align(4);
+
+L(l3558);
+ test(M, 0x8);
+ jle(l35b0, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(xmm1, qword[A2-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l35b0);
+ test(M, 0x4);
+ jle(l35f4, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(xmm1, dword[A2-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l35f4);
+ test(M, 0x2);
+ jle(l3638, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l3638);
+ test(M, 0x1);
+ jle(l366c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ align(4);
+
+L(l366c);
+ mov(A1, qword[ARG_BIAS]);
+ movq(qword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l34a8, T_NEAR);
+ align(4);
+
+L(l368a);
+ cmp(N, 0x1);
+ jl(l37d8, T_NEAR);
+ align(4);
+
+L(l3694);
+ mov(A1, A);
+ add(A, LDA);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l36ec, T_NEAR);
+ align(4);
+
+L(l36a8);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(I);
+ jg(l36a8, T_NEAR);
+ align(4);
+
+L(l36ec);
+ test(M, 0x8);
+ jle(l3728, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l3728);
+ test(M, 0x4);
+ jle(l3760, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l3760);
+ test(M, 0x2);
+ jle(l3794, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ mov(word[B-0x80], ax);
+ sub(A1, -2);
+ sub(B, -2);
+ align(4);
+
+L(l3794);
+ test(M, 0x1);
+ jle(l37b8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l37b8);
+ mov(A1, qword[ARG_BIAS]);
+ movd(dword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x4);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l3694, T_NEAR);
+ align(4);
+
+L(l37d8);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+#undef ARG_BIAS
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp
new file mode 100644
index 0000000000..c7f1393c9d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bn_kern.cpp
@@ -0,0 +1,821 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_sum_bn_kern::jit_avx512_core_u8_copy_sum_bn_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#define ARG_BIAS 24+stacksize+rsp
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+#define ARG_BIAS 72+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l20;
+Xbyak::Label l22c;
+Xbyak::Label l340;
+Xbyak::Label l3f8;
+Xbyak::Label l48;
+Xbyak::Label l498;
+Xbyak::Label l51c;
+Xbyak::Label l540;
+Xbyak::Label l54c;
+Xbyak::Label l56c;
+Xbyak::Label l664;
+Xbyak::Label l6f8;
+Xbyak::Label l75c;
+Xbyak::Label l7b4;
+Xbyak::Label l7fc;
+Xbyak::Label l81c;
+Xbyak::Label l828;
+Xbyak::Label l848;
+Xbyak::Label l8d8;
+Xbyak::Label l930;
+Xbyak::Label l974;
+Xbyak::Label l9b8;
+Xbyak::Label l9ec;
+Xbyak::Label la0a;
+Xbyak::Label la14;
+Xbyak::Label la28;
+Xbyak::Label la6c;
+Xbyak::Label laa8;
+Xbyak::Label lae0;
+Xbyak::Label lb14;
+Xbyak::Label lb38;
+Xbyak::Label lb58;
+
+ preamble();
+ auto stacksize = get_size_of_abi_save_regs();
+#ifdef _WIN32
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(N, qword[N]);
+ mov(M, qword[M]);
+ mov(LDA, qword[LDA]);
+ sub(A, -128);
+ sub(B, -128);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ cmp(N, 0x8);
+ jl(l540, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*4]);
+ lea(I, ptr[A1+LDA*8]);
+ mov(A, I);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l22c, T_NEAR);
+ align(4);
+
+L(l48);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ movdqu(xmm2, xword[A1+LDA*2-0x80]);
+ movdqu(xmm3, xword[A1+LDA3*1-0x80]);
+ sub(A1, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x60], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x40], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x20], xmm3);
+ movdqu(xmm0, xword[A2-0x80]);
+ movdqu(xmm1, xword[A2+LDA*1-0x80]);
+ movdqu(xmm2, xword[A2+LDA*2-0x80]);
+ movdqu(xmm3, xword[A2+LDA3*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x50], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x30], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x10], xmm3);
+ sub(B, -128);
+ dec(I);
+ jg(l48, T_NEAR);
+ align(4);
+
+L(l22c);
+ test(M, 0x8);
+ jle(l340, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ movq(xmm2, qword[A1+LDA*2-0x80]);
+ movq(xmm3, qword[A1+LDA3*1-0x80]);
+ sub(A1, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x60], xmm1);
+ movq(xmm0, qword[A2-0x80]);
+ movq(xmm1, qword[A2+LDA*1-0x80]);
+ movq(xmm2, qword[A2+LDA*2-0x80]);
+ movq(xmm3, qword[A2+LDA3*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ align(4);
+
+L(l340);
+ test(M, 0x4);
+ jle(l3f8, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ movd(xmm2, dword[A1+LDA*2-0x80]);
+ movd(xmm3, dword[A1+LDA3*1-0x80]);
+ sub(A1, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A2-0x80]);
+ movd(xmm1, dword[A2+LDA*1-0x80]);
+ movd(xmm2, dword[A2+LDA*2-0x80]);
+ movd(xmm3, dword[A2+LDA3*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ align(4);
+
+L(l3f8);
+ test(M, 0x2);
+ jle(l498, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A1+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A1+LDA3*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x3);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x4);
+ mov(ax, word[A2+LDA*1-0x80]);
+ pinsrw(xmm0, eax, 0x5);
+ mov(ax, word[A2+LDA*2-0x80]);
+ pinsrw(xmm0, eax, 0x6);
+ mov(ax, word[A2+LDA3*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l498);
+ test(M, 0x1);
+ jle(l51c, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A2+LDA*2-0x80]);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A2+LDA3*1-0x80]);
+ pinsrb(xmm0, eax, 0x7);
+ pmovsxbd(xmm5, xmm0);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm8, xmm5);
+ paddd(xmm9, xmm6);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l51c);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ add(qword[ARG_BIAS], 0x20);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l540);
+ cmp(N, 0x4);
+ jl(l81c, T_NEAR);
+ align(4);
+
+L(l54c);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*2]);
+ lea(I, ptr[A1+LDA*4]);
+ mov(A, I);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l664, T_NEAR);
+ align(4);
+
+L(l56c);
+ movdqu(xmm0, xword[A1-0x80]);
+ movdqu(xmm1, xword[A1+LDA*1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm2, xword[A2-0x80]);
+ movdqu(xmm3, xword[A2+LDA*1-0x80]);
+ sub(A2, -16);
+ movdqa(xmm4, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm4, xmm1);
+ movdqa(xmm5, xmm2);
+ punpckldq(xmm2, xmm3);
+ punpckhdq(xmm5, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ movdqa(xmm3, xmm4);
+ punpcklqdq(xmm4, xmm5);
+ punpckhqdq(xmm3, xmm5);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ pmovsxbw(xmm5, xmm4);
+ movhlps(xmm6, xmm4);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x60], xmm4);
+ pmovsxbw(xmm5, xmm3);
+ movhlps(xmm6, xmm3);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x50], xmm3);
+ sub(B, -64);
+ dec(I);
+ jg(l56c, T_NEAR);
+ align(4);
+
+L(l664);
+ test(M, 0x8);
+ jle(l6f8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ movq(xmm1, qword[A1+LDA*1-0x80]);
+ sub(A1, -8);
+ movq(xmm2, qword[A2-0x80]);
+ movq(xmm3, qword[A2+LDA*1-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklqdq(xmm0, xmm2);
+ punpckhqdq(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l6f8);
+ test(M, 0x4);
+ jle(l75c, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ movd(xmm1, dword[A1+LDA*1-0x80]);
+ sub(A1, -4);
+ movd(xmm2, dword[A2-0x80]);
+ movd(xmm3, dword[A2+LDA*1-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ punpckldq(xmm2, xmm3);
+ punpcklqdq(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l75c);
+ test(M, 0x2);
+ jle(l7b4, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1+LDA*1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x1);
+ mov(ax, word[A2-0x80]);
+ pinsrw(xmm0, eax, 0x2);
+ mov(ax, word[A2+LDA*1-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x3);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l7b4);
+ test(M, 0x1);
+ jle(l7fc, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A2+LDA*1-0x80]);
+ pinsrb(xmm0, eax, 0x3);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l7fc);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x10);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l54c, T_NEAR);
+ align(4);
+
+L(l81c);
+ cmp(N, 0x2);
+ jl(la0a, T_NEAR);
+ align(4);
+
+L(l828);
+ mov(A1, A);
+ lea(A2, ptr[A1+LDA*1]);
+ lea(I, ptr[A1+LDA*2]);
+ mov(A, I);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(l8d8, T_NEAR);
+ align(4);
+
+L(l848);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ movdqu(xmm1, xword[A2-0x80]);
+ sub(A2, -16);
+ movdqa(xmm2, xmm0);
+ punpckldq(xmm0, xmm1);
+ punpckhdq(xmm2, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ pshufd(xmm6, xmm2, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm2);
+ sub(B, -32);
+ dec(I);
+ jg(l848, T_NEAR);
+ align(4);
+
+L(l8d8);
+ test(M, 0x8);
+ jle(l930, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ movq(xmm1, qword[A2-0x80]);
+ sub(A2, -8);
+ punpckldq(xmm0, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l930);
+ test(M, 0x4);
+ jle(l974, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ movd(xmm1, dword[A2-0x80]);
+ sub(A2, -4);
+ punpckldq(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l974);
+ test(M, 0x2);
+ jle(l9b8, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ sub(A1, -2);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A2-0x80]);
+ sub(A2, -2);
+ pinsrw(xmm0, eax, 0x1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l9b8);
+ test(M, 0x1);
+ jle(l9ec, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A2-0x80]);
+ pinsrb(xmm0, eax, 0x1);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ align(4);
+
+L(l9ec);
+ mov(A1, qword[ARG_BIAS]);
+ movq(qword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l828, T_NEAR);
+ align(4);
+
+L(la0a);
+ cmp(N, 0x1);
+ jl(lb58, T_NEAR);
+ align(4);
+
+L(la14);
+ mov(A1, A);
+ add(A, LDA);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x4);
+ jle(la6c, T_NEAR);
+ align(4);
+
+L(la28);
+ movdqu(xmm0, xword[A1-0x80]);
+ sub(A1, -16);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(I);
+ jg(la28, T_NEAR);
+ align(4);
+
+L(la6c);
+ test(M, 0x8);
+ jle(laa8, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ sub(A1, -8);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(laa8);
+ test(M, 0x4);
+ jle(lae0, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ sub(A1, -4);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(lae0);
+ test(M, 0x2);
+ jle(lb14, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ mov(word[B-0x80], ax);
+ sub(A1, -2);
+ sub(B, -2);
+ align(4);
+
+L(lb14);
+ test(M, 0x1);
+ jle(lb38, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrb(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(lb38);
+ mov(A1, qword[ARG_BIAS]);
+ movd(dword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x4);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(la14, T_NEAR);
+ align(4);
+
+L(lb58);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+#undef ARG_BIAS
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp
new file mode 100644
index 0000000000..afe4f1713e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/jit_avx512_core_u8_copy_sum_bt_kern.cpp
@@ -0,0 +1,647 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_generator.hpp"
+#include "common.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+jit_avx512_core_u8_copy_sum_bt_kern::jit_avx512_core_u8_copy_sum_bt_kern(): jit_generator(nullptr, GEMM_CODE_SIZE)
+{
+
+#ifndef _WIN32
+#define M rdi
+#define N rsi
+#define A rdx
+#define LDA rcx
+#define ALPHA r8
+#define B r9
+
+#define I rax
+#define A1 r10
+#define A2 r8
+#define LDA3 r11
+
+#define ARG_BIAS 24+stacksize+rsp
+
+#else
+
+#define M rcx
+#define N rdx
+#define A r8
+#define LDA r9
+#define ALPHA rax
+#define B rdi
+
+#define I rax
+#define A1 rsi
+#define A2 r10
+#define LDA3 r11
+
+#define ARG_ALPHA 40+stacksize+rsp
+#define ARG_B 48+stacksize+rsp
+#define ARG_BIAS 72+stacksize+rsp
+
+#endif
+
+inLocalLabel();
+{
+
+Xbyak::Label l15c;
+Xbyak::Label l1f4;
+Xbyak::Label l20;
+Xbyak::Label l248;
+Xbyak::Label l280;
+Xbyak::Label l2a4;
+Xbyak::Label l2b0;
+Xbyak::Label l2c8;
+Xbyak::Label l384;
+Xbyak::Label l3e8;
+Xbyak::Label l40;
+Xbyak::Label l424;
+Xbyak::Label l448;
+Xbyak::Label l468;
+Xbyak::Label l474;
+Xbyak::Label l48c;
+Xbyak::Label l550;
+Xbyak::Label l5bc;
+Xbyak::Label l600;
+Xbyak::Label l628;
+Xbyak::Label l646;
+Xbyak::Label l650;
+Xbyak::Label l668;
+Xbyak::Label l700;
+Xbyak::Label l760;
+Xbyak::Label l7a4;
+Xbyak::Label l7c8;
+Xbyak::Label l7e8;
+
+ preamble();
+ auto stacksize = get_size_of_abi_save_regs();
+#ifdef _WIN32
+ mov(ALPHA, ptr[ARG_ALPHA]);
+ mov(B, ptr[ARG_B]);
+#endif
+
+ mov(M, qword[M]);
+ mov(N, qword[N]);
+ mov(LDA, qword[LDA]);
+ lea(LDA3, ptr[LDA+LDA*2]);
+ sub(A, -128);
+ sub(B, -128);
+ cmp(N, 0x8);
+ jl(l2a4, T_NEAR);
+ align(4);
+
+L(l20);
+ mov(A1, A);
+ add(A, 0x8);
+ pxor(xmm8, xmm8);
+ pxor(xmm9, xmm9);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l15c, T_NEAR);
+ align(4);
+
+L(l40);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x60], xmm0);
+ movdqu(xword[B-0x50], xmm1);
+ sub(B, -64);
+ dec(I);
+ jg(l40, T_NEAR);
+ align(4);
+
+L(l15c);
+ test(M, 0x4);
+ jle(l1f4, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm2, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm3, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ movdqa(xmm1, xmm0);
+ punpcklwd(xmm0, xmm2);
+ punpckhwd(xmm1, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ pmovsxbw(xmm5, xmm1);
+ movhlps(xmm6, xmm1);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm9, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movdqu(xword[B-0x70], xmm1);
+ sub(B, -32);
+ align(4);
+
+L(l1f4);
+ test(M, 0x2);
+ jle(l248, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ movq(xmm1, qword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm8, xmm5);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm6, xmm6);
+ pmovsxwd(xmm6, xmm6);
+ paddd(xmm9, xmm6);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l248);
+ test(M, 0x1);
+ jle(l280, T_NEAR);
+ movq(xmm0, qword[A1-0x80]);
+ add(A1, LDA);
+ pmovsxbd(xmm5, xmm0);
+ pshufd(xmm6, xmm0, 0x55);
+ pmovsxbd(xmm6, xmm6);
+ paddd(xmm8, xmm5);
+ paddd(xmm9, xmm6);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l280);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm8);
+ movdqu(xword[A1+0x10], xmm9);
+ add(qword[ARG_BIAS], 0x20);
+ sub(N, 0x8);
+ cmp(N, 0x8);
+ jge(l20, T_NEAR);
+ align(4);
+
+L(l2a4);
+ cmp(N, 0x4);
+ jl(l468, T_NEAR);
+ align(4);
+
+L(l2b0);
+ mov(A1, A);
+ add(A, 0x4);
+ pxor(xmm7, xmm7);
+ mov(I, M);
+ sar(I, 0x3);
+ jle(l384, T_NEAR);
+ align(4);
+
+L(l2c8);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x70], xmm0);
+ sub(B, -32);
+ dec(I);
+ jg(l2c8, T_NEAR);
+ align(4);
+
+L(l384);
+ test(M, 0x4);
+ jle(l3e8, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm2, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm3, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ movhlps(xmm6, xmm0);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ align(4);
+
+L(l3e8);
+ test(M, 0x2);
+ jle(l424, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ add(A1, LDA);
+ movd(xmm1, dword[A1-0x80]);
+ add(A1, LDA);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l424);
+ test(M, 0x1);
+ jle(l448, T_NEAR);
+ movd(xmm0, dword[A1-0x80]);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l448);
+ mov(A1, qword[ARG_BIAS]);
+ movdqu(xword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x10);
+ sub(N, 0x4);
+ cmp(N, 0x4);
+ jge(l2b0, T_NEAR);
+ align(4);
+
+L(l468);
+ cmp(N, 0x2);
+ jl(l646, T_NEAR);
+ align(4);
+
+L(l474);
+ mov(A1, A);
+ add(A, 0x2);
+ pxor(xmm7, xmm7);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l550, T_NEAR);
+ align(4);
+
+L(l48c);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm4, eax, 0x0);
+ punpcklbw(xmm1, xmm2);
+ punpcklbw(xmm3, xmm4);
+ punpcklwd(xmm1, xmm3);
+ punpcklqdq(xmm0, xmm1);
+ pshufd(xmm6, xmm0, 0xd8);
+ pmovsxbw(xmm5, xmm6);
+ movhlps(xmm6, xmm6);
+ pmovsxbw(xmm6, xmm6);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movdqu(xword[B-0x80], xmm0);
+ sub(B, -16);
+ dec(LDA3);
+ jg(l48c, T_NEAR);
+ align(4);
+
+L(l550);
+ test(M, 0x4);
+ jle(l5bc, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm2, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm3, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ punpcklbw(xmm2, xmm3);
+ punpcklwd(xmm0, xmm2);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ align(4);
+
+L(l5bc);
+ test(M, 0x2);
+ jle(l600, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm0, eax, 0x0);
+ mov(ax, word[A1-0x80]);
+ add(A1, LDA);
+ pinsrw(xmm1, eax, 0x0);
+ punpcklbw(xmm0, xmm1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l600);
+ test(M, 0x1);
+ jle(l628, T_NEAR);
+ mov(ax, word[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(word[B-0x80], ax);
+ sub(B, -2);
+ align(4);
+
+L(l628);
+ mov(A1, qword[ARG_BIAS]);
+ movq(qword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x8);
+ sub(N, 0x2);
+ cmp(N, 0x2);
+ jge(l474, T_NEAR);
+ align(4);
+
+L(l646);
+ cmp(N, 0x1);
+ jl(l7e8, T_NEAR);
+ align(4);
+
+L(l650);
+ mov(A1, A);
+ add(A, 0x1);
+ pxor(xmm7, xmm7);
+ mov(LDA3, M);
+ sar(LDA3, 0x3);
+ jle(l700, T_NEAR);
+ align(4);
+
+L(l668);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x4);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x5);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x6);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x7);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm6);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movq(qword[B-0x80], xmm0);
+ sub(B, -8);
+ dec(LDA3);
+ jg(l668, T_NEAR);
+ align(4);
+
+L(l700);
+ test(M, 0x4);
+ jle(l760, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x2);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x3);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ movd(dword[B-0x80], xmm0);
+ sub(B, -4);
+ align(4);
+
+L(l760);
+ test(M, 0x2);
+ jle(l7a4, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x0);
+ mov(byte[B-0x80], al);
+ mov(al, byte[A1-0x80]);
+ add(A1, LDA);
+ pinsrb(xmm0, eax, 0x1);
+ pmovsxbw(xmm5, xmm0);
+ phaddw(xmm5, xmm5);
+ pmovsxwd(xmm5, xmm5);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x7f], al);
+ sub(B, -2);
+ align(4);
+
+L(l7a4);
+ test(M, 0x1);
+ jle(l7c8, T_NEAR);
+ mov(al, byte[A1-0x80]);
+ pinsrw(xmm0, eax, 0x0);
+ pmovsxbd(xmm5, xmm0);
+ paddd(xmm7, xmm5);
+ mov(byte[B-0x80], al);
+ sub(B, -1);
+ align(4);
+
+L(l7c8);
+ mov(A1, qword[ARG_BIAS]);
+ movd(dword[A1], xmm7);
+ add(qword[ARG_BIAS], 0x4);
+ sub(N, 0x1);
+ cmp(N, 0x1);
+ jge(l650, T_NEAR);
+ align(4);
+
+L(l7e8);
+
+ postamble();
+}
+outLocalLabel();
+
+#undef M
+#undef N
+#undef A
+#undef LDA
+#undef ALPHA
+#undef B
+#undef I
+#undef A1
+#undef A2
+#undef LDA3
+#ifdef _WIN32
+#undef ARG_ALPHA
+#undef ARG_B
+#endif
+#undef ARG_BIAS
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp
new file mode 100644
index 0000000000..4fc11afcbc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.cpp
@@ -0,0 +1,116 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <cstdint>
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "../f32/ref_gemm_f32.hpp"
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <typename b_dt>
+mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co) {
+
+ if (*M == 0 || *N == 0 || *K == 0)
+ return mkldnn_success;
+
+ bool OCisR = (*offsetc == 'R' || *offsetc == 'r');
+ bool OCisC = (*offsetc == 'C' || *offsetc == 'c');
+ bool AisN = (*transa == 'N' || *transa == 'n');
+ bool BisN = (*transb == 'N' || *transb == 'n');
+
+ int m = *M, n = *N, k = *K, lda = *LDA, ldb = *LDB, ldc = *LDC;
+ size_t sizeA = AisN ? lda * k : lda * m;
+ size_t sizeB = BisN ? ldb * n : ldb * k;
+ size_t sizeC = ldc * n;
+
+ double *dA = (double *)malloc(sizeA * sizeof(double), PAGE_4K);
+ double *dB = (double *)malloc(sizeB * sizeof(double), PAGE_4K);
+ double *dC = (double *)malloc(sizeC * sizeof(double), PAGE_4K);
+
+ if (utils::any_null(dA, dB, dC)) {
+ free(dA);
+ free(dB);
+ free(dC);
+ return mkldnn_out_of_memory;
+ }
+
+ auto da_setter = [=] (int i, int j, double v) { dA[j * lda + i] = v; };
+ auto db_setter = [=] (int i, int j, double v) { dB[j * ldb + i] = v; };
+
+ auto ia_accessor = [=] (int i, int j) { return A[j * lda + i]; };
+ auto ib_accessor = [=] (int i, int j) { return B[j * ldb + i]; };
+
+ const int a_rows = AisN ? m : k;
+ const int a_cols = AisN ? k : m;
+ mkldnn::impl::parallel_nd(a_cols, a_rows, [&](int j, int i) {
+ da_setter(i, j,
+ static_cast<double>(ia_accessor(i, j)) + static_cast<double>(ao[0]));
+ });
+
+ const int b_rows = BisN ? k : n;
+ const int b_cols = BisN ? n : k;
+ mkldnn::impl::parallel_nd(b_cols, b_rows, [&](int j, int i) {
+ db_setter(i, j,
+ static_cast<double>(ib_accessor(i, j)) + static_cast<double>(bo[0]));
+ });
+ double one = 1.0, zero = 0.0;
+ ref_gemm<double>(transa, transb, M, N, K, &one, dA, LDA, dB, LDB, &zero,
+ dC, LDC, nullptr);
+
+ auto i2d = [=] (int32_t v) { return static_cast<double>(v); };
+ auto f2d = [=] (float v) { return static_cast<double>(v); };
+
+ mkldnn::impl::parallel_nd(n, m, [&] (int j, int i) {
+ double coffset = OCisR ? i2d(co[j]) : OCisC ? i2d(co[i]) : i2d(co[0]);
+ double val = ((*beta == 0.0f) ? 0.0 : f2d(*beta) * i2d(C[i + j * ldc]))
+ + f2d(*alpha) * dC[i + j * ldc] + coffset;
+ C[i + j * ldc] = math::out_round<int32_t>(math::saturate<int32_t>(val));
+ });
+
+ free(dA);
+ free(dB);
+ free(dC);
+ return mkldnn_success;
+}
+
+template mkldnn_status_t ref_gemm_s8x8s32<uint8_t>(
+ const char *transa, const char *transb, const char *offsetc,
+ const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const uint8_t *B, const int *LDB, const int8_t *bo,
+ const float *beta, int32_t *C, const int *LDC, const int32_t *co);
+
+template mkldnn_status_t ref_gemm_s8x8s32<int8_t>(
+ const char *transa, const char *transb, const char *offsetc,
+ const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const int8_t *B, const int *LDB, const int8_t *bo,
+ const float *beta, int32_t *C, const int *LDC, const int32_t *co);
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp
new file mode 100644
index 0000000000..6c0370ae99
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/ref_gemm_s8x8s32.hpp
@@ -0,0 +1,38 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef REF_GEMM_S8X8S32_HPP
+#define REF_GEMM_S8X8S32_HPP
+
+#include <stdint.h>
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <typename b_dt>
+mkldnn_status_t ref_gemm_s8x8s32(const char *transa, const char *transb,
+ const char *offsetc, const int *M, const int *N, const int *K,
+ const float *alpha, const int8_t *A, const int *LDA, const int8_t *ao,
+ const b_dt *B, const int *LDB, const int8_t *bo, const float *beta,
+ int32_t *C, const int *LDC, const int32_t *co);
+
+}
+}
+}
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp
new file mode 100644
index 0000000000..de1035f3b2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.cpp
@@ -0,0 +1,180 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "common.hpp"
+#include "nstl.hpp"
+#include "math_utils.hpp"
+
+#include "../gemm.hpp"
+#include "jit_avx512_core_gemm_s8u8s32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+void compensation_init(const char *offsetC, int32_t *compensation, int len,
+ const int32_t *oc) {
+ bool OCisC = (*offsetC == 'C' || *offsetC == 'c');
+ bool OCisF = (*offsetC == 'F' || *offsetC == 'f');
+
+ if (OCisF && (*oc) != 0) {
+ for (int i = 0; i < len; i++)
+ compensation[i] = *oc;
+ } else if (OCisC) {
+ for (int i = 0; i < len; i++)
+ compensation[i] = oc[i];
+ } else {
+ parallel_nd(len, [=](int i) { compensation[i] = 0; });
+ }
+}
+
+void compensation_compute(bool transa, int m, int k, float alpha,
+ const int8_t *a, int lda, int32_t *compensation) {
+ if (!transa) {
+ const int L2_cache_size = get_cache_size(2, true);
+ const int blocking_factor = nstl::min(k, L2_cache_size / lda + 1);
+ const int npanels = k / blocking_factor;
+ const bool has_tile = k % blocking_factor > 0;
+
+ parallel_nd(npanels, m, [&](int j, int i) {
+ int32_t val = 0;
+ for (int jb = 0; jb < blocking_factor; jb++) {
+ val += a[(i + (ptrdiff_t)j * blocking_factor * lda)
+ + (ptrdiff_t)jb * lda];
+ }
+ if (alpha != 1.0f) {
+ val = math::out_round<int32_t>(math::saturate<int32_t>(
+ (double)val * alpha * -128.0));
+ } else {
+ val *= -128;
+ }
+ fetch_and_add(&compensation[i], val);
+ });
+
+ if (has_tile) {
+ parallel_nd(m, [=](int i) {
+ int32_t val = 0;
+ for (int j = npanels * blocking_factor; j < k; j++) {
+ val += a[i + (ptrdiff_t)j * lda];
+ }
+ if (alpha != 1.0f) {
+ val = math::out_round<int32_t>(math::saturate<int32_t>(
+ (double)val * alpha * -128.0));
+ } else {
+ val *= -128;
+ }
+ fetch_and_add(&compensation[i], val);
+ });
+ }
+ } else {
+ parallel_nd(m, [=](int i) {
+ int32_t val = 0;
+ for (int j = 0; j < k; j++) {
+ val += a[j + (ptrdiff_t)i * lda];
+ }
+ if (alpha != 1.0f) {
+ val = math::out_round<int32_t>(math::saturate<int32_t>(
+ (double)val * alpha * -128.0));
+ } else {
+ val *= -128;
+ }
+ compensation[i] += val;
+ });
+ }
+}
+
+void copy_and_shift_b(bool transb, int k, int n, uint8_t *b_u8, int ldb_u8,
+ const int8_t *b_s8, int ldb_s8) {
+ const int b_cols = transb ? k : n;
+
+ parallel_nd(b_cols, [=](int j) {
+ const int b_rows = transb ? n : k;
+
+ uint8_t *pb_u8 = b_u8 + j * ldb_u8;
+ const int8_t *pb_s8 = b_s8 + j * ldb_s8;
+
+ for (int i = 0; i < b_rows; i++) {
+ (*pb_u8) = (*pb_s8) + 128;
+ pb_u8++;
+ pb_s8++;
+ }
+ });
+}
+
+/**
+ * gemm_s8s8s32 operation is defined as follows:
+ * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset + compensation
+ *
+ * where
+ * - compensation is a vector of length m that contains computed compensation
+ * that may contain C_offset if applicable. The compensation is applied inside
+ * gemm_s8u8s32 as a C_offset
+ * - B_shift is a k-by-n matrix, every element of B_shift is equal to 128
+ *
+ * What is the compensation:
+ * In order to prepare the matrix B for gemm_s8u8s32 call the B_shift is applied:
+ * C = alpha * op(A) * (op(B) + B_shift) + beta * C + C_offset =
+ * alpha * op(A) * op(B) + alpha * op(A) * B_shift + beta * C + C_offset
+ * compensation = -alpha * op(A) * B_shift
+ * Since B_shift is a matrix, every element of which is equal to 128 then
+ * - if op(A) = A: compensation contains sum of the elements in each row
+ * scaled by -128 * alpha
+ * - if op(A) = A**T: compensation contains sum of the elements in each column
+ * scaled by -128 * alpha
+ *
+ * The rest of parameters is described in mkldnn.h
+ */
+mkldnn_status_t simple_gemm_s8s8s32(
+ const char *transA, const char *transB, const char *offsetC,
+ const int *m, const int *n, const int *k,
+ const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
+ const int8_t *b, const int *ldb, const int8_t *ob,
+ const float *beta, int32_t *c, const int *ldc, const int32_t *oc) {
+ if (*oa != 0 || *ob != 0) return mkldnn_unimplemented;
+
+ int M = *m, N = *n, K = *k;
+ bool transa = (*transA == 'T' || *transA == 't');
+ bool transb = (*transB == 'T' || *transB == 't');
+ int ld = transb ? N : K;
+
+ uint8_t *b_u8 = (uint8_t *)malloc(sizeof(uint8_t) * K * N, 64);
+ int32_t *compensation = (int32_t *)malloc(sizeof(int32_t) * M, 64);
+
+ if (utils::any_null(b_u8, compensation)) {
+ free(b_u8);
+ free(compensation);
+ return mkldnn_out_of_memory;
+ }
+
+ compensation_init(offsetC, compensation, M, oc);
+ compensation_compute(transa, M, K, *alpha, a, *lda, compensation);
+ copy_and_shift_b(transb, K, N, b_u8, ld, b, *ldb);
+
+ gemm_s8x8s32(transA, transB, "C", m, n, k, alpha, a, lda, oa, b_u8,
+ &ld, ob, beta, c, ldc, compensation);
+
+ if ((*offsetC == 'R' || *offsetC == 'r'))
+ parallel_nd(M, N,
+ [=](int i, int j) { c[i + (ptrdiff_t)j * *ldc] += oc[j]; });
+
+ free(b_u8);
+ free(compensation);
+
+ return mkldnn_success;
+}
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp
new file mode 100644
index 0000000000..03a3d2f7e0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm/s8x8s32/simple_gemm_s8s8s32.hpp
@@ -0,0 +1,37 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SIMPLE_GEMM_S8S8S32_HPP
+#define SIMPLE_GEMM_S8S8S32_HPP
+
+#include <stdint.h>
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+mkldnn_status_t simple_gemm_s8s8s32(
+ const char *transA, const char *transB, const char *offsetC,
+ const int *m, const int *n, const int *k,
+ const float *alpha, const int8_t *a, const int *lda, const int8_t *oa,
+ const int8_t *b, const int *ldb, const int8_t *ob,
+ const float *beta, int32_t *c, const int *ldc, const int32_t *oc);
+}
+}
+}
+
+#endif // SIMPLE_GEMM_S8S8S32_HPP
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp
new file mode 100644
index 0000000000..604a728b47
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.cpp
@@ -0,0 +1,307 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "gemm_convolution.hpp"
+#include "utils.hpp"
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+#include "ref_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+void gemm_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ auto col = scratchpad(ctx).get<data_t>(key_conv_gemm_col);
+
+ const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
+ const int M = jcp.os * jcp.od;
+ const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
+ const size_t dst_step = jcp.oc * M;
+ const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
+
+ assert(IMPLICATION(
+ jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
+ assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
+
+ const int K = jcp.ic * jcp.ks;
+ const int N = jcp.oc;
+
+ if (jcp.im2col_sz && jcp.id != 1)
+ parallel_nd(jcp.im2col_sz * jcp.nthr,
+ [&](ptrdiff_t i) { col[i] = (data_t)0; });
+
+ const int nb_oh = div_up(jcp.oh, jcp.oh_block);
+ const int nb_ow = div_up(jcp.ow, jcp.ow_block);
+ const size_t work_amount = jcp.ngroups * jcp.mb * jcp.od * nb_oh * nb_ow;
+ parallel(jcp.nthr, [&](const int ithr, const int nthr) {
+ data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
+
+ int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 };
+ size_t start = 0, end = 0;
+
+ balance211(work_amount, nthr, ithr, start, end);
+ nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb,
+ nb_oh, owb, nb_ow);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ int oh = ohb * jcp.oh_block;
+ int ow = owb * jcp.ow_block;
+ const data_t *_src = src + (n * jcp.ngroups + g) * src_step;
+ const data_t *_weights = weights + g * weights_g_size;
+ data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step;
+ const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
+ const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
+ if (jcp.im2col_sz) {
+ if (jcp.id == 1)
+ jit_gemm_convolution_utils::im2col(
+ jcp, _src, _col, oh, h_step, ow, w_step);
+ else
+ jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od);
+ }
+
+ const data_t one = 1.0;
+
+ const int m = h_step * w_step;
+ const int LDA = jcp.im2col_sz ? m : M;
+ data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow;
+
+ extended_sgemm("N", "N", &m, &N, &K, &one,
+ jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K,
+ &this->beta_, _dst, &M);
+
+ data_t *d = _dst;
+ if (eltwise_) {
+ // fast branch for ReLU case
+ if (eltwise_->alg_ == alg_kind::eltwise_relu) {
+ parallel_nd(jcp.oc, [&](const int oc) {
+ data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0;
+ data_t *d_ = d + oc * M;
+ PRAGMA_OMP_SIMD()
+ for (int oS = 0; oS < m; ++oS) {
+ d_[oS] += b;
+ if (d_[oS] < 0) d_[oS] *= eltwise_->alpha_;
+ }
+ });
+ } else {
+ parallel_nd(jcp.oc, [&](const int oc) {
+ data_t b = jcp.with_bias ? bias[g * jcp.oc + oc] : 0;
+ data_t *d_ = d + oc * M;
+ PRAGMA_OMP_SIMD()
+ for (int oS = 0; oS < m; ++oS) {
+ d_[oS] += b;
+ d_[oS] = eltwise_->compute_scalar(d_[oS]);
+ }
+ });
+ }
+ } else if (jcp.with_bias) {
+ parallel_nd(jcp.oc, [&](const int oc) {
+ data_t b = bias[g * jcp.oc + oc];
+ data_t *d_ = d + oc * M;
+ PRAGMA_OMP_SIMD()
+ for (int oS = 0; oS < m; ++oS) {
+ d_[oS] += b;
+ }
+ });
+ }
+ nd_iterator_step(g, jcp.ngroups, n, jcp.mb, od, jcp.od, ohb, nb_oh,
+ owb, nb_ow);
+ }
+ });
+}
+
+void gemm_convolution_bwd_data_t::execute_backward_data(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ auto col = scratchpad(ctx).get<data_t>(key_conv_gemm_col);
+
+ const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
+ const int M = jcp.os * jcp.od;
+ const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
+ const size_t dst_step = jcp.oc * M;
+ const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
+
+ const int m = jcp.os;
+ const int K = jcp.oc;
+ const int N = jcp.ic * jcp.ks;
+ const int LDC = jcp.im2col_sz ? m : M;
+
+ const size_t work_amount = (size_t)jcp.ngroups * jcp.mb;
+
+ if (jcp.id > 1) {
+ const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step);
+ parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; });
+ }
+
+ parallel(jcp.nthr, [&](const int ithr, const int nthr) {
+ data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
+
+ int g{0}, n{0};
+ size_t start = 0, end = 0;
+ balance211(work_amount, nthr, ithr, start, end);
+ nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+
+ data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step;
+ const data_t *_weights = weights + g * weights_g_size;
+ for (int od = 0; od < jcp.od; ++od) {
+ const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g)
+ *dst_step + od * m;
+
+ const data_t zero = 0.0, one = 1.0;
+ extended_sgemm("N", "T", &m, &N, &K, &one, _diff_dst, &M,
+ _weights, &N, &zero,
+ jcp.im2col_sz ? _col:_diff_src + od * m, &LDC);
+
+ if (jcp.im2col_sz) {
+ if (jcp.id == 1)
+ jit_gemm_convolution_utils::col2im(jcp, _col,
+ _diff_src);
+ else
+ jit_gemm_convolution_utils::col2im_3d(jcp, _col,
+ _diff_src, od);
+ }
+ }
+ nd_iterator_step(g, jcp.ngroups, n, jcp.mb);
+ }
+ });
+}
+
+void gemm_convolution_bwd_weights_t::execute_backward_weights(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ auto col = scratchpad(ctx).get<data_t>(key_conv_gemm_col);
+ auto wei_reduction = scratchpad(ctx).get<data_t>(key_conv_wei_reduction);
+
+ const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
+ const int K = jcp.os * jcp.od;
+ const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
+ const size_t dst_step = jcp.oc * K;
+ const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
+
+ const int k = jcp.os;
+ const int N = jcp.oc;
+ const int M = jcp.ic * jcp.ks;
+ const int LDA = jcp.im2col_sz ? k : K;
+
+ parallel_nd(jcp.im2col_sz * jcp.nthr,
+ [&](ptrdiff_t i) { col[i] = (data_t)0; });
+
+ parallel(jcp.nthr, [&](const int ithr, const int nthr) {
+ int ithr_g, nthr_g, ithr_mb, nthr_mb;
+ size_t g_start{0}, g_end{0}, mb_start{0}, mb_end{0};
+
+ const int mb_for_balance = jcp.need_wei_reduction ? jcp.mb : 1;
+ jit_gemm_convolution_utils::bwd_weights_balance(ithr, nthr, jcp.ngroups,
+ mb_for_balance, ithr_g, nthr_g, ithr_mb, nthr_mb);
+
+ assert(IMPLICATION(!jcp.need_wei_reduction, nthr_mb == 1));
+ const int need_reduction = nthr_mb != 1;
+
+ if (ithr_g != -1 && ithr_mb != -1) {
+ balance211((size_t)jcp.ngroups, nthr_g, ithr_g, g_start, g_end);
+ balance211((size_t)jcp.mb, nthr_mb, ithr_mb, mb_start, mb_end);
+
+ assert(IMPLICATION((g_end - g_start) > 1, need_reduction == 0));
+
+ data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
+ data_t *weights_reduce_base = wei_reduction
+ + ithr_g * nthr_mb * weights_g_size;
+ data_t *weights_reduce = weights_reduce_base
+ + ithr_mb * weights_g_size;
+
+ for (size_t g = g_start; g < g_end; ++g) {
+ data_t *_diff_weights = need_reduction
+ ? weights_reduce : (diff_weights + g * weights_g_size);
+ for (size_t mb = mb_start; mb < mb_end; ++mb) {
+ const data_t *_src = src + (mb*jcp.ngroups+g)*src_step;
+ for (int od = 0; od < jcp.od; ++od) {
+ const data_t *_diff_dst = diff_dst
+ + (mb*jcp.ngroups+g)*dst_step + od * k;
+
+ if (jcp.im2col_sz) {
+ if (jcp.id == 1)
+ jit_gemm_convolution_utils::im2col(
+ jcp, _src, _col, 0, jcp.oh, 0, jcp.ow);
+ else
+ jit_gemm_convolution_utils::im2col_3d(jcp, _src,
+ _col, od);
+ }
+
+ const data_t zero = 0.0, one = 1.0;
+ extended_sgemm(
+ "T", "N", &M, &N, &k, &one,
+ jcp.im2col_sz ? _col : _src + od * k,
+ &LDA, _diff_dst, &K,
+ mb == mb_start && od == 0 ? &zero : &one,
+ _diff_weights, &M);
+ }
+ }
+ }
+ if (need_reduction) {
+ mkldnn_thr_barrier();
+ data_t *weights_base = diff_weights + g_start * weights_g_size;
+ jit_gemm_convolution_utils::bwd_weights_reduction_par(
+ ithr_mb, nthr_mb, jcp, weights_reduce_base, weights_base);
+ }
+ } else
+ if (need_reduction) { mkldnn_thr_barrier(); }
+ });
+
+ if (jcp.with_bias) {
+ parallel_nd(jcp.ngroups, jcp.oc, [&](int g, int oc) {
+ data_t db = 0;
+ size_t offset_ = (size_t)g * dst_step + (size_t)oc * K;
+ for (int mb = 0; mb < jcp.mb; ++mb)
+ {
+ size_t offset = offset_ + (size_t)mb * jcp.ngroups * dst_step;
+ for (int od = 0; od < jcp.od; ++od)
+ for (int oh = 0; oh < jcp.oh; ++oh)
+ PRAGMA_OMP_SIMD(reduction(+:db))
+ for (int ow = 0; ow < jcp.ow; ++ow) {
+ db += diff_dst[offset];
+ offset++;
+ }
+ }
+ diff_bias[g*jcp.oc+oc] = db;
+ });
+ }
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp
new file mode 100644
index 0000000000..302e46369a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution.hpp
@@ -0,0 +1,250 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_GEMM_CONVOLUTION_HPP
+#define CPU_JIT_GEMM_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "gemm_convolution_utils.hpp"
+#include "gemm/gemm.hpp"
+#include "ref_eltwise.hpp"
+
+#include "cpu_convolution_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct gemm_convolution_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc, const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats_common(dat_tag(), wei_tag(), dat_tag())
+ && post_ops_ok()
+ && memory_desc_matches_tag(*src_md(), dat_tag())
+ && memory_desc_matches_tag(*dst_md(), dat_tag())
+ && memory_desc_matches_tag(*weights_md(), wei_tag());
+ if (!ok) return status::unimplemented;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+ *desc(), src_md(), weights_md(0), dst_md(),
+ mkldnn_get_max_threads());
+ }
+
+ jit_gemm_conv_conf_t jcp_;
+
+ protected:
+ format_tag_t dat_tag() const {
+ using namespace format_tag;
+ return utils::pick(ndims() - 3, ncw, nchw, ncdhw);
+ }
+
+ format_tag_t wei_tag() const {
+ using namespace format_tag;
+ return with_groups()
+ ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
+ : utils::pick(ndims() - 3, oiw, oihw, oidhw);
+ }
+
+ bool post_ops_ok() const {
+ auto const &po = attr()->post_ops_;
+ auto is_eltwise = [&](int idx)
+ { return po.entry_[idx].is_eltwise(); };
+ auto is_sum = [&](int idx) { return po.entry_[idx].is_sum(); };
+
+ switch (po.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
+ case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
+ default: return false;
+ }
+ return false;
+ }
+ };
+
+ gemm_convolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true)
+ , eltwise_(nullptr)
+ {
+ const auto &post_ops = pd()->attr()->post_ops_;
+ const data_t one = 1.0, zero = 0.0;
+ beta_ = post_ops.find(primitive_kind::sum) >= 0 ? one : zero;
+
+ const int entry_idx = post_ops.find(primitive_kind::eltwise);
+ if (entry_idx != -1) eltwise_ = new ref_eltwise_scalar_fwd_t(
+ post_ops.entry_[entry_idx].eltwise);
+ }
+
+ ~gemm_convolution_fwd_t() { delete eltwise_; }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ data_t beta_;
+
+ ref_eltwise_scalar_fwd_t* eltwise_;
+};
+
+struct gemm_convolution_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_data_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc, const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::undef, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats_common(dat_tag(), wei_tag(), dat_tag())
+ && memory_desc_matches_tag(*diff_src_md(), dat_tag())
+ && memory_desc_matches_tag(*diff_dst_md(), dat_tag())
+ && memory_desc_matches_tag(*weights_md(), wei_tag());
+ if (!ok) return status::unimplemented;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+ *desc(), diff_src_md(), weights_md(0), diff_dst_md(),
+ mkldnn_get_max_threads());
+ }
+
+ jit_gemm_conv_conf_t jcp_;
+
+ protected:
+ format_tag_t dat_tag() const {
+ using namespace format_tag;
+ return utils::pick(ndims() - 3, ncw, nchw, ncdhw);
+ }
+
+ format_tag_t wei_tag() const {
+ using namespace format_tag;
+ return with_groups()
+ ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
+ : utils::pick(ndims() - 3, oiw, oihw, oidhw);
+ }
+ };
+
+ gemm_convolution_bwd_data_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true) {}
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+struct gemm_convolution_bwd_weights_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_weights_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_convolution_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_weights
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats_common(dat_tag(), wei_tag(), dat_tag())
+ && memory_desc_matches_tag(*src_md(), dat_tag())
+ && memory_desc_matches_tag(*diff_dst_md(), dat_tag())
+ && memory_desc_matches_tag(*diff_weights_md(), wei_tag());
+ if (!ok) return status::unimplemented;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+ *desc(), src_md(), diff_weights_md(0), diff_dst_md(),
+ mkldnn_get_max_threads());
+ }
+
+ jit_gemm_conv_conf_t jcp_;
+
+ protected:
+ format_tag_t dat_tag() const {
+ using namespace format_tag;
+ return utils::pick(ndims() - 3, ncw, nchw, ncdhw);
+ }
+
+ format_tag_t wei_tag() const {
+ using namespace format_tag;
+ return with_groups()
+ ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
+ : utils::pick(ndims() - 3, oiw, oihw, oidhw);
+ }
+ };
+
+ gemm_convolution_bwd_weights_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true) {}
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_weights(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_weights(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp
new file mode 100644
index 0000000000..f133b1e62b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.cpp
@@ -0,0 +1,771 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+#include "cpu_isa_traits.hpp"
+
+#include "gemm_convolution_utils.hpp"
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+using namespace prop_kind;
+using namespace data_type;
+
+namespace jit_gemm_convolution_utils {
+
+void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
+ int od)
+{
+ const size_t OHW = jcp.oh * jcp.ow;
+ const size_t im_step = jcp.ih * jcp.iw * jcp.id;
+ const size_t col_step = jcp.ks * OHW;
+
+ parallel_nd(jcp.ic, [&](int ic) {
+ const float *__restrict im_loc = im + ic * im_step;
+ float *__restrict col_loc = col + ic * col_step;
+ int id = od * jcp.stride_d - jcp.f_pad;
+ for (int kd = 0; kd < jcp.kd; ++kd) {
+ float *__restrict col_ = col_loc + kd * jcp.kh * jcp.kw * OHW;
+ if (id < 0 || id >= jcp.id) {
+ int ih_ = -jcp.t_pad;
+ for (int kh = 0; kh < jcp.kh; ++kh) {
+ int ih = ih_;
+ for (int oh = 0; oh < jcp.oh; ++oh) {
+ if (ih < 0 || ih >= jcp.ih) {
+ ih += jcp.stride_h;
+ continue;
+ }
+ int iw_ = -jcp.l_pad;
+ for (int kw = 0; kw < jcp.kw; ++kw) {
+ int iw = iw_;
+ for (int ow = 0; ow < jcp.ow; ++ow) {
+ if (iw < 0 || iw >= jcp.iw) {
+ iw += jcp.stride_w;
+ continue;
+ }
+
+ const size_t col_idx = kw * OHW + oh * jcp.ow
+ + ow;
+
+ col_[col_idx] = 0;
+ iw += jcp.stride_w;
+ }
+ iw_ += (1 + jcp.dilate_w);
+ }
+ ih += jcp.stride_h;
+ }
+ ih_ += (1 + jcp.dilate_h);
+ col_ += jcp.kw * OHW;
+ }
+ } else {
+ const float *__restrict im_ = im_loc + id * jcp.ih * jcp.iw;
+ int ih_ = -jcp.t_pad;
+ for (int kh = 0; kh < jcp.kh; ++kh) {
+ int ih = ih_;
+ for (int oh = 0; oh < jcp.oh; ++oh) {
+ if (ih < 0 || ih >= jcp.ih) {
+ ih += jcp.stride_h;
+ continue;
+ }
+ int iw_ = -jcp.l_pad;
+ for (int kw = 0; kw < jcp.kw; ++kw) {
+ int iw = iw_;
+ for (int ow = 0; ow < jcp.ow; ++ow) {
+ if (iw < 0 || iw >= jcp.iw) {
+ iw += jcp.stride_w;
+ continue;
+ }
+
+ const size_t col_idx = kw * OHW + oh * jcp.ow
+ + ow;
+ const size_t im_idx = ih * jcp.iw + iw;
+
+ col_[col_idx] = im_[im_idx];
+ iw += jcp.stride_w;
+ }
+ iw_ += (1 + jcp.dilate_w);
+ }
+ ih += jcp.stride_h;
+ }
+ ih_ += (1 + jcp.dilate_h);
+ col_ += jcp.kw * OHW;
+ }
+ }
+ id += (1 + jcp.dilate_d);
+ }
+ });
+}
+
+/* col[ic][kh][kw][oh][ow] <-- im2col(im[ic][ih][iw]) */
+void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
+ float *__restrict col, int hs, int hb, int ws, int wb) {
+ const size_t im_step = jcp.is;
+ const size_t col_step = jcp.ks * hb * wb;
+ if (jcp.stride_w == 1) {
+ // Generated code is more optimized for stride_w == 1
+ // because innermost loop is by width
+ auto ker = [&](int ic, int kh, int kw, int oh) {
+ const float *__restrict im_ = im + ic * im_step;
+ float *__restrict col_
+ = col + ic * col_step + ((kh * jcp.kw + kw) * hb + oh) * wb;
+
+ const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
+ + kh * (1 + jcp.dilate_h);
+ if (ih < 0 || ih >= jcp.ih) {
+ for (int ow = 0; ow < wb; ++ow)
+ col_[ow] = 0.f;
+ } else {
+ for (int ow = 0; ow < wb; ++ow) {
+ const int iw = ow + ws - jcp.l_pad + kw * (1 + jcp.dilate_w);
+ if (iw < 0 || iw >= jcp.iw)
+ col_[ow] = 0.f;
+ else {
+ const size_t im_idx = ih * jcp.iw + iw;
+ col_[ow] = im_[im_idx];
+ }
+ }
+ }
+ };
+
+ if (jcp.outer_threading) {
+ for (int ic = 0; ic < jcp.ic; ic++)
+ for (int kh = 0; kh < jcp.kh; kh++)
+ for (int kw = 0; kw < jcp.kw; kw++)
+ for (int oh = 0; oh < hb; oh++)
+ ker(ic, kh, kw, oh);
+ }
+ else {
+ parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb, ker);
+ }
+ } else if (jcp.ic == 1) {
+ parallel_nd(jcp.kh, hb, [&](int kh, int oh) {
+ const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
+ + kh * (1 + jcp.dilate_h);
+ if (ih < 0 || ih >= jcp.ih)
+ for (int kw = 0; kw < jcp.kw; ++kw) {
+ for (int ow = 0; ow < wb; ++ow) {
+ const size_t col_idx
+ = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
+ col[col_idx] = 0;
+ }
+ }
+ else
+ for (int kw = 0; kw < jcp.kw; ++kw) {
+ for (int ow = 0; ow < wb; ++ow) {
+ const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
+ + kw * (1 + jcp.dilate_w);
+ const size_t col_idx
+ = ((kh * jcp.kw + kw) * hb + oh) * wb + ow;
+ const size_t im_idx = ih * jcp.iw + iw;
+ if (iw < 0 || iw >= jcp.iw)
+ col[col_idx] = 0;
+ else
+ col[col_idx] = im[im_idx];
+ }
+ }
+ });
+ } else {
+
+ parallel_nd(jcp.ic, jcp.kh, jcp.kw, hb,
+ [&](int ic, int kh, int kw, int oh) {
+ const float *__restrict im_ = im + ic * im_step;
+ float *__restrict col_ = col + ic * col_step
+ + ((kh * jcp.kw + kw) * hb + oh) * wb;
+
+ const int ih = (oh + hs) * jcp.stride_h - jcp.t_pad
+ + kh * (1 + jcp.dilate_h);
+ if (ih < 0 || ih >= jcp.ih) {
+ for (int ow = 0; ow < wb; ++ow)
+ col_[ow] = 0.f;
+ } else {
+ for (int ow = 0; ow < wb; ++ow) {
+ const int iw = (ow + ws) * jcp.stride_w - jcp.l_pad
+ + kw * (1 + jcp.dilate_w);
+ const size_t im_idx = ih * jcp.iw + iw;
+ if (iw < 0 || iw >= jcp.iw)
+ col_[ow] = 0.f;
+ else
+ col_[ow] = im_[im_idx];
+ }
+ }
+ });
+ }
+}
+
+inline int limit(int low, int upper, int value) {
+ return nstl::max(low, nstl::min(upper, value));
+}
+
+/* col[kh][kw][ic][oh][ow] <-- im2col_u8(im[ih][iw][ic]) */
+template <typename T>
+void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
+ T *__restrict imtr, uint8_t *__restrict col, int hs, int hb, int ws,
+ int wb) {
+ uint8_t shift = jcp.signed_input ? 128 : 0;
+ const int dh = 1 + jcp.dilate_h;
+ const int dw = 1 + jcp.dilate_w;
+ const int sh = jcp.stride_h;
+ const int sw = jcp.stride_w;
+ const int im_iw_stride = jcp.ic * jcp.ngroups;
+ const int im_ih_stride = jcp.iw * im_iw_stride;
+ const int tp = jcp.t_pad;
+ const int lp = jcp.l_pad;
+
+ if (jcp.outer_threading && sh == 1 && sw == 1 && dh == 1 && dw == 1) {
+ /* im[ih][iw][ic] --> imtr[ic][ih][iw] --> col[kh][kw][ic][oh][ow] */
+ const int hp = hs - tp;
+ const int wp = ws - lp;
+ const int ih_start = limit(0, jcp.ih, hp);
+ const int ih_end = limit(0, jcp.ih, hp + hb + jcp.kh);
+ const int iw_start = limit(0, jcp.iw, wp);
+ const int iw_end = limit(0, jcp.iw, wp + wb + jcp.kw);
+
+ const int ihb = ih_end - ih_start;
+ const int iwb = iw_end - iw_start;
+
+ const int imtr_ic_stride = ihb * iwb;
+ const ptrdiff_t imtr_idx_shift = ih_start * iwb + iw_start;
+ for (int ic = 0; ic < jcp.ic; ic++) {
+ const ptrdiff_t imtr_idx_ic = ic * imtr_ic_stride - imtr_idx_shift;
+ for (int ih = ih_start; ih < ih_end; ih++) {
+ const ptrdiff_t im_idx_ih = ic + ih * im_ih_stride;
+ const ptrdiff_t imtr_idx_ih = imtr_idx_ic + ih * iwb;
+ for (int iw = iw_start; iw < iw_end; iw++)
+ imtr[imtr_idx_ih + iw] = im[im_idx_ih + iw * im_iw_stride];
+ }
+ }
+
+ const int col_ic_str = hb * wb;
+ const int col_kw_stride = jcp.ic * col_ic_str;
+ const int col_kh_stride = jcp.kw * col_kw_stride;
+
+ const int oh_init = ih_start - hp;
+ const int ow_init = iw_start - wp;
+ for (int kh = 0; kh < jcp.kh; kh++) {
+ const ptrdiff_t col_idx_kh = kh * col_kh_stride;
+ const int oh_kh = oh_init - kh;
+ const int oh_start = limit(0, hb, oh_kh);
+ const int oh_end = limit(0, hb, oh_kh + ihb);
+ for (int kw = 0; kw < jcp.kw; kw++) {
+ const ptrdiff_t col_idx_kw
+ = col_idx_kh + kw * jcp.ic * col_ic_str;
+ const int ow_kw = ow_init - kw;
+ const int imtr_shift = oh_kh * iwb + ow_kw;
+ const int ow_start = limit(0, wb, ow_kw);
+ const int ow_end = limit(0, wb, ow_kw + iwb);
+ for (int ic = 0; ic < jcp.ic; ic++) {
+ const ptrdiff_t col_idx_ic = col_idx_kw + ic * col_ic_str;
+ const int imtr_idx_ic = ic * imtr_ic_stride - imtr_shift;
+ for (int oh = 0; oh < oh_start; oh++) {
+ const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
+ for (int ow = 0; ow < wb; ++ow)
+ col[col_idx_oh + ow] = shift;
+ }
+ for (int oh = oh_start; oh < oh_end; oh++) {
+ const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
+ const ptrdiff_t imtr_idx_oh = imtr_idx_ic + oh * iwb;
+ for (int ow = 0; ow < ow_start; ++ow)
+ col[col_idx_oh + ow] = shift;
+ for (int ow = ow_start; ow < ow_end; ++ow)
+ col[col_idx_oh + ow]
+ = imtr[imtr_idx_oh + ow] + shift;
+ for (int ow = ow_end; ow < wb; ++ow)
+ col[col_idx_oh + ow] = shift;
+ }
+ for (int oh = oh_end; oh < hb; oh++) {
+ const ptrdiff_t col_idx_oh = col_idx_ic + oh * wb;
+ for (int ow = 0; ow < wb; ++ow)
+ col[col_idx_oh + ow] = shift;
+ }
+ }
+ }
+ }
+ } else {
+ parallel_nd(jcp.kh, jcp.kw, jcp.ic, hb,
+ [&](int kh, int kw, int ic, int oh) {
+ const int hp = tp - kh * dh;
+ const int ih = (oh + hs) * sh - hp;
+ const ptrdiff_t col_idx_base
+ = (((kh * jcp.kw + kw) * jcp.ic + ic) * hb + oh) * wb;
+ if (ih < 0 || ih >= jcp.ih)
+ for (int ow = 0; ow < wb; ow++)
+ col[col_idx_base + ow] = shift;
+ else {
+ const int wp = lp - kw * dw;
+ const int ow_start = limit(0, wb, div_up(wp, sw) - ws);
+ const int ow_end
+ = limit(0, wb, div_up(jcp.iw + wp, sw) - ws);
+ for (int ow = 0; ow < ow_start; ow++)
+ col[col_idx_base + ow] = shift;
+ const int iw_base = ws * sw - wp;
+ const ptrdiff_t im_idx_base = ih * im_ih_stride + ic;
+ for (int ow = ow_start; ow < ow_end; ow++) {
+ const int iw = iw_base + ow * sw;
+ const ptrdiff_t im_idx
+ = im_idx_base + iw * im_iw_stride;
+ col[col_idx_base + ow] = im[im_idx] + shift;
+ }
+ for (int ow = ow_end; ow < wb; ow++)
+ col[col_idx_base + ow] = shift;
+ }
+ });
+ }
+}
+
+template void im2col_u8<int8_t>(const jit_gemm_conv_conf_t &jcp,
+ const int8_t *__restrict im, int8_t *__restrict imtr,
+ uint8_t *__restrict col, int hs, int hb, int ws, int wb);
+template void im2col_u8<uint8_t>(const jit_gemm_conv_conf_t &jcp,
+ const uint8_t *__restrict im, uint8_t *__restrict imtr,
+ uint8_t *__restrict col, int hs, int hb, int ws, int wb);
+
+/* im[ih][iw][ic] <-- col2im_s32(col[oh][ow][kh][kw][ic]) */
+void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col,
+ int32_t *__restrict im)
+{
+ parallel(0, [&](const int ithr, const int nthr) {
+ int h_nthr = nstl::min(jcp.ih, nthr);
+ int w_nthr = nstl::min(jcp.iw, nthr / h_nthr);
+ int h_ithr = 1, h_s = 0, h_e = 0, w_ithr = 1, w_s = 0, w_e = 0;
+ if (ithr < h_nthr * w_nthr) {
+ h_ithr = ithr / w_nthr;
+ w_ithr = ithr % w_nthr;
+ balance211(jcp.ih, h_nthr, h_ithr, h_s, h_e);
+ balance211(jcp.iw, w_nthr, w_ithr, w_s, w_e);
+ } else {
+ h_ithr = w_ithr = -ithr;
+ h_s = h_e = w_s = w_e = -1;
+ }
+
+ for (int ih = h_s; ih < h_e; ++ih) {
+ for (int iw = w_s; iw < w_e; ++iw) {
+ PRAGMA_OMP_SIMD()
+ for (int ic = 0; ic < jcp.ic; ++ic) {
+ im[(ih * jcp.iw + iw) * jcp.ic + ic] = 0;
+ }
+ }
+ }
+
+ // TODO: reduce region: [0.. oh] --> [h_s * sh .. h_e * sh]
+ for (int oh = 0; oh < jcp.oh; ++oh) {
+ for (int ow = 0; ow < jcp.ow; ++ow) {
+ for (int kh = 0; kh < jcp.kh; ++kh) {
+ const int ih = oh * jcp.stride_h
+ - jcp.t_pad + kh * (1 + jcp.dilate_h);
+ if (ih < h_s || ih >= h_e) continue;
+
+ for (int kw = 0; kw < jcp.kw; ++kw) {
+ const int iw = ow * jcp.stride_w
+ - jcp.l_pad + kw * (1 + jcp.dilate_w);
+ if (iw < w_s || iw >= w_e) continue;
+
+ const size_t col_idx = (((oh * jcp.ow + ow) * jcp.kh
+ + kh) * jcp.kw + kw) * jcp.ic;
+ const size_t im_idx
+ = (ih * jcp.iw + iw) * jcp.ic;
+ PRAGMA_OMP_SIMD()
+ for (int ic = 0; ic < jcp.ic; ++ic) {
+ im[im_idx + ic] += col[col_idx + ic];
+ }
+ }
+ }
+ }
+ }
+ });
+}
+
+void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im,
+ int od)
+{
+ parallel_nd(jcp.ic, [&](int ic) {
+ const float *__restrict col_ = col + (size_t)ic * jcp.ks * jcp.os;
+ float *__restrict im_ic = im + (size_t)ic * jcp.ih * jcp.iw * jcp.id;
+
+ int id = od * jcp.stride_d - jcp.f_pad;
+ for (int kd = 0; kd < jcp.kd; ++kd) {
+ if (id < 0 || id >= jcp.id) {
+ col_ += jcp.kh * jcp.kw * jcp.os;
+ id += (1 + jcp.dilate_d);
+ continue;
+ }
+
+ float *__restrict im_ = im_ic + id * jcp.ih * jcp.iw;
+
+ for (int oh = 0; oh < jcp.oh; ++oh) {
+ for (int kh = 0; kh < jcp.kh; ++kh) {
+ const int ih = oh * jcp.stride_h - jcp.t_pad
+ + kh * (1 + jcp.dilate_h);
+ if (ih < 0 || ih >= jcp.ih) continue;
+
+ for (int ow = 0; ow < jcp.ow; ++ow) {
+ for (int kw = 0; kw < jcp.kw; ++kw) {
+ const int iw = ow * jcp.stride_w - jcp.l_pad
+ + kw * (1 + jcp.dilate_w);
+ if (iw < 0 || iw >= jcp.iw) continue;
+
+ const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
+ const size_t im_idx = ih*jcp.iw + iw;
+ im_[im_idx] += col_[col_idx];
+ }}
+ }}
+
+ col_ += jcp.kh * jcp.kw * jcp.os;
+ id += (1 + jcp.dilate_d);
+ }
+ });
+}
+
+void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im) {
+ const size_t col_step = jcp.ks * jcp.os;
+ const size_t im_step = jcp.ih * jcp.iw;
+ const int iS = jcp.ih * jcp.iw;
+
+ parallel_nd(jcp.ic, [&](int ic) {
+ float *__restrict im_ = im + ic * im_step;
+ const float *__restrict col_ = col + ic * col_step;
+ PRAGMA_OMP_SIMD()
+ for (int is = 0; is < iS; ++is) im_[is] = 0.;
+
+ for (int kh = 0; kh < jcp.kh; ++kh) {
+ for (int oh = 0; oh < jcp.oh; ++oh) {
+ const int ih =
+ oh * jcp.stride_h - jcp.t_pad + kh * (1 + jcp.dilate_h);
+ if (ih < 0 || ih >= jcp.ih) continue;
+
+ for (int kw = 0; kw < jcp.kw; ++kw) {
+ for (int ow = 0; ow < jcp.ow; ++ow) {
+ const int iw =
+ ow * jcp.stride_w - jcp.l_pad + kw * (1 + jcp.dilate_w);
+ if (iw < 0 || iw >= jcp.iw) continue;
+
+ const size_t col_idx = ((kh*jcp.kw + kw)*jcp.oh+oh)*jcp.ow+ow;
+ const size_t im_idx = ih*jcp.iw + iw;
+ im_[im_idx] += col_[col_idx];
+ }
+ }
+ }
+ }
+ });
+}
+
+status_t init_conf(jit_gemm_conv_conf_t &jcp,
+ memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d, int max_threads) {
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ const int ndims = src_d.ndims();
+ const int is_1d = ndims == 3;
+ const int is_3d = ndims == 5;
+
+ jcp.prop_kind = cd.prop_kind;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.id = is_3d ? src_d.dims()[2] : 1;
+ jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
+ jcp.iw = src_d.dims()[ndims - 1];
+ jcp.od = is_3d ? dst_d.dims()[2] : 1;
+ jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
+ jcp.ow = dst_d.dims()[ndims - 1];
+
+ jcp.kd = is_3d ? weights_d.dims()[with_groups + 2] : 1;
+ jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
+ jcp.kw = weights_d.dims()[with_groups + ndims - 1];
+
+ jcp.f_pad = is_3d ? cd.padding[0][0] : 0;
+ jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
+ jcp.l_pad = cd.padding[0][ndims - 3];
+
+ jcp.stride_d = is_3d ? cd.strides[0] : 1;
+ jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
+ jcp.stride_w = cd.strides[ndims - 3];
+
+ jcp.dilate_d = is_3d ? cd.dilates[0] : 0;
+ jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
+ jcp.dilate_w = cd.dilates[ndims - 3];
+
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef
+ || cd.diff_bias_desc.format_kind != format_kind::undef;
+
+ jcp.is = jcp.ih * jcp.iw;
+ jcp.os = jcp.oh * jcp.ow;
+ jcp.ks = jcp.kh * jcp.kw * jcp.kd;
+
+ jcp.signed_input = src_d.data_type() == data_type::s8;
+
+ jcp.im2col_sz = !everyone_is(true,
+ jcp.ow == jcp.iw, jcp.oh == jcp.ih, jcp.od == jcp.id,
+ jcp.stride_w == 1, jcp.stride_h == 1, jcp.stride_d == 1,
+ jcp.ks == 1, !jcp.signed_input)
+ ? (ptrdiff_t)jcp.ic * jcp.ks * jcp.os : 0;
+
+ jcp.outer_threading = false;
+
+ bool is_int8_conv = utils::one_of(src_d.data_type(), s32, s8, u8)
+ && weights_d.data_type() == s8;
+
+ const int vlen = mayiuse(avx512_common)
+ ? cpu_isa_traits<avx512_common>::vlen
+ : mayiuse(avx)
+ ? cpu_isa_traits<avx>::vlen
+ : mayiuse(sse42) ? cpu_isa_traits<sse42>::vlen : 4;
+ const int simd_w = vlen / (is_int8_conv ? 1 : 4);
+
+ const bool is_bwd_d = jcp.prop_kind == backward_data;
+ const bool is_bwd_w = jcp.prop_kind == backward_weights;
+ const bool is_fwd = !is_bwd_d && !is_bwd_w;
+ jcp.oh_block = is_fwd ? jcp.oh : jcp.ih;
+ jcp.ow_block = is_fwd ? jcp.ow : jcp.iw;
+
+ using namespace memory_tracking::names;
+ bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
+
+ // TODO: maybe mitigate blocking restriction
+ const int wei_size = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
+ const int L2 = get_cache_size(2, true)
+ / (is_int8_conv ? sizeof(int8_t) : sizeof(float));
+ bool is_blocking_applicable = true
+ && is_fwd && jcp.im2col_sz
+ && jcp.id == 1 && jcp.od == 1
+ && jcp.dilate_h == 0 && jcp.dilate_w == 0
+ && !is_depthwise
+ && wei_size < L2/2;
+ if (is_blocking_applicable) {
+ // looking for oh and ow blocking
+ int h_block{ jcp.oh_block }, w_block{ jcp.ow_block };
+ const int ic = jcp.ic;
+ const int oc = jcp.oc;
+ const int iw = jcp.iw;
+ const int ow = jcp.ow;
+ const int oh = jcp.oh;
+ const int os = oh * ow;
+
+ // 1. cache requirement
+ int row_size = ic * ow * jcp.ks + 2 * (ic * iw + oc * ow);
+ if (is_int8_conv) {
+ // Heuristic rule: gemm needed a lot of memory for internal usage
+ row_size *= 5;
+ // memory for accumulators
+ row_size += oc * ow * sizeof(uint32_t);
+ // memory for transposition
+ row_size += ic * iw;
+ }
+
+ h_block = nstl::max(1, nstl::min(oh, div_up(L2, row_size)));
+ if (h_block == 1) {
+ int col_size = ic * jcp.ks + 2 * (ic + oc);
+ if (is_int8_conv) {
+ col_size *= 5;
+ col_size += oc * sizeof(uint32_t);
+ col_size += ic;
+ }
+ w_block = nstl::max(1, nstl::min(ow, div_up(L2, col_size)));
+ }
+
+ // 2. threading requirement
+ if (h_block != oh)
+ h_block = nstl::max(1, rnd_dn(h_block, 4));
+ if (w_block != ow)
+ w_block = nstl::max(1, rnd_dn(w_block, simd_w));
+
+ float thr_eff = 0.f;
+ float thr_eff_treshold = 0.9f;
+ if (w_block == ow) {
+ do {
+ int nb_h = div_up(oh, h_block);
+ size_t work = jcp.ngroups * jcp.mb * jcp.od * nb_h;
+ float disb = (float)oh / rnd_up(oh, h_block);
+ thr_eff = (float)work / rnd_up(work, max_threads);
+ thr_eff = (thr_eff + disb) / 2.f;
+ if (thr_eff >= thr_eff_treshold)
+ break;
+ h_block = rnd_dn(h_block - 4, 4);
+ } while (h_block > 0);
+ }
+ if (thr_eff < thr_eff_treshold) // we didn't find suitable h_block
+ {
+ h_block = 1;
+ int nb_h = oh;
+ do {
+ int nb_w = div_up(ow, w_block);
+ size_t work_amount = jcp.ngroups * jcp.mb * nb_h * nb_w;
+ float disb = (float)ow / rnd_up(ow, w_block);
+ thr_eff = (float)work_amount / rnd_up(work_amount, max_threads);
+ thr_eff = (thr_eff + disb) / 2.f;
+ if (thr_eff > thr_eff_treshold)
+ break;
+ w_block = rnd_dn(w_block - simd_w, simd_w);
+ } while (w_block > 0);
+ }
+ h_block = nstl::max(1, h_block);
+ w_block = nstl::max(1, w_block);
+ const size_t inner_work = div_up(os, simd_w) * div_up(oc, simd_w);
+ const float inner_thr_eff
+ = (float)inner_work / rnd_up(inner_work, max_threads);
+ if (thr_eff >= inner_thr_eff / 2 && h_block > 0 && w_block > 0) {
+ jcp.oh_block = h_block;
+ jcp.ow_block = w_block;
+ jcp.outer_threading = true;
+ }
+ // updating jcp.im2col_sz
+ if (jcp.oh_block != 1)
+ jcp.ow_block = ow;
+ jcp.im2col_sz = (ptrdiff_t)ic * jcp.ks * jcp.oh_block * jcp.ow_block;
+ }
+ // For threading selection in bwd_d we do:
+ // 1. Rough estimation of efficiency for inner and outer threading.
+ // 2. Gemm size estimation in assumption that it does not work
+ // so effectively for small sizes.
+ // 64K - this is heuristic gemm size per thread threshold.
+ const int gemm_thrld = 64 * 1024;
+
+ if (is_int8_conv) {
+ if (is_fwd) {
+ if (!jcp.outer_threading) {
+ bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
+ const size_t outer_work = jcp.ngroups * jcp.mb;
+ const float outer_thr_eff
+ = (float)outer_work / rnd_up(outer_work, max_threads);
+ const size_t inner_work
+ = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
+ const float inner_thr_eff
+ = (float)inner_work / rnd_up(inner_work, max_threads);
+ jcp.outer_threading = (is_depthwise
+ || (jcp.is / max_threads < 64 && jcp.mb != 1))
+ && (outer_thr_eff / inner_thr_eff >= 1.f
+ || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
+ }
+ jcp.nthr = jcp.outer_threading ? max_threads : 1;
+ scratchpad.book(key_conv_gemm_col,
+ sizeof(int8_t) * jcp.nthr * jcp.im2col_sz);
+ scratchpad.book(key_conv_int_dat_in_acc_dt,
+ sizeof(int32_t) * jcp.nthr * jcp.oh_block * jcp.ow_block * jcp.oc);
+ scratchpad.book(key_conv_gemm_imtr,
+ sizeof(int8_t) * jcp.nthr * jcp.is * jcp.ic);
+ } else if (is_bwd_d) {
+ bool is_depthwise = jcp.ic == 1 && jcp.oc == 1 && jcp.ngroups != 1;
+ const size_t outer_work = jcp.ngroups * jcp.mb;
+ const float outer_thr_eff
+ = (float)outer_work / rnd_up(outer_work, max_threads);
+ const size_t inner_work
+ = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
+ const float inner_thr_eff
+ = (float)inner_work / rnd_up(inner_work, max_threads);
+ jcp.outer_threading = (is_depthwise
+ || (jcp.is / max_threads < 64 && jcp.mb != 1))
+ && (outer_thr_eff / inner_thr_eff >= 1.f
+ || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
+
+ jcp.nthr = jcp.outer_threading ? max_threads : 1;
+ scratchpad.book(key_conv_gemm_col,
+ sizeof(int32_t) * jcp.nthr * jcp.im2col_sz);
+ scratchpad.book(key_conv_int_dat_in_acc_dt,
+ sizeof(int32_t) * jcp.nthr * jcp.is * jcp.ic);
+ } else if (is_bwd_w) {
+ assert(!"unimplemented prop_kind");
+ return status::unimplemented;
+ }
+ } else {
+ if (is_fwd) {
+ if (!jcp.outer_threading) {
+ const size_t outer_work_amount = jcp.ngroups * jcp.mb * jcp.od;
+ const float outer_thr_eff = (float)outer_work_amount
+ / rnd_up(outer_work_amount, max_threads);
+ const size_t inner_work_amount
+ = div_up(jcp.os, simd_w) * div_up(jcp.oc, simd_w);
+ const float inner_thr_eff = (float)inner_work_amount
+ / rnd_up(inner_work_amount, max_threads);
+ jcp.outer_threading = jcp.os / max_threads < 512
+ && IMPLICATION(jcp.od == 1, jcp.mb != 1 || jcp.ngroups > 2)
+ && (outer_thr_eff / inner_thr_eff >= 1.f
+ || (jcp.os * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
+ }
+ } else if (is_bwd_d) {
+ const size_t outer_work_amount = jcp.ngroups * jcp.mb;
+ const float outer_thr_eff = (float)outer_work_amount
+ / rnd_up(outer_work_amount, max_threads);
+ const size_t inner_work
+ = div_up(jcp.is, simd_w) * div_up(jcp.ic, simd_w);
+ const float inner_thr_eff = (float)inner_work
+ / rnd_up(inner_work, max_threads);
+ jcp.outer_threading = (jcp.os / max_threads < 512 || jcp.ks < 64)
+ && (jcp.mb != 1 || jcp.ngroups > 2)
+ && (outer_thr_eff / inner_thr_eff >= 1.f
+ || (jcp.is * jcp.ic * jcp.oc) / max_threads < gemm_thrld);
+ } else if (is_bwd_w)
+ jcp.outer_threading = jcp.os / max_threads < 256
+ && (jcp.mb != 1 || jcp.ngroups > 2);
+
+ jcp.nthr = jcp.outer_threading ? max_threads : 1;
+ scratchpad.book(key_conv_gemm_col,
+ sizeof(float) * jcp.nthr * jcp.im2col_sz);
+
+ if (is_bwd_w) {
+ jcp.need_wei_reduction = mkldnn_thr_syncable()
+ ? jcp.mb != 1 && jcp.nthr != 1 : false;
+ scratchpad.book(key_conv_wei_reduction,
+ sizeof(float) * jcp.nthr * jcp.ngroups * weights_d.size());
+ }
+ }
+
+ return status::success;
+}
+
+void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb, int &ithr_g,
+ int &nthr_g, int &ithr_mb, int &nthr_mb) {
+ nthr_g = nstl::min(ngroups, nthr);
+ nthr_mb = nstl::min(mb, nthr / nthr_g);
+ if (ithr / nthr_mb >= ngroups) {
+ ithr_g = ithr_mb = -1;
+ } else {
+ ithr_g = ithr / nthr_mb;
+ ithr_mb = ithr % nthr_mb;
+ }
+}
+
+void bwd_weights_reduction_par(int ithr, int nthr,
+ const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws,
+ float *weights) {
+ const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
+
+ size_t weights_start{0}, weights_end{0};
+ balance211(weights_g_size, nthr, ithr, weights_start, weights_end);
+
+ for (int i = 0; i < nthr; ++i) {
+ const float *ws_i = weights_reduce_ws + i * weights_g_size;
+ for (size_t s = weights_start; s < weights_end; ++s)
+ weights[s] = (i == 0 ? 0 : weights[s]) + ws_i[s];
+ }
+}
+
+};
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp
new file mode 100644
index 0000000000..e006789344
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_convolution_utils.hpp
@@ -0,0 +1,66 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP
+#define CPU_JIT_GEMM_CONVOLUTION_UTILS_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_engine.hpp"
+#include "jit_primitive_conf.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace jit_gemm_convolution_utils {
+
+void im2col_3d(const jit_gemm_conv_conf_t &jcp, const float *im, float *col,
+ int od);
+void im2col(const jit_gemm_conv_conf_t &jcp, const float *__restrict im,
+ float *__restrict col, int hs, int hb, int ws, int wb);
+template <typename T>
+void im2col_u8(const jit_gemm_conv_conf_t &jcp, const T *__restrict im,
+ T* __restrict imtr, uint8_t *__restrict col,
+ int hs, int hb, int ws, int wb);
+
+void col2im_s32(const jit_gemm_conv_conf_t &jcp, const int32_t *__restrict col,
+ int32_t *__restrict im);
+void col2im_3d(const jit_gemm_conv_conf_t &jcp, const float *col, float *im,
+ int od);
+void col2im(const jit_gemm_conv_conf_t &jcp, const float *col, float *im);
+
+status_t init_conf(jit_gemm_conv_conf_t &jcp,
+ memory_tracking::registrar_t &scratchpad, const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d, int max_threads);
+
+void bwd_weights_balance(int ithr, int nthr, int ngroups, int mb,
+ int &ithr_g, int &nthr_g, int &ithr_mb, int &nthr_mb);
+void bwd_weights_reduction_par(int ithr, int nthr,
+ const jit_gemm_conv_conf_t &jcp, const float *weights_reduce_ws,
+ float *weights);
+
+}
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp
new file mode 100644
index 0000000000..2872122f0d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.cpp
@@ -0,0 +1,156 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "gemm_inner_product.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::data_type;
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::primitive_kind;
+
+template <impl::data_type_t data_type>
+void gemm_inner_product_fwd_t<data_type>::execute_forward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC_total_padded();
+
+ bool wei_tr = !memory_desc_matches_one_of_tag(
+ *pd()->weights_md(), hwio, dhwio, io);
+
+ const auto &post_ops = pd()->attr()->post_ops_;
+ const bool do_relu = post_ops.len_ == 1;
+
+ float alpha = 1.0, beta = 0.0;
+ extended_sgemm(wei_tr ? "T" : "N", "N", &OC, &MB, &IC, &alpha, weights,
+ wei_tr ? &IC : &OC, src, &IC, &beta, dst, &OC, bias);
+
+ if (do_relu) {
+ float nslope = post_ops.entry_[0].eltwise.alpha;
+ parallel_nd(MB, OC, [&](int mb, int oc) {
+ size_t dst_off = mb * OC + oc;
+ if (dst[dst_off] < 0)
+ dst[dst_off] *= nslope;
+ });
+ }
+}
+
+template <impl::data_type_t data_type>
+void gemm_inner_product_bwd_data_t<data_type>::execute_backward_data(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC_total_padded();
+
+ bool wei_tr = memory_desc_matches_one_of_tag(
+ *pd()->weights_md(), hwio, dhwio, io);
+
+ float alpha = 1.0, beta = 0.0;
+ extended_sgemm(wei_tr ? "T" : "N", "N", &IC, &MB, &OC, &alpha, weights,
+ wei_tr ? &OC : &IC, diff_dst, &OC, &beta, diff_src, &IC);
+}
+
+template <impl::data_type_t data_type>
+void gemm_inner_product_bwd_weights_t<data_type>::execute_backward_weights(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
+
+ diff_dst += diff_dst_d.offset0();
+
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC_total_padded();
+
+ bool wei_tr = memory_desc_matches_one_of_tag(
+ *pd()->diff_weights_md(), hwio, dhwio, io);
+
+ float alpha = 1.0, beta = 0.0;
+ if (wei_tr)
+ extended_sgemm("N", "T", &OC, &IC, &MB, &alpha, diff_dst, &OC, src, &IC,
+ &beta, diff_weights, &OC);
+ else
+ extended_sgemm("N", "T", &IC, &OC, &MB, &alpha, src, &IC, diff_dst, &OC,
+ &beta, diff_weights, &IC);
+
+ if (diff_bias) {
+ diff_bias += diff_bias_d.offset0();
+ constexpr int blksize = 8;
+ const int OC_blocks = OC / blksize;
+ const int rem_OC = OC % blksize;
+ parallel(0, [&](const int ithr, const int nthr) {
+ int oc_st{0}, oc_e{0};
+ balance211(OC_blocks, nthr, ithr, oc_st, oc_e);
+ oc_st = oc_st * blksize;
+ oc_e = oc_e * blksize;
+
+ PRAGMA_OMP_SIMD()
+ for (int oc = oc_st; oc < oc_e; ++oc) {
+ diff_bias[oc] = diff_dst[oc];
+ }
+
+ for (int mb = 1; mb < MB; ++mb) {
+ PRAGMA_OMP_SIMD()
+ for (int oc = oc_st; oc < oc_e; ++oc) {
+ diff_bias[oc] += diff_dst[mb * OC + oc];
+ }
+ }
+
+ if (rem_OC != 0 && ithr == nthr-1) {
+ for (int oc = OC_blocks * blksize; oc < OC; oc++)
+ diff_bias[oc] = diff_dst[oc];
+ for (int mb = 1; mb < MB; ++mb) {
+ for (int oc = OC_blocks * blksize; oc < OC; oc++) {
+ diff_bias[oc] += diff_dst[mb * OC + oc];
+ }
+ }
+ }
+ });
+ }
+}
+
+template struct gemm_inner_product_fwd_t<data_type::f32>;
+template struct gemm_inner_product_bwd_data_t<data_type::f32>;
+template struct gemm_inner_product_bwd_weights_t<data_type::f32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp
new file mode 100644
index 0000000000..acf0a49b9a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_inner_product.hpp
@@ -0,0 +1,157 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_GEMM_INNER_PRODUCT_HPP
+#define CPU_GEMM_INNER_PRODUCT_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "gemm/gemm.hpp"
+
+#include "cpu_inner_product_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type>
+struct gemm_inner_product_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_inner_product_fwd_pd_t {
+ using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_fwd_t);
+
+ status_t init() {
+ using namespace utils;
+
+ bool ok = true
+ && set_default_params() == status::success
+ && is_fwd()
+ && !has_zero_dim_memory()
+ && everyone_is(data_type,
+ src_md()->data_type,
+ weights_md()->data_type,
+ dst_md()->data_type,
+ with_bias() ? weights_md(1)->data_type : data_type)
+ && attr()->output_scales_.has_default_values()
+ && attr()->post_ops_.len_ <= 1
+ && IMPLICATION(attr()->post_ops_.len_ == 1,
+ attr()->post_ops_.entry_[0].is_relu(true, false))
+ && dense_gemm_consitency_check(src_md(), weights_md(),
+ dst_md());
+ return ok ? status::success : status::unimplemented;
+ }
+ };
+
+ gemm_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t data_type>
+struct gemm_inner_product_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_inner_product_bwd_data_pd_t {
+ using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t;
+
+ DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && set_default_params() == status::success
+ && desc()->prop_kind == prop_kind::backward_data
+ && !has_zero_dim_memory()
+ && utils::everyone_is(data_type,
+ diff_src_md()->data_type,
+ weights_md()->data_type,
+ diff_dst_md()->data_type)
+ && attr()->has_default_values()
+ && dense_gemm_consitency_check(diff_src_md(), weights_md(),
+ diff_dst_md());
+ return ok ? status::success : status::unimplemented;
+ }
+ };
+
+ gemm_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t data_type>
+struct gemm_inner_product_bwd_weights_t: public cpu_primitive_t {
+ struct pd_t: public cpu_inner_product_bwd_weights_pd_t {
+ using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t;
+
+ DECLARE_COMMON_PD_T(GEMM_IMPL_STR, gemm_inner_product_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && set_default_params() == status::success
+ && desc()->prop_kind == prop_kind::backward_weights
+ && !has_zero_dim_memory()
+ && utils::everyone_is(data_type,
+ src_md()->data_type,
+ diff_weights_md()->data_type,
+ diff_dst_md()->data_type,
+ with_bias() ? diff_weights_md(1)->data_type : data_type)
+ && attr()->has_default_values()
+ && dense_gemm_consitency_check(src_md(), diff_weights_md(),
+ diff_dst_md());
+
+ return ok ? status::success : status::unimplemented;
+ }
+ };
+
+ gemm_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_weights(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_weights(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp
new file mode 100644
index 0000000000..fed7e4d693
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.cpp
@@ -0,0 +1,740 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "utils.hpp"
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+#include "math_utils.hpp"
+
+#include "simple_q10n.hpp"
+
+#include "gemm/gemm.hpp"
+#include "gemm_x8s8s32x_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::math;
+using namespace mkldnn::impl::memory_tracking::names;
+
+template <data_type_t src_type, data_type_t dst_type>
+void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
+execute_forward(const exec_ctx_t &ctx) const {
+ auto src_base = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst_base = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
+ assert(IMPLICATION(
+ jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
+ assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
+
+ parallel(jcp.nthr, [&](const int ithr, const int nthr) {
+ execute_forward_thr(ithr, nthr, src_base, wei_base, bia_base, dst_base,
+ scratchpad);
+ });
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+_gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::pp_ker_t(
+ const pd_t *pd)
+ : ker_(nullptr)
+ , jcp_(pd->jcp_)
+ , OC_(pd->jcp_.oc)
+ , OS_(pd->jcp_.os)
+ , bias_data_type_(data_type::undef)
+ , bias_data_type_size_(0)
+ , scale_idx_mult_(0)
+ , do_bias_(false)
+ , do_relu_(false)
+ , do_sum_(false)
+{
+ using namespace types;
+
+ const auto dst_md = memory_desc_wrapper(pd->dst_md());
+ dst_os_stride_ = dst_md.blk_off(0, 0, 0, 1);
+
+ scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
+
+ auto &post_ops = pd->attr()->post_ops_;
+
+ int entry_idx = -1;
+ for (int idx = 0; idx < post_ops.len_; ++idx) {
+ const auto &e = post_ops.entry_[idx];
+ if (e.is_relu(true, false)) {
+ entry_idx = idx;
+ break;
+ }
+ }
+ do_relu_ = entry_idx >= 0;
+
+ do_signed_scaling_ = jcp_.signed_input;
+
+ do_sum_ = post_ops.contain(primitive_kind::sum, 0);
+ do_bias_ = pd->with_bias();
+ bias_data_type_ = pd->desc()->bias_desc.data_type;
+ if (do_bias_) {
+ assert(bias_data_type_ != data_type::undef);
+ bias_data_type_size_ = data_type_size(bias_data_type_);
+ }
+ const size_t vlen_start
+ = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+
+ for (size_t i = vlen_start; i > 0; i--) {
+ if (OC_ % i == 0) {
+ vlen_ = i;
+ break;
+ }
+ }
+
+ if (!mayiuse(avx512_core))
+ // use fallback code for older CPUs
+ return;
+ else
+ generate();
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::generate()
+{
+ using namespace Xbyak;
+ using namespace utils;
+
+ // TODO: clean-up
+ Reg64 reg_param = abi_param1;
+ Reg64 reg_dst = rdx;
+ Reg64 reg_acc = rax;
+ Reg64 reg_bias = rbx;
+ Reg64 reg_scales = rsi;
+
+ Reg64 reg_len = r8;
+ Reg64 reg_tmp = rcx; // intentional for shifting purposes
+ Reg64 reg_oc_offset = r9;
+ Reg64 reg_rem_mask_short = r10;
+ Reg64 reg_rem_mask_vlen = r11;
+ Opmask kreg_rem_mask_short = k1;
+ Opmask kreg_rem_mask_vlen = k3;
+ Opmask kreg_relu_cmp = k2;
+
+ const size_t vlen = vlen_;
+
+ Zmm vreg_zero = Zmm(0);
+ Zmm vreg_scale = Zmm(1);
+ Zmm vreg_nslope = Zmm(2);
+ Zmm vreg_sum_scale = Zmm(3);
+ Zmm vreg_signed_scale = Zmm(4);
+
+ size_t def_unroll = 4;
+ size_t max_unroll = 12;
+ size_t zmm_step = 2;
+ if (do_sum_) {
+ max_unroll = 8;
+ zmm_step = 3;
+ }
+
+ auto vreg_dst = [&](int idx) {
+ return Zmm(5 + idx * zmm_step + 0);
+ };
+ auto vreg_bias = [&](int idx) {
+ return Zmm(5 + idx * zmm_step + 1);
+ };
+ auto vreg_prev_dst = [&](int idx) {
+ return Zmm(5 + idx * zmm_step + 2);
+ };
+
+ preamble();
+
+#define PARAM_OFF(x) offsetof(ker_args, x)
+ mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
+ mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
+ mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
+ mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
+ mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
+ mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
+ vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]);
+ vbroadcastss(vreg_sum_scale, ptr[reg_param + PARAM_OFF(sum_scale)]);
+ vbroadcastss(vreg_signed_scale, ptr[reg_param + PARAM_OFF(signed_scale)]);
+ if (scale_idx_mult_ == 0)
+ vbroadcastss(vreg_scale, dword[reg_scales]);
+
+#undef PARAM_OFF
+
+ mov(reg_rem_mask_vlen, 1);
+ shl(reg_rem_mask_vlen, vlen);
+ sub(reg_rem_mask_vlen, 1);
+ kmovq(kreg_rem_mask_vlen, reg_rem_mask_vlen);
+
+ if (do_relu_ || dst_type == data_type::u8)
+ vxorps(vreg_zero, vreg_zero, vreg_zero);
+
+ // Load accumulated value, convert to float, apply sum (if any),
+ // bias (if any), scaling, and relu (if any);
+ // then convert to destination type and store
+ auto compute = [&](size_t offset, int idx, bool apply_mask) {
+ auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
+
+ if (scale_idx_mult_ > 0) {
+ assert(scale_idx_mult_ == 1);
+ auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
+ auto vreg_scale_ = vreg_scale;
+ if (apply_mask)
+ vreg_scale_ = vreg_scale_ | kreg_rem_mask_short;
+ else
+ vreg_scale_ = vreg_scale_ | kreg_rem_mask_vlen;
+ vmovups(vreg_scale_, scale_addr);
+ }
+
+ auto vreg_dst_ = vreg_dst(idx);
+ if (apply_mask)
+ vreg_dst_ = vreg_dst_ | kreg_rem_mask_short;
+ else
+ vreg_dst_ = vreg_dst_ | kreg_rem_mask_vlen;
+ vcvtdq2ps(vreg_dst_, acc_addr);
+
+ if (do_signed_scaling_)
+ vmulps(vreg_dst(idx), vreg_dst(idx), vreg_signed_scale);
+
+ if (do_bias_) {
+ auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
+ auto vreg_bias_ = vreg_bias(idx);
+ if (apply_mask)
+ vreg_bias_ = vreg_bias_ | kreg_rem_mask_short;
+ else
+ vreg_bias_ = vreg_bias_ | kreg_rem_mask_vlen;
+
+ switch (bias_data_type_) {
+ case data_type::s8:
+ vpmovsxbd(vreg_bias_, bias_addr);
+ break;
+ case data_type::u8:
+ vpmovzxbd(vreg_bias_, bias_addr);
+ break;
+ case data_type::s32:
+ case data_type::f32:
+ vmovups(vreg_bias_, bias_addr);
+ break;
+ default: assert(!"unimplemented");
+ }
+ if (bias_data_type_ != data_type::f32)
+ vcvtdq2ps(vreg_bias(idx), vreg_bias(idx));
+ vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
+ }
+
+ vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
+
+ auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
+
+ if (do_sum_)
+ {
+ auto vreg_prev_dst_ = vreg_prev_dst(idx);
+ if (apply_mask)
+ vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_short;
+ else
+ vreg_prev_dst_ = vreg_prev_dst_ | kreg_rem_mask_vlen;
+
+ switch (dst_type) {
+ case data_type::f32:
+ case data_type::s32: vmovups(vreg_prev_dst_, dst_addr); break;
+ case data_type::s8: vpmovsxbd(vreg_prev_dst_, dst_addr); break;
+ case data_type::u8: vpmovzxbd(vreg_prev_dst_, dst_addr); break;
+ default: assert(!"unsupported data type");
+ }
+ if (dst_type != data_type::f32)
+ vcvtdq2ps(vreg_prev_dst(idx), vreg_prev_dst(idx));
+
+ vfmadd231ps(vreg_dst(idx), vreg_prev_dst(idx), vreg_sum_scale);
+ }
+
+ if (do_relu_) {
+ vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os);
+ vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope);
+ }
+
+ if (dst_type != data_type::f32) {
+ vcvtps2dq(vreg_dst(idx), vreg_dst(idx));
+ }
+
+ if (dst_type == data_type::u8)
+ vpmaxsd(vreg_dst(idx), vreg_dst(idx), vreg_zero);
+
+ switch (dst_type) {
+ case data_type::s8:
+ vpmovsdb(dst_addr, vreg_dst_);
+ break;
+ case data_type::u8:
+ vpmovusdb(dst_addr, vreg_dst_);
+ break;
+ case data_type::f32:
+ case data_type::s32:
+ vmovups(dst_addr, vreg_dst_);
+ break;
+ default: assert(!"unimplemented");
+ }
+ };
+
+ // Advance all pointers by an immediate
+ auto advance_ptrs_imm = [&](size_t offset) {
+ add(reg_dst, offset * sizeof(dst_data_t));
+ add(reg_acc, offset * sizeof(acc_data_t));
+ if (scale_idx_mult_) {
+ assert(scale_idx_mult_ == 1);
+ add(reg_scales, offset * sizeof(float));
+ }
+ if (do_bias_)
+ add(reg_bias, offset * bias_data_type_size_);
+ };
+
+ // Advance all pointers by a value stored in a register
+ auto advance_ptrs_reg = [&](Reg64 offset) {
+ lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
+ lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
+ if (scale_idx_mult_) {
+ assert(scale_idx_mult_ == 1);
+ lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
+ }
+ if (do_bias_)
+ lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
+ };
+
+ // Rewind pointers that point to data that is indexed by output channel
+ // (bias or per-oc scaling factors)
+ auto rewind_ptrs = [&]() {
+ if (do_bias_)
+ sub(reg_bias, OC_ * bias_data_type_size_);
+ if (scale_idx_mult_) {
+ assert(scale_idx_mult_ == 1);
+ sub(reg_scales, OC_ * sizeof(float));
+ }
+ add(reg_dst, (dst_os_stride_ - OC_) * sizeof(dst_data_t));
+ };
+
+ // <--------- OC --------------->
+ //
+ // ^ ................+..............+-------------+.......................
+ // | . : not accessed |Prologue loop| .
+ // | . +--------------+-------------+ .
+ // . | | .
+ // O . | Main loop (unrolled) | .
+ // S . | | .
+ // . +--------------+-------------+ .
+ // | . | Epilogue loop|not accessed : .
+ // v ................+--------------+.............+.......................
+
+ Label prologue_end;
+ cmp(reg_oc_offset, 0);
+ je(prologue_end, T_NEAR);
+
+ // Prologue loop
+ {
+ mov(reg_tmp, OC_);
+ sub(reg_tmp, reg_oc_offset);
+ cmp(reg_tmp, reg_len);
+ cmovg(reg_tmp, reg_len);
+ sub(reg_len, reg_tmp);
+
+ Label prologue_loop, prologue_loop_tail, prologue_loop_end;
+ cmp(reg_tmp, vlen);
+ jle(prologue_loop_tail, T_NEAR);
+ L(prologue_loop); {
+ compute(0, 0, false);
+ advance_ptrs_imm(vlen);
+ sub(reg_tmp, vlen);
+ cmp(reg_tmp, vlen);
+ jge(prologue_loop, T_NEAR);
+ }
+
+ L(prologue_loop_tail);
+ mov(reg_rem_mask_short, 1);
+ // cl == reg_tmp because reg_tmp <= vlen here
+ shl(reg_rem_mask_short, cl);
+ sub(reg_rem_mask_short, 1);
+ jz(prologue_loop_end, T_NEAR);
+
+ kmovq(kreg_rem_mask_short, reg_rem_mask_short);
+ compute(0, 0, true);
+ advance_ptrs_reg(reg_tmp);
+
+ L(prologue_loop_end);
+ rewind_ptrs();
+ }
+ L(prologue_end);
+
+ // Main loop
+ Label main_loop_end;
+ {
+ cmp(reg_len, OC_);
+ jle(main_loop_end, T_NEAR);
+
+ Label main_loop;
+ L(main_loop); {
+ size_t OC_loop, OC_tail;
+ if (OC_ < max_unroll * vlen) {
+ // Fully unroll small loops
+ OC_loop = 0;
+ OC_tail = OC_;
+ }
+ else {
+ OC_loop = vlen * def_unroll;
+ OC_tail = OC_ % OC_loop;
+ }
+
+ assert(!!OC_loop || !!OC_tail);
+
+ if (OC_tail % vlen) {
+ int vlen_tail = OC_tail % vlen;
+ unsigned tail_mask = (1 << vlen_tail) - 1;
+ mov(reg_tmp, tail_mask);
+ kmovq(kreg_rem_mask_short, reg_tmp);
+ }
+
+ if (OC_loop) {
+ mov(reg_tmp, rnd_dn(OC_, OC_loop));
+ Label oc_loop;
+ L(oc_loop); {
+ for (size_t offset = 0; offset < OC_loop; offset += vlen)
+ compute(offset, offset / vlen, false);
+ advance_ptrs_imm(OC_loop);
+ sub(reg_tmp, OC_loop);
+ jnz(oc_loop);
+ }
+ }
+
+ if (OC_tail) {
+ for (size_t offset = 0; offset < OC_tail; offset += vlen) {
+ bool use_mask = (offset + vlen) > OC_tail;
+ compute(offset, offset / vlen, use_mask);
+ }
+ advance_ptrs_imm(OC_tail);
+ }
+
+ rewind_ptrs();
+ sub(reg_len, OC_);
+ cmp(reg_len, OC_);
+ jge(main_loop, T_NEAR);
+ }
+ }
+ L(main_loop_end);
+
+ // Epilogue loop
+ Label epilogue_end;
+ {
+ cmp(reg_len, 0);
+ je(epilogue_end, T_NEAR);
+
+ Label epilogue_loop, epilogue_loop_tail;
+ cmp(reg_len, vlen);
+ jle(epilogue_loop_tail, T_NEAR);
+ L(epilogue_loop); {
+ compute(0, 0, false);
+ sub(reg_len, vlen);
+ advance_ptrs_imm(vlen);
+ cmp(reg_len, vlen);
+ jge(epilogue_loop, T_NEAR);
+ }
+
+ L(epilogue_loop_tail);
+ mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
+ mov(reg_rem_mask_short, 1);
+ shl(reg_rem_mask_short, cl); // reg_tmp == rcx and reg_tail < vlen
+ sub(reg_rem_mask_short, 1);
+ jz(epilogue_end, T_NEAR);
+ kmovq(kreg_rem_mask_short, reg_rem_mask_short);
+ compute(0, 0, true);
+ }
+
+ L(epilogue_end);
+
+ postamble();
+
+ ker_ = getCode<decltype(ker_)>();
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::pp_ker_t::operator ()
+ (dst_data_t *dst, const acc_data_t *acc, const char *bias,
+ const float *scales, float nslope, float sum_scale, float signed_scale,
+ int g, size_t start, size_t end)
+{
+ using math::get_bias;
+
+ if (end <= start)
+ return;
+
+ if (ker_) {
+ // JIT
+ ker_args args;
+ size_t oc_offset = start % OC_;
+ size_t os_offset = start / OC_;
+ args.acc = acc + start;
+ args.dst = dst + os_offset * dst_os_stride_ + oc_offset;
+ args.bias = bias + (g * jcp_.oc + oc_offset) * bias_data_type_size_;
+ args.scales = scales + scale_idx_mult_ * (g * jcp_.oc + oc_offset);
+ args.nslope = nslope;
+ args.sum_scale = sum_scale;
+ args.signed_scale = signed_scale;
+ args.len = end - start;
+ args.oc_offset = oc_offset;
+ ker_(&args);
+ }
+ else {
+ // Fallback
+ const size_t first_oc = start % OC_;
+ const size_t last_oc = (end - 1) % OC_;
+ const size_t first_os = start / OC_;
+ const size_t last_os = (end - 1) / OC_;
+ for (size_t os = first_os; os <= last_os; os++) {
+ const size_t start_oc = (os == first_os) ? first_oc : 0;
+ const size_t end_oc = (os == last_os) ? last_oc : OC_ - 1;
+ for (size_t oc = start_oc; oc <= end_oc; oc++) {
+ const size_t acc_off = os * jcp_.oc + oc;
+ const size_t dst_off = os * dst_os_stride_ + oc;
+
+ float d = (float)(acc[acc_off]);
+ if (jcp_.signed_input)
+ d *= signed_scale;
+
+ if (do_bias_)
+ d += get_bias(bias, g * jcp_.oc + oc,
+ bias_data_type_);
+
+ d *= scales[(g * jcp_.oc + oc) * scale_idx_mult_];
+ if (do_sum_)
+ d += sum_scale * dst[dst_off];
+ if (do_relu_ && d < 0)
+ d *= nslope;
+ dst[dst_off] = qz_a1b0<float, dst_data_t>()(d);
+ }
+ }
+ }
+};
+
+template <data_type_t src_type, data_type_t dst_type>
+void _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>::
+execute_forward_thr(const int ithr, const int nthr, const src_data_t *src_base,
+ const wei_data_t *wei_base, const char *bia_base, dst_data_t *dst_base,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
+ const auto src_md = memory_desc_wrapper(pd()->src_md());
+ const size_t src_mb_stride = src_md.blk_off(1);
+ const size_t src_g_stride = src_md.blk_off(0, 1) * jcp.ic;
+
+ const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
+ const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
+
+ const auto dst_md = memory_desc_wrapper(pd()->dst_md());
+ const size_t dst_mb_stride = dst_md.blk_off(1);
+ const size_t dst_g_stride = dst_md.blk_off(0, 1) * jcp.oc;
+
+ const float *scales = pd()->attr()->output_scales_.scales_;
+
+ const auto &post_ops = pd()->attr()->post_ops_;
+ const bool do_sum = post_ops.contain(primitive_kind::sum, 0);
+ const float sum_scale = do_sum ? post_ops.entry_[0].sum.scale : 0;
+
+ float nslope = 0;
+ for (int idx = 0; idx < post_ops.len_; ++idx) {
+ const auto &e = post_ops.entry_[idx];
+ if (e.is_relu(true, false)) {
+ nslope = e.eltwise.alpha;
+ break;
+ }
+ }
+
+ auto col = scratchpad.get<uint8_t>(key_conv_gemm_col)
+ + (ptrdiff_t)ithr * jcp.im2col_sz;
+ src_data_t *__restrict imtr = scratchpad.get<src_data_t>(key_conv_gemm_imtr)
+ + (ptrdiff_t)ithr * jcp.is * jcp.ic;
+ auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
+ + (ptrdiff_t)ithr * jcp.oh_block * jcp.ow_block * jcp.oc;
+
+ const ptrdiff_t offset = (ptrdiff_t)jcp.ngroups * jcp.ks * jcp.ic * jcp.oc;
+ const int32_t *_wei_comp = (const int32_t *)(wei_base + offset);
+
+ int g{ 0 }, n{ 0 }, ohb{ 0 }, owb{ 0 };
+ size_t start = 0, end = 0;
+
+ const int nb_oh = div_up(jcp.oh, jcp.oh_block);
+ const int nb_ow = div_up(jcp.ow, jcp.ow_block);
+ const size_t work_amount = jcp.ngroups * jcp.mb * nb_oh * nb_ow;
+ balance211(work_amount, nthr, ithr, start, end);
+ nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ohb,
+ nb_oh, owb, nb_ow);
+
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ int oh = ohb * jcp.oh_block;
+ int ow = owb * jcp.ow_block;
+ const src_data_t *__restrict src = src_base + n * src_mb_stride
+ + g * src_g_stride;
+ const wei_data_t *__restrict wei = wei_base + g * wei_g_stride;
+ dst_data_t *__restrict dst =
+ dst_base + n * dst_mb_stride + g * dst_g_stride;
+ const int32_t *wei_comp = _wei_comp + g * jcp.oc;
+ const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
+ const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
+
+ if (jcp.im2col_sz)
+ jit_gemm_convolution_utils::im2col_u8<src_data_t>(
+ jcp, src, imtr, col, oh, h_step, ow, w_step);
+
+ const int M = jcp.oc;
+ const int K = jcp.ks * jcp.ic;
+ const int N = h_step * w_step;
+ const int LDA = M * jcp.ngroups;
+ const int LDB = jcp.im2col_sz ? N : K;
+ const char *BT = jcp.im2col_sz ? "T" : "N";
+ const int8_t off_a = 0, off_b = 0;
+ const int32_t off_c = 0;
+ const float onef = 1.0, zerof = 0.0;
+ gemm_s8x8s32("N", BT, jcp.signed_input ? "C" : "F",
+ &M, &N, &K, &onef, wei, &LDA, &off_a,
+ jcp.im2col_sz ? col : (uint8_t *)src, &LDB, &off_b,
+ &zerof, acc, &M, jcp.signed_input ? wei_comp : &off_c);
+
+ auto wei_adj_scale =
+ (wei_md.extra().flags | memory_extra_flags::scale_adjust)
+ ? wei_md.extra().scale_adjust : 1.f;
+
+ parallel(0, [&](int ithr, int nthr) {
+ size_t start, end;
+ balance211((size_t)N * jcp.oc, nthr, ithr, start, end);
+ (*pp_ker_)(dst + (oh * jcp.ow + ow) * pp_ker_->dst_os_stride_,
+ acc, bia_base, scales, nslope, sum_scale,
+ 1.f / wei_adj_scale, g, start, end);
+ });
+
+ nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ohb, nb_oh,
+ owb, nb_ow);
+ }
+}
+
+template <data_type_t dst_type>
+void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
+execute_backward_data(const exec_ctx_t &ctx) const {
+ auto diff_dst_base = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
+ auto wei_base = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bia_base = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto diff_src_base = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
+ parallel(jcp.nthr, [&](const int ithr, const int nthr) {
+ execute_backward_data_thr(ithr, nthr, diff_dst_base, wei_base,
+ bia_base, diff_src_base, scratchpad);
+ });
+}
+
+template <data_type_t dst_type>
+void _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>::
+execute_backward_data_thr(const int ithr, const int nthr,
+ const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
+ const char *bia_base, diff_src_data_t *diff_src_base,
+ const memory_tracking::grantor_t &scratchpad) const
+{
+ const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
+ const auto diff_dst_md = memory_desc_wrapper(pd()->diff_dst_md());
+ const size_t diff_dst_mb_stride = diff_dst_md.blk_off(1);
+ const size_t diff_dst_g_stride = diff_dst_md.blk_off(0, 1) * jcp.oc;
+
+ const auto wei_md = memory_desc_wrapper(pd()->weights_md(0));
+ const size_t wei_g_stride = pd()->with_groups() ? wei_md.blk_off(1) : 0;
+
+ const auto diff_src_md = memory_desc_wrapper(pd()->diff_src_md());
+ const size_t diff_src_mb_stride = diff_src_md.blk_off(1);
+ const size_t diff_src_g_stride = diff_src_md.blk_off(0, 1) * jcp.ic;
+ const size_t diff_src_os_stride = diff_src_md.blk_off(0, 0, 0, 1);
+
+ /* scale_idx_mult = 1 for per_oc scales and 0, otherwise */
+ const int scale_idx_mult = pd()->attr()->output_scales_.mask_ == (1 << 1);
+ const float *scales = pd()->attr()->output_scales_.scales_;
+ const size_t work_amount = jcp.ngroups * jcp.mb;
+
+ auto col = scratchpad.get<acc_data_t>(key_conv_gemm_col)
+ + (ptrdiff_t)ithr * jcp.im2col_sz;
+ auto acc = scratchpad.get<acc_data_t>(key_conv_int_dat_in_acc_dt)
+ + (ptrdiff_t)ithr * jcp.is * jcp.ic;
+
+ int n{0}, g{0};
+ size_t start = 0, end = 0;
+
+ balance211(work_amount, nthr, ithr, start, end);
+ nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups);
+
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ const diff_dst_data_t *diff_dst = diff_dst_base
+ + n * diff_dst_mb_stride + g * diff_dst_g_stride;
+ const wei_data_t *wei = wei_base + g * wei_g_stride;
+ diff_src_data_t *diff_src = diff_src_base + n * diff_src_mb_stride
+ + g * diff_src_g_stride;
+
+ const int M = jcp.ks * jcp.ic;
+ const int N = jcp.os;
+ const int K = jcp.oc;
+ const int8_t off_a = 0, off_b = 0;
+ const int32_t off_c = 0;
+ const float onef = 1.0, zerof = 0.0;
+ const int LD = K * jcp.ngroups;
+
+ gemm_s8x8s32("T", "N", "F", &M, &N, &K, &onef,
+ wei, &LD, &off_a, diff_dst, &LD, &off_b,
+ &zerof, jcp.im2col_sz ? col : acc, &M, &off_c);
+
+ if (jcp.im2col_sz)
+ jit_gemm_convolution_utils::col2im_s32(jcp, col, acc);
+
+ parallel_nd(jcp.is, jcp.ic, [&](int is, int ic) {
+ float d = (float)acc[is * jcp.ic + ic];
+ if (jcp.with_bias)
+ d += get_bias(bia_base, g * jcp.ic + ic,
+ pd()->desc()->bias_desc.data_type);
+ d *= scales[(g * jcp.ic + ic) * scale_idx_mult];
+ const size_t diff_src_off = is * diff_src_os_stride + ic;
+ diff_src[diff_src_off] =
+ qz_a1b0<float, diff_src_data_t>()(d);
+ });
+ nd_iterator_step(n, jcp.mb, g, jcp.ngroups);
+ }
+}
+
+using namespace data_type;
+
+template struct _gemm_x8s8s32x_convolution_fwd_t<u8, f32>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s32>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<u8, s8>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<u8, u8>;
+
+template struct _gemm_x8s8s32x_convolution_fwd_t<s8, f32>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s32>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<s8, s8>;
+template struct _gemm_x8s8s32x_convolution_fwd_t<s8, u8>;
+
+template struct _gemm_u8s8s32x_convolution_bwd_data_t<f32>;
+template struct _gemm_u8s8s32x_convolution_bwd_data_t<s32>;
+template struct _gemm_u8s8s32x_convolution_bwd_data_t<s8>;
+template struct _gemm_u8s8s32x_convolution_bwd_data_t<u8>;
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp
new file mode 100644
index 0000000000..9e77b890d5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_convolution.hpp
@@ -0,0 +1,266 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef GEMM_X8S8S32X_CONVOLUTION_HPP
+#define GEMM_X8S8S32X_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_primitive_conf.hpp"
+#include "jit_generator.hpp"
+#include "gemm_convolution_utils.hpp"
+
+#include "gemm/gemm.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t src_type, data_type_t dst_type>
+struct _gemm_x8s8s32x_convolution_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR,
+ _gemm_x8s8s32x_convolution_fwd_t<src_type, dst_type>);
+
+ status_t init() {
+ using namespace data_type;
+
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(src_type, s8, data_type::undef, dst_type,
+ s32)
+ && IMPLICATION(with_bias(), utils::one_of(
+ desc()->bias_desc.data_type, f32, s32, s8, u8))
+ && !has_zero_dim_memory()
+ && set_default_formats_common(
+ dat_tag(), format_tag::any, dat_tag())
+ && post_ops_ok()
+ && memory_desc_matches_tag(*src_md(), dat_tag())
+ && memory_desc_matches_tag(*dst_md(), dat_tag())
+ && set_or_check_wei_format();
+ if (!ok) return status::unimplemented;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+ *desc(), src_md(), weights_md(0), dst_md(),
+ mkldnn_get_max_threads());
+ }
+
+ jit_gemm_conv_conf_t jcp_;
+
+ protected:
+ format_tag_t dat_tag() const { return format_tag::nhwc; }
+
+ bool set_or_check_wei_format() {
+ using namespace format_tag;
+
+ const bool is_src_s8 = src_md_.data_type == data_type::s8;
+
+ memory_desc_t want_wei_md = weights_md_;
+ memory_desc_init_by_tag(want_wei_md, with_groups() ? hwigo : hwio);
+
+ if (is_src_s8) {
+ want_wei_md.extra.flags = 0
+ | memory_extra_flags::compensation_conv_s8s8
+ | memory_extra_flags::scale_adjust;
+ want_wei_md.extra.compensation_mask = (1 << 0)
+ + (with_groups() ? (1 << 1) : 0);
+ want_wei_md.extra.scale_adjust =
+ mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
+ }
+
+ if (weights_md_.format_kind == format_kind::any) {
+ weights_md_ = want_wei_md;
+ return true;
+ }
+
+ return weights_md_ == want_wei_md;
+ }
+
+ bool post_ops_ok() const {
+ using namespace mkldnn::impl::primitive_kind;
+ auto const &po = attr()->post_ops_;
+ auto is_relu = [&](int idx) {
+ return po.entry_[idx].is_relu(true, false); };
+
+ switch (po.len_) {
+ case 0: return true;
+ case 1: return is_relu(0) || po.contain(sum, 0);
+ case 2: return po.contain(sum, 0) && is_relu(1);
+ default: return false;
+ }
+ return false;
+ }
+ };
+
+ _gemm_x8s8s32x_convolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true), pp_ker_(nullptr)
+ { pp_ker_ = new pp_ker_t(pd()); }
+ ~_gemm_x8s8s32x_convolution_fwd_t() { delete pp_ker_; }
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<data_type::s8>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+ typedef typename prec_traits<data_type::s32>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ // XXX: this is throwaway code that will become unnecessary when we have a
+ // sufficiently advanced igemm jit generator that supports quantization,
+ // relu, and whatnot
+ class pp_ker_t : jit_generator {
+ public:
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(
+ _gemm_x8s8s32x_convolution_fwd_t::pp_kernel);
+ pp_ker_t(const pd_t *pd);
+
+ void operator()(dst_data_t *dst, const acc_data_t *acc,
+ const char *bias, const float *scales,
+ float nslope, float sum_scale, float signed_scale,
+ int g, size_t start, size_t end);
+
+ size_t dst_os_stride_;
+
+ private:
+ void generate();
+
+ struct ker_args {
+ dst_data_t *dst;
+ const acc_data_t *acc;
+ const char *bias;
+ const float *scales;
+ float nslope;
+ float sum_scale;
+ float signed_scale;
+ size_t len;
+ size_t oc_offset;
+ };
+ void(*ker_)(const ker_args *args);
+
+ const jit_gemm_conv_conf_t &jcp_;
+ size_t OC_;
+ size_t OS_;
+ data_type_t bias_data_type_;
+ size_t bias_data_type_size_;
+ size_t scale_idx_mult_;
+ bool do_bias_;
+ bool do_relu_;
+ bool do_sum_;
+ bool do_signed_scaling_;
+ size_t vlen_;
+ };
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ void execute_forward(const exec_ctx_t &ctx) const;
+ void execute_forward_thr(const int ithr, const int nthr,
+ const src_data_t *src_base, const wei_data_t *wei_base,
+ const char *bia_base, dst_data_t *dst_base,
+ const memory_tracking::grantor_t &scratchpad) const;
+
+ int nthr_;
+ pp_ker_t *pp_ker_;
+
+};
+
+template <data_type_t dst_type>
+struct _gemm_u8s8s32x_convolution_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_data_pd_t{
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc, const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(IGEMM_S8U8S32_IMPL_STR,
+ _gemm_u8s8s32x_convolution_bwd_data_t<dst_type>);
+
+ status_t init() {
+ using namespace data_type;
+
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(dst_type, s8, data_type::undef, u8, s32)
+ && IMPLICATION(with_bias(), utils::one_of(
+ desc()->bias_desc.data_type, f32, s32, s8, u8))
+ && !has_zero_dim_memory()
+ && set_default_formats_common(dat_tag(), wei_tag(), dat_tag())
+ && attr()->post_ops_.has_default_values()
+ && memory_desc_matches_tag(*diff_src_md(), dat_tag())
+ && memory_desc_matches_tag(*diff_dst_md(), dat_tag())
+ && memory_desc_matches_tag(*weights_md(), wei_tag());
+ if (!ok) return status::unimplemented;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ return jit_gemm_convolution_utils::init_conf(jcp_, scratchpad,
+ *desc(), diff_src_md(), weights_md(), diff_dst_md(),
+ mkldnn_get_max_threads());
+ }
+
+ virtual bool support_bias() const override { return true; }
+
+ jit_gemm_conv_conf_t jcp_;
+
+ protected:
+ format_tag_t dat_tag() const { return format_tag::nhwc; }
+
+ format_tag_t wei_tag() const {
+ return with_groups() ? format_tag::hwigo : format_tag::hwio;
+ }
+ };
+
+ _gemm_u8s8s32x_convolution_bwd_data_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true) {}
+
+ typedef typename prec_traits<data_type::u8>::type diff_dst_data_t;
+ typedef typename prec_traits<data_type::s8>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type diff_src_data_t;
+ typedef typename prec_traits<data_type::s32>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ void execute_backward_data_thr(const int ithr, const int nthr,
+ const diff_dst_data_t *diff_dst_base, const wei_data_t *wei_base,
+ const char *bia_base, diff_src_data_t *diff_src_base,
+ const memory_tracking::grantor_t &scratchpad) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp
new file mode 100644
index 0000000000..1e435a233a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.cpp
@@ -0,0 +1,453 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+#include "simple_q10n.hpp"
+
+#include "gemm/gemm.hpp"
+#include "gemm_x8s8s32x_inner_product.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace math;
+using namespace format_tag;
+using namespace memory_tracking::names;
+
+template<data_type_t src_type, data_type_t dst_type>
+gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::pp_kernel_t(
+ const pd_t *pd, bool dst_is_acc)
+ : ker_(nullptr), OC_(pd->OC())
+ , bias_data_type_(data_type::undef), bias_data_type_size_(0)
+ , scale_idx_mult_(0), do_bias_(false), do_relu_(false)
+{
+ using namespace types;
+
+ scale_idx_mult_ = (pd->attr()->output_scales_.mask_ == (1 << 1));
+
+ auto &post_ops = pd->attr()->post_ops_;
+ do_relu_ = post_ops.len_ == 1;
+ do_bias_ = pd->with_bias();
+ bias_data_type_ = pd->desc()->bias_desc.data_type;
+ if (do_bias_) {
+ assert(bias_data_type_ != data_type::undef);
+ bias_data_type_size_ = data_type_size(bias_data_type_);
+ }
+
+ if (!mayiuse(avx512_core))
+ // use fallback code for older CPUs since they do not have optimized
+ // x8s8s32 GEMM anyways. The configuration variables above are used by
+ // the fallback code.
+ return;
+ else
+ generate();
+}
+
+template<data_type_t src_type, data_type_t dst_type>
+void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::generate()
+{
+ using namespace Xbyak;
+ using namespace utils;
+
+ // TODO: clean-up
+ Reg64 reg_param = abi_param1;
+ Reg64 reg_dst = rdx;
+ Reg64 reg_acc = rax;
+ Reg64 reg_bias = rbx;
+ Reg64 reg_scales = rsi;
+
+ Reg64 reg_len = r8;
+ Reg64 reg_tmp = rcx; // intentional for shifting purposes
+ Reg64 reg_oc_offset = r9;
+ Reg64 reg_rem_mask = r10;
+ Opmask kreg_rem_mask = k1;
+ Opmask kreg_relu_cmp = k2;
+
+ const size_t vlen = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+
+ Zmm vreg_zero = Zmm(0);
+ Zmm vreg_scale = Zmm(1);
+ Zmm vreg_nslope = Zmm(2);
+
+ auto vreg_dst = [&](int idx) { return Zmm(3 + idx * 2 + 0); };
+ auto vreg_bias = [&](int idx) { return Zmm(3 + idx * 2 + 1); };
+
+ preamble();
+
+#define PARAM_OFF(x) offsetof(ker_args, x)
+ mov(reg_dst, ptr[reg_param + PARAM_OFF(dst)]);
+ mov(reg_acc, ptr[reg_param + PARAM_OFF(acc)]);
+ mov(reg_bias, ptr[reg_param + PARAM_OFF(bias)]);
+ mov(reg_scales, ptr[reg_param + PARAM_OFF(scales)]);
+ mov(reg_len, ptr[reg_param + PARAM_OFF(len)]);
+ mov(reg_oc_offset, ptr[reg_param + PARAM_OFF(oc_offset)]);
+ vbroadcastss(vreg_nslope, ptr[reg_param + PARAM_OFF(nslope)]);
+ if (scale_idx_mult_ == 0)
+ vbroadcastss(vreg_scale, dword[reg_scales]);
+#undef PARAM_OFF
+
+ if (do_relu_ || dst_type == data_type::u8)
+ vxorps(vreg_zero, vreg_zero, vreg_zero);
+
+ // Load accumulated value, convert to float, apply bias (if any), scaling,
+ // and relu (if any); then convert to destination type and store
+ auto compute = [&](size_t offset, int idx, bool apply_mask) {
+ auto acc_addr = ptr[reg_acc + offset * sizeof(acc_data_t)];
+
+ if (scale_idx_mult_ > 0) {
+ assert(scale_idx_mult_ == 1);
+ auto scale_addr = ptr[reg_scales + offset * sizeof(float)];
+ auto vreg_scale_ = vreg_scale;
+ if (apply_mask)
+ vreg_scale_ = vreg_scale_ | kreg_rem_mask;
+ vmovups(vreg_scale, scale_addr);
+ }
+
+ auto vreg_dst_ = vreg_dst(idx);
+ if (apply_mask)
+ vreg_dst_ = vreg_dst_ | kreg_rem_mask;
+ vcvtdq2ps(vreg_dst_, acc_addr);
+
+ if (do_bias_) {
+ auto bias_addr = ptr[reg_bias + offset * bias_data_type_size_];
+ auto vreg_bias_ = vreg_bias(idx);
+ if (apply_mask)
+ vreg_bias_ = vreg_bias_ | kreg_rem_mask;
+
+ switch (bias_data_type_) {
+ case data_type::s8:
+ vpmovsxbd(vreg_bias_, bias_addr);
+ break;
+ case data_type::u8:
+ vpmovzxbd(vreg_bias_, bias_addr);
+ break;
+ case data_type::s32:
+ case data_type::f32:
+ vmovups(vreg_bias_, bias_addr);
+ break;
+ default: assert(!"unimplemented");
+ }
+ if (bias_data_type_ != data_type::f32)
+ vcvtdq2ps(vreg_bias(idx), vreg_bias(idx));
+ vaddps(vreg_dst(idx), vreg_dst(idx), vreg_bias(idx));
+ }
+
+ vmulps(vreg_dst(idx), vreg_dst(idx), vreg_scale);
+ if (do_relu_) {
+ vcmpps(kreg_relu_cmp, vreg_dst(idx), vreg_zero, _cmp_lt_os);
+ vmulps(vreg_dst(idx) | kreg_relu_cmp, vreg_dst(idx), vreg_nslope);
+ }
+
+ if (dst_type == data_type::u8)
+ vmaxps(vreg_dst(idx), vreg_dst(idx), vreg_zero);
+
+ if (dst_type != data_type::f32) {
+ vcvtps2dq(vreg_dst(idx), vreg_dst(idx));
+ }
+
+ auto dst_addr = ptr[reg_dst + offset * sizeof(dst_data_t)];
+ switch (dst_type) {
+ case data_type::s8:
+ vpmovsdb(dst_addr, vreg_dst_);
+ break;
+ case data_type::u8:
+ vpmovusdb(dst_addr, vreg_dst_);
+ break;
+ case data_type::f32:
+ case data_type::s32:
+ vmovups(dst_addr, vreg_dst_);
+ break;
+ default: assert(!"unimplemented");
+ }
+ };
+
+ // Advance all pointers by an immediate
+ auto advance_ptrs_imm = [&](size_t offset) {
+ add(reg_dst, offset * sizeof(dst_data_t));
+ add(reg_acc, offset * sizeof(acc_data_t));
+ if (scale_idx_mult_) {
+ assert(scale_idx_mult_ == 1);
+ add(reg_scales, offset * sizeof(float));
+ }
+ if (do_bias_)
+ add(reg_bias, offset * bias_data_type_size_);
+ };
+
+ // Advance all pointers by a value stored in a register
+ auto advance_ptrs_reg = [&](Reg64 offset) {
+ lea(reg_dst, ptr[reg_dst + offset * sizeof(dst_data_t)]);
+ lea(reg_acc, ptr[reg_acc + offset * sizeof(acc_data_t)]);
+ if (scale_idx_mult_) {
+ assert(scale_idx_mult_ == 1);
+ lea(reg_scales, ptr[reg_scales + offset * sizeof(float)]);
+ }
+ if (do_bias_)
+ lea(reg_bias, ptr[reg_bias + offset * bias_data_type_size_]);
+ };
+
+ // Rewind pointers that point to data that is indixed by output channel
+ // (bias or per-oc scaling factors)
+ auto rewind_ptrs = [&]() {
+ if (do_bias_)
+ sub(reg_bias, OC_ * bias_data_type_size_);
+ if (scale_idx_mult_) {
+ assert(scale_idx_mult_ == 1);
+ sub(reg_scales, OC_ * sizeof(float));
+ }
+ };
+
+ // <-------------------- OC ------------------------------->
+ //
+ // ^ +....................+----------------------------------+
+ // | : not accessed | Prologue loop |
+ // | +--------------------+----------------------------------+
+ // | |
+ // M | Main loop (unrolled) |
+ // B | |
+ // +--------------------------------+----------------------+
+ // | | Epilogue loop | not accessed :
+ // v +--------------------------------+......................+
+
+ Label prologue_end;
+ cmp(reg_oc_offset, 0);
+ je(prologue_end, T_NEAR);
+
+ // Prologue loop
+ {
+ mov(reg_tmp, OC_);
+ sub(reg_tmp, reg_oc_offset);
+ cmp(reg_tmp, reg_len);
+ cmovg(reg_tmp, reg_len);
+ sub(reg_len, reg_tmp);
+
+ Label prologue_loop, prologue_loop_tail, prologue_loop_end;
+ cmp(reg_tmp, vlen);
+ jle(prologue_loop_tail, T_NEAR); // Skips for reg_tmp == 16 too (?)
+ L(prologue_loop); {
+ compute(0, 0, false);
+ advance_ptrs_imm(vlen);
+ sub(reg_tmp, vlen);
+ cmp(reg_tmp, vlen);
+ jge(prologue_loop, T_NEAR);
+ }
+
+ L(prologue_loop_tail);
+ mov(reg_rem_mask, 1);
+ shl(reg_rem_mask, cl); // cl == reg_tmp because reg_tmp <= vlen here
+ sub(reg_rem_mask, 1);
+ jz(prologue_loop_end, T_NEAR);
+
+ kmovq(kreg_rem_mask, reg_rem_mask);
+ compute(0, 0, true);
+ advance_ptrs_reg(reg_tmp);
+
+ L(prologue_loop_end);
+ rewind_ptrs();
+ }
+ L(prologue_end);
+
+ // Main loop
+ Label main_loop_end;
+ {
+ cmp(reg_len, OC_);
+ jle(main_loop_end, T_NEAR);
+
+ Label main_loop;
+ L(main_loop); {
+ size_t def_unroll = 4;
+ size_t max_unroll = 13;
+
+ size_t OC_loop, OC_tail;
+ if (OC_ < max_unroll * vlen) {
+ // Fully unroll small loops
+ OC_loop = 0;
+ OC_tail = OC_;
+ } else {
+ OC_loop = vlen * def_unroll;
+ OC_tail = OC_ % OC_loop;
+ }
+
+ assert(!!OC_loop || !!OC_tail);
+
+ if (OC_tail % vlen) {
+ int vlen_tail = OC_tail % vlen;
+ unsigned tail_mask = (1 << vlen_tail) - 1;
+ mov(reg_tmp, tail_mask);
+ kmovq(kreg_rem_mask, reg_tmp);
+ }
+
+ if (OC_loop) {
+ mov(reg_tmp, rnd_dn(OC_, OC_loop));
+ Label oc_loop;
+ L(oc_loop); {
+ for (size_t offset = 0; offset < OC_loop; offset += vlen)
+ compute(offset, offset / vlen, false);
+ advance_ptrs_imm(OC_loop);
+ sub(reg_tmp, OC_loop);
+ jnz(oc_loop);
+ }
+ }
+
+ if (OC_tail) {
+ for (size_t offset = 0; offset < OC_tail; offset += vlen) {
+ bool use_mask = (offset + vlen) > OC_tail;
+ compute(offset, offset / vlen, use_mask);
+ }
+ advance_ptrs_imm(OC_tail);
+ }
+
+ rewind_ptrs();
+ sub(reg_len, OC_);
+ cmp(reg_len, OC_);
+ jge(main_loop, T_NEAR);
+ }
+ }
+ L(main_loop_end);
+
+ // Epilogue loop
+ Label epilogue_end;
+ {
+ cmp(reg_len, 0);
+ je(epilogue_end, T_NEAR);
+
+ Label epilogue_loop, epilogue_loop_tail;
+ cmp(reg_len, vlen);
+ jle(epilogue_loop_tail, T_NEAR); // Skips for reg_len == 16 (?)
+ L(epilogue_loop); {
+ compute(0, 0, false);
+ sub(reg_len, vlen);
+ advance_ptrs_imm(vlen);
+ cmp(reg_len, vlen);
+ jge(epilogue_loop, T_NEAR);
+ }
+
+ L(epilogue_loop_tail);
+ mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
+ mov(reg_rem_mask, 1);
+ shl(reg_rem_mask, cl); // reg_tmp == rcx and reg_tail < vlen == 16
+ sub(reg_rem_mask, 1);
+ jz(epilogue_end, T_NEAR);
+ kmovq(kreg_rem_mask, reg_rem_mask);
+ compute(0, 0, true);
+ }
+
+ L(epilogue_end);
+
+ postamble();
+
+ ker_ = getCode<decltype(ker_)>();
+}
+
+template<data_type_t src_type, data_type_t dst_type>
+void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::pp_kernel_t::operator ()(
+ dst_data_t *dst, const acc_data_t *acc,
+ const char *bias, const float *scales, float nslope,
+ size_t start, size_t end)
+{
+ using math::get_bias;
+
+ if (end <= start)
+ return;
+
+ if (ker_) {
+ // JIT
+ ker_args args;
+ size_t oc_offset = start % OC_;
+ args.dst = dst + start;
+ args.acc = acc + start;
+ args.bias = bias + oc_offset * bias_data_type_size_;
+ args.scales = scales + scale_idx_mult_ * oc_offset;
+ args.nslope = nslope;
+ args.len = end - start;
+ args.oc_offset = oc_offset;
+ ker_(&args);
+ } else {
+ // Fallback
+ size_t oc = start % OC_;
+ for (size_t i = start; i < end; i++) {
+ float d = (float)acc[i];
+ float b = get_bias(bias, oc, bias_data_type_);
+ d = d + b;
+ d *= scales[oc * scale_idx_mult_];
+ if (do_relu_ && d < 0)
+ d *= nslope;
+ dst[i] = qz_a1b0<float, dst_data_t>()(d);
+ oc = (oc == OC_ - 1) ? 0 : oc + 1;
+ }
+ }
+};
+
+template <data_type_t src_type, data_type_t dst_type>
+void gemm_x8s8s32x_inner_product_fwd_t<src_type, dst_type>::execute_forward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+
+ bool wei_tr = memory_desc_matches_one_of_tag(
+ *pd()->weights_md(), oiw, oihw, oidhw, oi);
+
+ const int M = OC;
+ const int N = MB;
+ const int K = pd()->IC_total_padded();
+ const int8_t off_a = 0, off_b = 0;
+ const int32_t off_c = 0;
+
+ const float *scales = pd()->attr()->output_scales_.scales_;
+
+ const auto &post_ops = pd()->attr()->post_ops_;
+ const bool do_relu = post_ops.len_ == 1;
+ const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
+
+ acc_data_t *acc = pd()->dst_is_acc_
+ ? (acc_data_t *)dst
+ : scratchpad(ctx).template get<acc_data_t>(key_iprod_int_dat_in_acc_dt);
+
+ const float onef = 1.0, zerof = 0.0;
+ gemm_s8x8s32(wei_tr ? "T" : "N", "N", "F", &M, &N, &K, &onef, weights,
+ wei_tr ? &K : &M, &off_a, src, &K, &off_b, &zerof, acc, &M, &off_c);
+
+ if (!pd()->attr()->has_default_values() || !pd()->dst_is_acc_
+ || pd()->with_bias()) {
+ const bool force_sequential = MB * OC < 2000;
+ parallel(force_sequential ? 1 : 0, [&](int ithr, int nthr) {
+ size_t start, end;
+ balance211((size_t)OC * MB, nthr, ithr, start, end);
+ (*pp_kernel_)(dst, acc, bias, scales, nslope, start, end);
+ });
+ }
+}
+
+using namespace data_type;
+
+template struct gemm_x8s8s32x_inner_product_fwd_t<u8, f32>;
+template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s32>;
+template struct gemm_x8s8s32x_inner_product_fwd_t<u8, s8>;
+template struct gemm_x8s8s32x_inner_product_fwd_t<u8, u8>;
+template struct gemm_x8s8s32x_inner_product_fwd_t<s8, f32>;
+template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s32>;
+template struct gemm_x8s8s32x_inner_product_fwd_t<s8, s8>;
+template struct gemm_x8s8s32x_inner_product_fwd_t<s8, u8>;
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp
new file mode 100644
index 0000000000..ac6a5c8f85
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/gemm_x8s8s32x_inner_product.hpp
@@ -0,0 +1,166 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef GEMM_X8S8S32X_INNER_PRODUCT_HPP
+#define GEMM_X8S8S32X_INNER_PRODUCT_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "gemm/gemm.hpp"
+#include "jit_generator.hpp"
+
+#include "cpu_inner_product_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t src_type, impl::data_type_t dst_type>
+struct gemm_x8s8s32x_inner_product_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_inner_product_fwd_pd_t {
+ using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T(src_type == data_type::u8
+ ? IGEMM_S8U8S32_IMPL_STR
+ : IGEMM_S8S8S32_IMPL_STR,
+ gemm_x8s8s32x_inner_product_fwd_t);
+
+ status_t init() {
+ using namespace data_type;
+
+ bool ok = true
+ && set_default_params() == status::success
+ && is_fwd()
+ && !has_zero_dim_memory()
+ && src_md()->data_type == src_type
+ && dst_md()->data_type == dst_type
+ && weights_md()->data_type == s8
+ && IMPLICATION(with_bias(), utils::one_of(
+ weights_md(1)->data_type, f32, s32, s8, u8))
+ && attr()->post_ops_.len_ <= 1
+ && IMPLICATION(attr()->post_ops_.len_,
+ attr()->post_ops_.entry_[0].is_relu(true, false))
+ && dense_gemm_consitency_check(src_md(), weights_md(),
+ dst_md());
+ if (!ok) return status::unimplemented;
+
+ dst_is_acc_ = utils::one_of(dst_type, s32, f32);
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ bool dst_is_acc_;
+
+ protected:
+ status_t set_default_params() {
+ using namespace format_tag;
+ if (src_md_.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(src_md_,
+ utils::pick(ndims() - 2, nc, nwc, nhwc, ndhwc)));
+ }
+ if (dst_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_md_, nc));
+ if (weights_md_.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(weights_md_,
+ utils::pick(ndims() - 2, io, wio, hwio, dhwio)));
+ }
+ return inner_product_fwd_pd_t::set_default_params();
+ }
+
+ private:
+ void init_scratchpad() {
+ if (!dst_is_acc_) {
+ auto scratchpad = scratchpad_registry().registrar();
+ scratchpad.book(
+ memory_tracking::names::key_iprod_int_dat_in_acc_dt,
+ sizeof(acc_data_t) * MB() * OC());
+ }
+ }
+ };
+
+ gemm_x8s8s32x_inner_product_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true)
+ { pp_kernel_ = new pp_kernel_t(apd, pd()->dst_is_acc_); }
+ ~gemm_x8s8s32x_inner_product_fwd_t() { delete pp_kernel_; }
+
+ typedef typename prec_traits<dst_type>::type data_t;
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<data_type::s8>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+ typedef typename prec_traits<data_type::s32>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ // XXX: this is throwaway code that will become unnecessary when we have a
+ // sufficiently advanced igemm jit generator that supports quantization,
+ // relu, and whatnot
+ class pp_kernel_t: jit_generator {
+ public:
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(
+ gemm_x8s8s32x_inner_product_fwd_t::pp_kernel);
+ pp_kernel_t(const pd_t *pd, bool dst_is_acc);
+
+ void operator()(dst_data_t *dst, const acc_data_t *acc,
+ const char *bias, const float *scales, float nslope,
+ size_t start, size_t end);
+ private:
+ void generate();
+
+ struct ker_args {
+ dst_data_t *dst;
+ const acc_data_t *acc;
+ const char *bias;
+ const float *scales;
+ float nslope;
+ size_t len;
+ size_t oc_offset;
+ };
+ void (*ker_)(const ker_args *args);
+
+ size_t OC_;
+ data_type_t bias_data_type_;
+ size_t bias_data_type_size_;
+ size_t scale_idx_mult_;
+ bool do_bias_;
+ bool do_relu_;
+ };
+
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ pp_kernel_t *pp_kernel_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp
new file mode 100644
index 0000000000..6fa251d465
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.cpp
@@ -0,0 +1,674 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+* Copyright 2018 YANDEX LLC
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_memory.hpp"
+
+#include "jit_avx2_1x1_conv_kernel_f32.hpp"
+
+#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::utils;
+
+using namespace Xbyak;
+
+void jit_avx2_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk)
+{
+ mov(aux1_reg_bcast_data, reg_bcast_data);
+ mov(aux_reg_output_data, reg_output_data);
+ mov(bcast_loop_iter, reg_bcast_loop_work);
+
+ Label bcast_loop, bcast_loop_tail;
+
+ cmp(bcast_loop_iter, jcp.ur);
+ jl(bcast_loop_tail, T_NEAR);
+
+ L(bcast_loop); {
+ assert(jcp.bcast_block % jcp.ur == 0);
+ int num_substeps = jcp.bcast_block / jcp.ur;
+ assert(num_substeps > 0 && num_substeps < 10);
+ for (int i = 0; i < num_substeps; i++) {
+ generate_reduce_loop(load_loop_blk, jcp.ur);
+ if (i < num_substeps - 1) {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
+ add(aux_reg_output_data, jcp.bcast_loop_output_substep);
+ } else {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
+ - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
+ add(aux_reg_output_data, jcp.bcast_loop_output_step
+ - (num_substeps - 1) * jcp.bcast_loop_output_substep);
+ }
+ }
+ sub(bcast_loop_iter, jcp.bcast_block);
+ cmp(bcast_loop_iter, jcp.bcast_block);
+ jge(bcast_loop, T_NEAR);
+ }
+
+ L(bcast_loop_tail);
+ if (jcp.ur_tail) {
+ Label bcast_loop_tail_out;
+ cmp(bcast_loop_iter, 0);
+ jz(bcast_loop_tail_out, T_NEAR);
+ generate_reduce_loop(load_loop_blk, jcp.ur_tail);
+ L(bcast_loop_tail_out);
+ }
+}
+
+void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
+ int load_loop_blk, int ur)
+{
+ auto vreg_load = [=](int i) {
+ return Ymm(ur * load_loop_blk + i);
+ };
+
+ auto vreg_accum = [=](int i, int j) {
+ return Ymm(j * load_loop_blk + i);
+ };
+
+ auto bias_ptr = [=](int i) {
+ return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i];
+ };
+
+ auto bcast_ptr = [=](int u, int j) {
+ assert(j < jcp.ur);
+ assert(u <= jcp.reduce_loop_unroll);
+ size_t offt;
+ if (one_of(jcp.prop_kind,
+ forward_training, forward_inference, backward_data))
+ {
+ assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
+ ? jcp.oc_block : jcp.ic_block);
+ auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
+ offt = (u == jcp.reduce_loop_unroll)
+ ? (height + j) * jcp.reduce_loop_unroll
+ : j * jcp.reduce_loop_unroll + u;
+ } else
+ offt = u * jcp.ic_block + j;
+ return ptr[aux_reg_bcast_data + sizeof(float) * offt];
+ };
+
+ auto load_ptr = [=](int u, int i) {
+ size_t offt;
+ size_t u0 = u % jcp.reduce_loop_unroll;
+ size_t u1 = u / jcp.reduce_loop_unroll;
+ switch (jcp.prop_kind) {
+ case backward_data:
+ offt = (i * jcp.oc_block + u0) * jcp.ic_block;
+ break;
+ case backward_weights:
+ offt = (i * jcp.os + u0) * jcp.oc_block;
+ break;
+ default:
+ offt = (i * jcp.ic + u0) * jcp.oc_block;
+ }
+ return ptr[aux_reg_load_data
+ + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt];
+ };
+
+ auto output_ptr = [=](int i, int j) {
+ switch (jcp.prop_kind) {
+ case backward_data:
+ return ptr[aux_reg_output_data +
+ (i * jcp.is + j) * jcp.ic_block * sizeof(float)];
+ case backward_weights:
+ return ptr[aux_reg_output_data
+ + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
+ + sizeof(float) * jcp.oc_block * j];
+ default:
+ return ptr[aux_reg_output_data +
+ (i * jcp.os + j) * jcp.oc_block * sizeof(float)];
+ }
+ };
+
+ auto init = [=]() {
+ Label init_done, init_zero;
+
+ if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
+ forward_inference)) {
+ test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
+ jz(init_zero);
+
+ for (int i = 0; i < load_loop_blk; i++)
+ for (int j = 0; j < ur; ++j)
+ vmovups(vreg_accum(i, j), bias_ptr(i));
+ jmp(init_done);
+ }
+
+ L(init_zero);
+ for (int i = 0; i < load_loop_blk; ++i)
+ for (int j = 0; j < ur; ++j) {
+ auto r = vreg_accum(i, j);
+ vxorps(r, r, r);
+ }
+
+ L(init_done);
+ for (int i = 0; i < load_loop_blk; ++i)
+ vmovups(vreg_load(i), load_ptr(0, i));
+ vbroadcastss(vreg_bcast, bcast_ptr(0, 0));
+ };
+
+ auto store = [=]() {
+ Label store_noadd;
+
+ if (!jcp.with_sum) {
+ test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
+ jnz(store_noadd, T_NEAR);
+ }
+
+ for (int j = 0; j < ur; ++j)
+ for (int i = 0; i < load_loop_blk; ++i) {
+ auto r = vreg_accum(i, j);
+ vaddps(r, r, output_ptr(i, j));
+ }
+
+ L(store_noadd);
+
+ if (jcp.with_eltwise) {
+ assert(ur * load_loop_blk < 14);
+
+ Label store_norelu;
+ test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
+ jz(store_norelu, T_NEAR);
+
+ eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
+
+ L(store_norelu);
+ }
+
+ for (int j = 0; j < ur; ++j)
+ for (int i = 0; i < load_loop_blk; ++i) {
+ vmovups(output_ptr(i, j), vreg_accum(i, j));
+ }
+ };
+
+ auto fma_block = [=](bool last_block) {
+ for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
+ for (int j = 0; j < ur; ++j) {
+ for (int i = 0; i < load_loop_blk; ++i) {
+ if (mayiuse(avx2))
+ vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast);
+ else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
+ vmulps(vtmp, vreg_bcast, vreg_load(i));
+ vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp);
+ }
+ if (j == ur - 1 && !(last_block
+ && u == jcp.reduce_loop_unroll - 1))
+ vmovups(vreg_load(i), load_ptr(u + 1, i));
+ }
+ if (j < ur - 1)
+ vbroadcastss(vreg_bcast, bcast_ptr(u, j + 1));
+ }
+ if (!last_block || u < jcp.reduce_loop_unroll - 1)
+ vbroadcastss(vreg_bcast, bcast_ptr(u + 1, 0));
+ }
+ };
+
+ Label reduce_loop, reduce_loop_tail;
+
+ mov(aux_reg_load_data, reg_load_data);
+ mov(aux_reg_bcast_data, aux1_reg_bcast_data);
+
+ init();
+
+ mov(reduce_loop_iter, reg_reduce_loop_work);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jle(reduce_loop_tail, T_NEAR);
+
+ L(reduce_loop); {
+ fma_block(false);
+ add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
+ add(aux_reg_load_data, jcp.reduce_loop_load_step);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jg(reduce_loop, T_NEAR);
+ }
+
+ L(reduce_loop_tail);
+ fma_block(true);
+
+ store();
+}
+
+void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
+{
+ if (!jcp.with_bias || jcp.prop_kind != backward_weights)
+ return;
+
+ Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
+ Label diff_bias_load;
+
+ auto diff_bias_ptr = [=](int i) {
+ return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)];
+ };
+
+ auto load_ptr = [=](int u, int i) {
+ return ptr[aux_reg_load_data
+ + (i * jcp.os + u) * jcp.oc_block * sizeof(float)];
+ };
+
+ auto diff_bias_reg = [=](int i) { return Ymm(i); };
+
+ mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
+ cmp(reg_diff_bias_data, 0);
+ je(diff_bias_loop_out, T_NEAR);
+
+ test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
+ jz(diff_bias_load, T_NEAR);
+
+ for (int i = 0; i < load_loop_blk; ++i) {
+ auto r = diff_bias_reg(i);
+ vxorps(r, r, r);
+ }
+ jmp(diff_bias_init_out, T_NEAR);
+
+ L(diff_bias_load);
+ for (int i = 0; i < load_loop_blk; ++i)
+ vmovups(diff_bias_reg(i), diff_bias_ptr(i));
+
+ L(diff_bias_init_out);
+ mov(aux_reg_load_data, reg_load_data);
+ mov(reduce_loop_iter, reg_reduce_loop_work);
+ L(diff_bias_loop); {
+ for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
+ for (int i = 0; i < load_loop_blk; ++i)
+ vaddps(diff_bias_reg(i), diff_bias_reg(i), load_ptr(u, i));
+ assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
+ add(aux_reg_load_data, jcp.reduce_loop_load_step);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jnz(diff_bias_loop, T_NEAR);
+ }
+
+ for (int i = 0; i < load_loop_blk; i++)
+ vmovups(diff_bias_ptr(i), diff_bias_reg(i));
+ add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
+ mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
+
+ L(diff_bias_loop_out);
+}
+
+void jit_avx2_1x1_conv_kernel_f32::generate()
+{
+ preamble();
+
+ mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
+ mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
+ mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
+ if (jcp.with_bias) {
+ if (jcp.prop_kind == backward_weights) {
+ sub(rsp, stack_space_needed);
+ mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
+ mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
+ } else
+ mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
+ }
+
+ mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
+ mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
+ mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
+ mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
+ if (jcp.prop_kind == backward_weights)
+ mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
+
+ auto generate_load_loop_body = [=] (int load_loop_blk) {
+ generate_bcast_loop(load_loop_blk);
+ add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
+ switch (jcp.prop_kind) {
+ case forward_training:
+ case forward_inference:
+ add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
+ add(reg_output_data,
+ load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
+ break;
+ case backward_data:
+ add(reg_output_data,
+ load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
+ break;
+ case backward_weights:
+ for (int i = 0; i < load_loop_blk; i++)
+ add(reg_output_data, reg_output_stride);
+ break;
+ default:
+ assert(!"invalid prop_kind");
+ }
+ sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
+ };
+
+ Label load_loop_blk_8;
+ Label load_loop_blk_16;
+ Label load_loop_blk_24;
+ Label load_loop_blk_end;
+
+ cmp(reg_load_loop_work, 8);
+ jle(load_loop_blk_8, T_NEAR);
+
+ cmp(reg_load_loop_work, 32);
+ je(load_loop_blk_16, T_NEAR);
+
+ cmp(reg_load_loop_work, 16);
+ jle(load_loop_blk_16, T_NEAR);
+
+ L(load_loop_blk_24); {
+ generate_diff_bias_loop(3);
+ generate_load_loop_body(3);
+ cmp(reg_load_loop_work, 32);
+ je(load_loop_blk_16);
+ cmp(reg_load_loop_work, 24);
+ jge(load_loop_blk_24);
+ }
+
+ cmp(reg_load_loop_work, 8);
+ jle(load_loop_blk_8, T_NEAR);
+
+ L(load_loop_blk_16); {
+ generate_diff_bias_loop(2);
+ generate_load_loop_body(2);
+ cmp(reg_load_loop_work, 16);
+ jge(load_loop_blk_16);
+ }
+
+ L(load_loop_blk_8); {
+ cmp(reg_load_loop_work, 0);
+ je(load_loop_blk_end, T_NEAR);
+ generate_diff_bias_loop(1);
+ generate_load_loop_body(1);
+ }
+
+ L(load_loop_blk_end);
+
+ if (jcp.with_bias && jcp.prop_kind == backward_weights)
+ add(rsp, 8);
+
+ postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+}
+
+bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok(
+ jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+
+ switch (p.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
+ case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr)
+{
+ if (!mayiuse(avx)) return status::unimplemented;
+
+ // TODO (Roma): this code is duplicated from the generic kernel; maybe the
+ // configuration struct could do some stuff below
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ const int ndims = src_d.ndims();
+
+ jcp.prop_kind = cd.prop_kind;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+
+ jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
+ jcp.iw = src_d.dims()[ndims - 1];
+ jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[ndims - 1];
+
+ jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
+ jcp.kw = weights_d.dims()[with_groups + ndims - 1];
+
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][ndims - 3];
+
+ jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
+ jcp.stride_w = cd.strides[ndims - 3];
+
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ jcp.os = jcp.oh * jcp.ow;
+ jcp.is = jcp.ih * jcp.iw;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ jcp.with_sum = p.find(primitive_kind::sum) != -1;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise) {
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+ if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu)
+ return status::unimplemented;
+ }
+
+ const int is_bwd_d = jcp.prop_kind == backward_data;
+
+ format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c;
+ format_tag_t wei_tag = with_groups
+ ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o,
+ gOIhw8o8i)
+ : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
+ OIhw8o8i);
+
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+ jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
+
+ const int simd_w = 8;
+
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+
+ bool args_ok = true
+ && jcp.ngroups == 1
+ && jcp.src_tag == dat_tag
+ && jcp.wei_tag == wei_tag
+ && jcp.dst_tag == dat_tag;
+ if (!args_ok) return status::unimplemented;
+
+ args_ok = true
+ && jcp.ih == jcp.oh && jcp.iw == jcp.ow
+ && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
+ && jcp.t_pad == 0 && jcp.l_pad == 0
+ && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
+ && jcp.kh == 1 && jcp.kw == 1;
+ if (!args_ok) return status::unimplemented;
+
+ // TODO: remove this restriction
+ // optimized 1x1 bwd_w does not support Intel AVX
+ if (jcp.prop_kind == backward_weights && !mayiuse(avx2))
+ return status::unimplemented;
+
+ jcp.ic_block = jcp.oc_block = simd_w;
+
+ jcp.ur = mayiuse(avx2) ? 4 : 3; // Intel AVX support
+
+ int load_blocking{ 0 };
+ int load_blocking_max{ 0 };
+ int bcast_blocking{ 0 };
+ int bcast_blocking_max{ 0 };
+ int reduce_blocking{ 0 };
+
+ if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
+ jcp.reduce_dim = jcp.ic;
+ jcp.reduce_block = jcp.ic_block;
+
+ jcp.load_dim = jcp.oc;
+ jcp.load_block = jcp.oc_block;
+
+ jcp.bcast_dim = jcp.is;
+ jcp.bcast_block = jcp.ur;
+
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
+ jcp.reduce_loop_load_step
+ = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
+
+ jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
+ jcp.bcast_loop_output_substep = -1; // unused
+ jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
+ jcp.bcast_loop_bcast_substep = -1; // unused
+
+ jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
+ jcp.load_loop_iter_step = jcp.oc_block;
+
+ load_blocking = 120; // assumes the kernel is jcp.ur x 3
+ load_blocking_max = 144;
+ bcast_blocking = 128; // affects load balancing across threads
+ bcast_blocking_max = 192;
+ reduce_blocking = 128; // affects L1$ utilization
+ } else if (jcp.prop_kind == backward_data) {
+ jcp.reduce_dim = jcp.oc;
+ jcp.reduce_block = jcp.oc_block;
+
+ jcp.load_dim = jcp.ic;
+ jcp.load_block = jcp.oc_block;
+
+ jcp.bcast_dim = jcp.os;
+ jcp.bcast_block = jcp.ur;
+
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
+ jcp.reduce_loop_load_step
+ = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
+
+ jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
+ jcp.bcast_loop_output_substep = -1; // unused
+ jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
+ jcp.bcast_loop_bcast_substep = -1; // unused
+
+ jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
+ jcp.load_loop_iter_step = jcp.ic_block;
+
+ load_blocking = 96; // assumes the kernel is jcp.ur x 3
+ load_blocking_max = 144;
+ bcast_blocking = 128; // affects load balancing across threads
+ bcast_blocking_max = 196;
+ reduce_blocking = 64; // affects L1$ utilization
+ } else if (jcp.prop_kind == backward_weights) {
+ jcp.reduce_dim = jcp.os;
+ jcp.reduce_block = 1;
+
+ jcp.load_dim = jcp.oc;
+ jcp.load_block = jcp.oc_block;
+
+ jcp.bcast_dim = jcp.ic;
+ jcp.bcast_block = jcp.ic_block;
+
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
+ jcp.reduce_loop_load_step
+ = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
+
+ jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
+ jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
+ jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
+ jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
+
+ jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
+ jcp.load_loop_iter_step = jcp.oc_block;
+
+ /* --- */
+
+ load_blocking = div_up(jcp.load_dim, jcp.load_block);
+ while (true) {
+ if (load_blocking <= 32) break;
+ else if (load_blocking % 2 == 0) load_blocking /= 2;
+ else if (load_blocking % 3 == 0) load_blocking /= 3;
+ else break;
+ }
+ load_blocking *= jcp.load_block;
+ load_blocking_max = load_blocking;
+ assert(jcp.load_dim % load_blocking == 0);
+
+ bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
+ while (true) {
+ if (bcast_blocking <= 9) break;
+ else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
+ else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
+ else break;
+ }
+ bcast_blocking *= jcp.bcast_block;
+ bcast_blocking_max = bcast_blocking;
+ assert(jcp.bcast_dim % bcast_blocking == 0);
+
+ reduce_blocking = 128; // affects L1$ utilization
+ } else
+ return status::unimplemented;
+
+ assert(load_blocking);
+ assert(load_blocking_max);
+ assert(bcast_blocking);
+ assert(bcast_blocking_max);
+ assert(reduce_blocking);
+
+ assert(jcp.bcast_block % jcp.ur == 0);
+ jcp.ur_tail = jcp.bcast_dim % jcp.ur;
+
+ jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
+ jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
+ jcp.nb_load_blocking = load_blocking / jcp.load_block;
+ jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
+ jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
+
+ jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
+ jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
+ jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
+
+ return status::success;
+}
+
+void jit_avx2_1x1_conv_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp) {
+ using namespace mkldnn::impl::memory_tracking::names;
+
+ if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp
new file mode 100644
index 0000000000..bfdeb2b18d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_conv_kernel_f32.hpp
@@ -0,0 +1,110 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX2_1x1_CONV_KERNEL_F32_HPP
+#define JIT_AVX2_1x1_CONV_KERNEL_F32_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "cpu_memory.hpp"
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_avx2_1x1_conv_kernel_f32: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_1x1_conv_kernel_f32)
+
+ jit_avx2_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp,
+ const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
+ {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx2>(this,
+ jcp.eltwise);
+
+ this->generate();
+ jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode();
+ }
+
+ ~jit_avx2_1x1_conv_kernel_f32() {
+ delete eltwise_injector_;
+ }
+
+ static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_1x1_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr);
+
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp);
+
+ jit_1x1_conv_conf_t jcp;
+ const primitive_attr_t &attr_;
+ void (*jit_ker)(jit_1x1_conv_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ using ymm_t = const Xbyak::Ymm;
+
+ reg64_t reg_bcast_data = rax;
+ reg64_t reg_load_data = rsi;
+ reg64_t reg_output_data = rbx;
+ reg64_t aux_reg_bcast_data = rdx;
+ reg64_t aux1_reg_bcast_data = abi_not_param1;
+ reg64_t aux_reg_load_data = abi_param1;
+ reg64_t aux_reg_output_data = rbp;
+ reg64_t reg_load_loop_work = r9;
+ reg64_t reg_bcast_loop_work = r10;
+ reg64_t reg_reduce_loop_work = r11;
+ reg64_t load_loop_iter = r13;
+ reg64_t bcast_loop_iter = r14;
+ reg64_t reduce_loop_iter = r15;
+ reg64_t imm_addr64 = reduce_loop_iter;
+ reg64_t reg_reduce_pos_flag = r8;
+ reg64_t reg_output_stride = r12;
+ reg64_t reg_bias_data = r12;
+ reg64_t reg_diff_bias_data = bcast_loop_iter;
+
+ int reg_diff_bias_data_stack_offt = 0;
+ int stack_space_needed = 8;
+
+ ymm_t vreg_bcast = ymm_t(15);
+ ymm_t vtmp = ymm_t(14);
+
+ jit_uni_eltwise_injector_f32<avx2> *eltwise_injector_;
+
+ void generate_bcast_loop(int load_loop_blk);
+ void generate_reduce_loop(int load_loop_blk, int ur);
+ void generate_diff_bias_loop(int load_loop_blk);
+
+ void generate();
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp
new file mode 100644
index 0000000000..f116ac9056
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.cpp
@@ -0,0 +1,545 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_generator.hpp"
+
+#include "jit_avx2_1x1_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+#define data_blk_off(f, n, c, h, w) \
+ ((ndims == 3) \
+ ? (f).blk_off(n, c, w) \
+ : (f).blk_off(n, c, h, w))
+
+/* convolution forward */
+
+void jit_avx2_1x1_convolution_fwd_t::execute_forward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+ auto rtus_space = scratchpad(ctx).get<data_t>(key_conv_rtus_space);
+
+ const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
+ const int ndims = dst_d.ndims();
+
+ const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[ndims - 3];
+ const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][ndims - 3];
+
+ auto step = [](int default_step, int remaining, int tail_step) {
+ assert(default_step <= tail_step);
+ return remaining < tail_step ? remaining : default_step;
+ };
+
+ auto ker = [&](const int ithr, const int nthr) {
+ // TODO (Roma): remove this restriction
+ assert(jcp.stride_w == 1 && jcp.stride_h == 1);
+
+ auto p = jit_1x1_conv_call_s();
+ auto rp = rtus_driver_t<avx2>::call_params_t();
+
+ const int nb_oc = jcp.nb_load;
+ const int nb_ic = jcp.nb_reduce;
+ const int nb_ic_blocking = jcp.nb_reduce_blocking;
+ const int os_block = jcp.bcast_block;
+
+ int start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ int iwork = start;
+ while (iwork < end) {
+ int n{0}, g{0}, osb{0};
+ nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
+ jcp.nb_bcast);
+
+ int bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
+ jcp.nb_bcast_blocking_max);
+ bcast_step = nstl::min(bcast_step, end - iwork);
+
+ const int os = osb * os_block;
+ const int oh = os / jcp.ow;
+ const int ow = os % jcp.ow;
+
+ const int ih = nstl::max(oh * stride_h - pad_t, 0);
+ const int iw = nstl::max(ow * stride_w - pad_l, 0);
+ rp.iw_start = iw;
+
+ p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
+ rp.os = p.bcast_dim;
+
+ int ocb = 0;
+ while (ocb < jcp.nb_load) {
+ const int load_step = step(jcp.nb_load_blocking,
+ jcp.nb_load - ocb, jcp.nb_load_blocking_max);
+
+ const int _ocb = g * nb_oc + ocb;
+ p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
+ load_step * jcp.oc_block);
+ const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
+
+ p.output_data = &dst[dst_off];
+
+ p.bias_data = &bias[_ocb * jcp.oc_block];
+
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ p.first_last_flag = 0
+ | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
+ | (icb + nb_ic_blocking >= nb_ic
+ ? FLAG_REDUCE_LAST : 0);
+
+ p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
+ nb_ic_blocking * jcp.ic_block);
+ rp.icb = p.reduce_dim / jcp.reduce_block;
+
+ p.load_data = &weights[pd()->with_groups()
+ ? weights_d.blk_off(g, ocb, icb)
+ : weights_d.blk_off(ocb, icb)];
+
+ const int _icb = g * nb_ic + icb;
+ if (pd()->rtus_.reduce_src_) {
+ rp.ws = rtus_space
+ + ithr * pd()->rtus_.space_per_thread_
+ + _icb * jcp.is * jcp.ic_block;
+
+ if (ocb == 0) {
+ rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
+ rtus_driver_->ker_(&rp);
+ }
+
+ p.bcast_data = rp.ws;
+ } else
+ p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw);
+
+ kernel_->jit_ker(&p);
+ }
+
+ ocb += load_step;
+ }
+
+ iwork += bcast_step;
+ }
+ };
+
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad(ctx).get<data_t>(key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bias = padded_bias;
+ }
+
+ parallel(0, ker);
+
+ if (pd()->wants_zero_pad_dst())
+ ctx.memory(MKLDNN_ARG_DST)->zero_pad();
+}
+
+/* convolution backward wtr data */
+
+void jit_avx2_1x1_convolution_bwd_data_t::execute_backward_data(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+
+ const auto &jcp = kernel_->jcp;
+ auto rtus_space = scratchpad(ctx).get<data_t>(key_conv_rtus_space);
+
+ // TODO (Roma): remove this restriction
+ assert(jcp.stride_w == 1 && jcp.stride_h == 1);
+ const int ndims = diff_dst_d.ndims();
+
+ const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[ndims - 3];
+ const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][ndims - 3];
+
+ const int nb_ic = jcp.nb_load;
+ const int nb_oc = jcp.nb_reduce;
+ const int os_block = jcp.bcast_block;
+ const int nb_oc_blocking = jcp.nb_reduce_blocking;
+
+ const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
+
+ auto step = [](int default_step, int remaining, int tail_step) {
+ assert(default_step <= tail_step);
+ return remaining < tail_step ? remaining : default_step;
+ };
+
+ auto ker = [&](const int ithr, const int nthr) {
+ auto p = jit_1x1_conv_call_s();
+ auto rp = rtus_driver_t<avx2>::call_params_t();
+
+ int start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ int load_step = 0;
+ for (int icb = 0; icb < jcp.nb_load; icb += load_step) {
+ load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
+ jcp.nb_load_blocking_max);
+
+ p.load_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
+ load_step * jcp.ic_block);
+ rp.icb = p.load_dim / jcp.ic_block;
+
+ int bcast_step;
+ for (int iwork = start; iwork < end; iwork += bcast_step) {
+ int n{0}, g{0}, osb{0};
+ nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
+ jcp.nb_bcast);
+
+ bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
+ jcp.nb_bcast_blocking_max);
+ bcast_step = nstl::min(bcast_step, end - iwork);
+
+ const int os = osb * os_block;
+ p.bcast_dim = this_block_size(os, jcp.os,
+ bcast_step * os_block);
+ rp.os = p.bcast_dim;
+
+ const int oh = os / jcp.ow;
+ const int ow = os % jcp.ow;
+ const int ih = nstl::max(oh * stride_h - pad_t, 0);
+ const int iw = nstl::max(ow * stride_w - pad_l, 0);
+ rp.iw_start = iw;
+
+ const int _icb = g * nb_ic + icb;
+ rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
+ if (pd()->rtus_.reduce_src_) {
+ rp.ws = rtus_space
+ + ithr * pd()->rtus_.space_per_thread_;
+ p.output_data = rp.ws;
+ } else
+ p.output_data = rp.src;
+
+ for (int ocb = 0; ocb < jcp.nb_reduce;
+ ocb += jcp.nb_reduce_blocking) {
+ const int _ocb = g * nb_oc + ocb;
+ size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh,
+ ow);
+ p.bcast_data = &diff_dst[diff_dst_off];
+
+ p.load_data = &weights[pd()->with_groups()
+ ? weights_d.blk_off(g, ocb, icb)
+ : weights_d.blk_off(ocb, icb)];
+
+ p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
+
+ p.reduce_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
+ nb_oc_blocking * jcp.oc_block);
+
+ kernel_->jit_ker(&p);
+ }
+
+ if (pd()->rtus_.reduce_src_)
+ rtus_driver_->ker_(&rp);
+ }
+ }
+ };
+
+ parallel(0, ker);
+}
+
+/* convolution backward wtr weights */
+
+jit_avx2_1x1_convolution_bwd_weights_t::jit_avx2_1x1_convolution_bwd_weights_t(
+ const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , kernel_(nullptr)
+ , rtus_driver_(nullptr)
+{
+ kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr());
+ reducer_weights_ =
+ new cpu_reducer_2d_t<data_type::f32>(pd()->reducer_wei_conf_);
+ reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
+ init_rtus_driver<avx2>(this);
+}
+
+void jit_avx2_1x1_convolution_bwd_weights_t::execute_backward_weights(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
+ const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
+
+ const auto &jcp = kernel_->jcp;
+ auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
+
+ data_t *diff_bias = pd()->wants_padded_bias()
+ ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
+
+ auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_bia);
+ auto rb = this->reducer_bias_;
+ rb->init(reducer_bia_scratchpad);
+
+ auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_wei);
+ auto rw = this->reducer_weights_;
+ rw->init(reducer_wei_scratchpad);
+
+ const int ndims = diff_dst_d.ndims();
+ // TODO (Roma): remove this restriction
+ assert(jcp.stride_w == 1 && jcp.stride_h == 1);
+
+ const int nb_ic = jcp.nb_bcast;
+ const int nb_ic_blocking = jcp.nb_bcast_blocking;
+ const int bcast_work = div_up(nb_ic, nb_ic_blocking);
+
+ const int nb_oc = jcp.nb_load;
+ const int nb_oc_blocking = jcp.nb_load_blocking;
+ const int load_work = div_up(nb_oc, nb_oc_blocking);
+
+ const int sp_dim = jcp.reduce_dim;
+ const int mb_sp_work = jcp.mb * sp_dim;
+
+ const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[ndims - 3];
+ const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][ndims - 3];
+
+ auto step = [](int default_step, int remaining, int tail_step) {
+ assert(default_step <= tail_step);
+ return remaining < tail_step ? remaining : default_step;
+ };
+
+ auto oc_ic_sp_loop = [=](int sp_start, int sp_end, bool first_image,
+ data_t *store_to, size_t store_to_ld, const data_t *diff_dst,
+ const data_t *src, int ithr) {
+ auto p = jit_1x1_conv_call_s();
+ auto rp = rtus_driver_t<avx2>::call_params_t();
+
+ p.output_stride = store_to_ld * sizeof(float);
+ const int sp_step_def = jcp.nb_reduce_blocking * jcp.reduce_block;
+
+ int oc_b_step = 0;
+ for (int oc_b = 0; oc_b < nb_oc_blocking; oc_b += oc_b_step) {
+ oc_b_step = step(12, nb_oc_blocking - oc_b, 18);
+ p.load_dim = oc_b_step * jcp.oc_block;
+
+ int ic_b_step = 0;
+ for (int ic_b = 0; ic_b < nb_ic_blocking; ic_b += ic_b_step) {
+ ic_b_step = step(12, nb_ic_blocking - ic_b, 18);
+ p.bcast_dim = ic_b_step * jcp.ic_block;
+ rp.icb = p.bcast_dim / jcp.ic_block;
+
+ p.output_data = store_to + oc_b * store_to_ld
+ + ic_b * jcp.ic_block * jcp.oc_block;
+
+ /* spatial reduction */
+ int sp_step = 0;
+ for (int sp = sp_start; sp < sp_end; sp += sp_step) {
+ sp_step = step(sp_step_def, sp_end - sp, 192);
+ p.reduce_dim = sp_step;
+ rp.os = p.reduce_dim;
+
+ p.first_last_flag = sp == sp_start && first_image
+ ? FLAG_REDUCE_FIRST : 0;
+
+ p.load_data = diff_dst
+ + (oc_b * jcp.reduce_dim + sp) * jcp.oc_block;
+
+ if (pd()->rtus_.reduce_src_) {
+ const int oh = sp / jcp.ow;
+ const int ow = sp % jcp.ow;
+
+ const int ih = nstl::max(oh * stride_h - pad_t, 0);
+ const int iw = nstl::max(ow * stride_w - pad_l, 0);
+ rp.iw_start = iw;
+
+ rp.ws = rtus_space
+ + ithr * pd()->rtus_.space_per_thread_
+ + (ic_b * jcp.is + sp) * jcp.ic_block;
+ if (ndims == 3)
+ rp.src = src
+ + iw * src_d.blocking_desc().strides[2];
+ else
+ rp.src = src
+ + ih * src_d.blocking_desc().strides[2]
+ + iw * src_d.blocking_desc().strides[3];
+
+ if (oc_b == 0)
+ rtus_driver_->ker_(&rp);
+
+ p.bcast_data = rp.ws;
+ } else
+ p.bcast_data = src
+ + (ic_b * jcp.reduce_dim + sp) * jcp.ic_block;
+
+ kernel_->jit_ker(&p);
+ }
+ }
+ }
+ };
+
+ auto ker = [&](const int ithr, const int nthr) {
+ assert(nthr == rw->balancer().nthr_);
+
+ const int w_njobs = rw->balancer().ithr_njobs(ithr);
+ if (w_njobs == 0) return;
+
+ /* setup: independent work (oc, ic) */
+ const int w_job_start = rw->balancer().ithr_job_off(ithr);
+ int g{0}, load_i{0}, bcast_i{0};
+ nd_iterator_init(w_job_start, g, jcp.ngroups, load_i, load_work,
+ bcast_i, bcast_work);
+
+ /* setup: reduction work (mb, sp) */
+ int mb_sp_start{0}, mb_sp_end{0};
+ balance211(mb_sp_work, rw->balancer().nthr_per_group_,
+ rw->balancer().id_in_group(ithr), mb_sp_start, mb_sp_end);
+ int img_start{0}, sp_start{0};
+ nd_iterator_init(mb_sp_start, img_start, jcp.mb, sp_start, sp_dim);
+
+ /* independent work */
+ for (int iwork = 0; iwork < w_njobs; ++iwork) {
+ const int oc_b = nb_oc_blocking * load_i;
+ const int ic_b = nb_ic_blocking * bcast_i;
+
+ const int _ic_b = g * nb_ic + ic_b;
+ const int _oc_b = g * nb_oc + oc_b;
+
+ data_t *store_to;
+ size_t store_to_ld;
+
+ if (rw->balancer().nthr_per_group_ == 1) {
+ const size_t off = pd()->with_groups()
+ ? diff_weights_d.blk_off(g, oc_b, ic_b)
+ : diff_weights_d.blk_off(oc_b, ic_b);
+ store_to = &diff_weights[off];
+ store_to_ld = jcp.ic * jcp.oc_block;
+ } else {
+ const size_t off = iwork * rw->balancer().job_size_;
+ store_to =
+ rw->get_local_ptr(ithr, reducer_wei_scratchpad) + off;
+ store_to_ld = nb_ic_blocking * jcp.ic_block * jcp.oc_block;
+ }
+
+ /* reduction work */
+ int img = img_start;
+ int sp = sp_start;
+ int sp_step = 0;
+ for (int mb_sp = mb_sp_start; mb_sp < mb_sp_end; mb_sp += sp_step)
+ {
+ sp_step = nstl::min(sp_dim - sp, mb_sp_end - mb_sp);
+
+ const bool first_image = img == img_start;
+ oc_ic_sp_loop(sp, sp + sp_step, first_image, store_to,
+ store_to_ld, &diff_dst[diff_dst_d.blk_off(img, _oc_b)],
+ &src[src_d.blk_off(img, _ic_b)], ithr);
+
+ sp = 0;
+ img += 1;
+ }
+
+ nd_iterator_step(g, jcp.ngroups, load_i, load_work, bcast_i,
+ bcast_work);
+ }
+ rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
+ };
+
+ auto ker_bias = [&](int ithr, int nthr) {
+ assert(nthr == rb->balancer().nthr_);
+
+ const int b_job_start = rb->balancer().ithr_job_off(ithr);
+ const int b_njobs = rb->balancer().ithr_njobs(ithr);
+
+ if (b_njobs == 0) return;
+
+ /* reduction dimension */
+ int img_start{0}, img_end{0};
+ balance211(jcp.mb, rb->balancer().nthr_per_group_,
+ rb->balancer().id_in_group(ithr), img_start, img_end);
+
+ /* jobs */
+ int g_start{0}, ocb_start{0};
+ nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, nb_oc);
+
+ for (int img = img_start; img < img_end; ++img) {
+ int g = g_start, ocb = ocb_start;
+ for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
+ const size_t _oc = g * nb_oc + ocb;
+
+ const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
+ data_t *d_bias =
+ rb->get_local_ptr(ithr, diff_bias, reducer_bia_scratchpad)
+ + b_job_loc * rb->balancer().job_size_;
+
+ if (img == img_start)
+ for (int o = 0; o < 8; ++o) d_bias[o] = 0.;
+
+ for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) {
+ PRAGMA_OMP_SIMD()
+ for (int o = 0; o < 8; ++o)
+ d_bias[o] += d_dst[o];
+ d_dst += 8;
+ }
+
+ nd_iterator_step(g, jcp.ngroups, ocb, nb_oc);
+ }
+ }
+ rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
+ };
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ ker(ithr, nthr);
+ if (pd()->with_bias())
+ ker_bias(ithr, nthr);
+ });
+
+ /* TODO: put this in ker_bias */
+ if (pd()->wants_padded_bias()) {
+ assert(jcp.ngroups == 1);
+ for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
+ diff_bias_in[oc] = diff_bias[oc];
+ }
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp
new file mode 100644
index 0000000000..9762242173
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_1x1_convolution.hpp
@@ -0,0 +1,344 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX2_1x1_CONVOLUTION_HPP
+#define CPU_JIT_AVX2_1x1_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+#include "cpu_reducer.hpp"
+
+#include "jit_avx2_1x1_conv_kernel_f32.hpp"
+#include "jit_uni_1x1_conv_utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_avx2_1x1_convolution_fwd_t: public cpu_primitive_t {
+ // TODO: (Roma) Code duplication duplication! Remove with templates
+ // (maybe...)!
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_(), rtus_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
+ jit_avx2_1x1_convolution_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ const convolution_desc_t *conv_d = desc();
+ const memory_desc_t *src_d = src_md();
+ rtus_prepare(this, conv_d, src_d, dst_md());
+
+ status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
+ *conv_d, *src_d, *weights_md(), *dst_md(), *attr());
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
+
+ rtus_prepare_space_info(this, scratchpad);
+
+ return status::success;
+ }
+
+ jit_1x1_conv_conf_t jcp_;
+ reduce_to_unit_stride_t rtus_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto wei_tag = with_groups()
+ ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o)
+ : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o);
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ template <cpu_isa_t isa, typename conv_t>
+ friend void init_rtus_driver(conv_t *self);
+
+ jit_avx2_1x1_convolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , kernel_(nullptr), rtus_driver_(nullptr)
+ {
+ kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr());
+ init_rtus_driver<avx2>(this);
+ }
+
+ ~jit_avx2_1x1_convolution_fwd_t() {
+ delete kernel_;
+ delete rtus_driver_;
+ }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx2_1x1_conv_kernel_f32 *kernel_;
+ rtus_driver_t<avx2> *rtus_driver_;
+};
+
+struct jit_avx2_1x1_convolution_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_data_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_(), rtus_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
+ jit_avx2_1x1_convolution_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::undef, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ const convolution_desc_t *conv_d = desc();
+ const memory_desc_t *diff_src_d = diff_src_md();
+ rtus_prepare(this, conv_d, diff_src_d, diff_dst_md());
+
+ status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
+ *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(),
+ *attr());
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
+
+ rtus_prepare_space_info(this, scratchpad);
+
+ return status::success;
+ }
+
+ jit_1x1_conv_conf_t jcp_;
+ reduce_to_unit_stride_t rtus_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto wei_tag = with_groups()
+ ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i)
+ : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i);
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ template <cpu_isa_t isa, typename conv_t>
+ friend void init_rtus_driver(conv_t *self);
+
+ jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , kernel_(nullptr)
+ , rtus_driver_(nullptr)
+ {
+ kernel_ = new jit_avx2_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr());
+ init_rtus_driver<avx2>(this);
+ }
+
+ ~jit_avx2_1x1_convolution_bwd_data_t() {
+ delete kernel_;
+ delete rtus_driver_;
+ }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx2_1x1_conv_kernel_f32 *kernel_;
+ rtus_driver_t<avx2> *rtus_driver_;
+};
+
+struct jit_avx2_1x1_convolution_bwd_weights_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_weights_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_(), rtus_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
+ jit_avx2_1x1_convolution_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_weights
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ const convolution_desc_t *conv_d = desc();
+ const memory_desc_t *src_d = src_md();
+ rtus_prepare(this, conv_d, src_d, diff_dst_md());
+
+ status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
+ *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(),
+ *attr());
+ if (status != status::success) return status;
+
+ init_balancers();
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
+
+ rtus_prepare_space_info(this, scratchpad);
+
+ auto reducer_bia_scratchpad = memory_tracking::registrar_t(
+ scratchpad, memory_tracking::names::prefix_reducer_bia);
+ reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
+
+ auto reducer_wei_scratchpad = memory_tracking::registrar_t(
+ scratchpad, memory_tracking::names::prefix_reducer_wei);
+ reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
+
+ return status::success;
+ }
+
+ jit_1x1_conv_conf_t jcp_;
+ cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
+ cpu_reducer_2d_t<data_type::f32>::conf_t reducer_wei_conf_;
+ reduce_to_unit_stride_t rtus_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto wei_tag = with_groups()
+ ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o)
+ : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o);
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+
+ private:
+ void init_balancers() {
+ const int ic_block = jcp_.bcast_block;
+ const int nb_ic = jcp_.nb_bcast;
+ const int nb_ic_blocking = jcp_.nb_bcast_blocking;
+ const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking);
+
+ const int oc_block = jcp_.load_block;
+ const int nb_oc = jcp_.nb_load;
+ const int nb_oc_blocking = jcp_.nb_load_blocking;
+ const int load_work = utils::div_up(nb_oc, nb_oc_blocking);
+
+ const int job_size
+ = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block;
+ const int njobs_x = bcast_work;
+ const int njobs_y = jcp_.ngroups * load_work;
+
+ const int max_threads = mkldnn_get_max_threads();
+ const size_t max_buffer_size = max_threads * job_size * 8;
+
+ if (with_bias()) {
+ reducer_bia_conf_.init(reduce_balancer_t(max_threads,
+ oc_block, jcp_.ngroups * jcp_.oc / oc_block,
+ jcp_.mb, max_buffer_size));
+ }
+
+ reducer_wei_conf_.init(
+ reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x,
+ jcp_.mb * jcp_.nb_reduce, max_buffer_size),
+ job_size / nb_oc_blocking, nb_oc_blocking, ic_block,
+ nb_ic * ic_block * oc_block, nb_oc);
+ }
+ };
+
+ template <cpu_isa_t isa, typename conv_t>
+ friend void init_rtus_driver(conv_t *self);
+
+ jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd);
+
+ ~jit_avx2_1x1_convolution_bwd_weights_t() {
+ delete kernel_;
+ delete rtus_driver_;
+ delete reducer_weights_;
+ delete reducer_bias_;
+ }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_weights(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_weights(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx2_1x1_conv_kernel_f32 *kernel_;
+ cpu_reducer_2d_t<data_type::f32> *reducer_weights_;
+ cpu_reducer_t<data_type::f32> *reducer_bias_;
+ rtus_driver_t<avx2> *rtus_driver_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp
new file mode 100644
index 0000000000..e24770a2da
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp
@@ -0,0 +1,1501 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+* Copyright 2018 YANDEX LLC
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "cpu_memory.hpp"
+
+#include "jit_avx2_conv_kernel_f32.hpp"
+
+#define GET_OFF(field) offsetof(jit_conv_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+using namespace Xbyak;
+
+void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
+ int pad_l, int pad_r, int oc_blocks)
+{
+ int iw = jcp.iw;
+ int ih = jcp.ih;
+ int id = jcp.id;
+ int kw = jcp.kw;
+ int kh = jcp.kh;
+ int kd = jcp.kd;
+ int nb_ic = jcp.nb_ic;
+ int stride_w = jcp.stride_w;
+ int dilate_w = jcp.dilate_w + 1;
+ int ic_blk = jcp.ic_block;
+ int oc_blk = jcp.oc_block;
+
+ for (int ki = 0; ki < kw; ki++) {
+ int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
+ int jj_end = ur_w
+ - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
+ for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ size_t inp_off;
+ if (one_of(jcp.src_tag, ncw, nchw, ncdhw))
+ inp_off = sizeof(float)*((size_t)ifm2*id*ih*iw
+ + (ki*dilate_w + jj*stride_w - pad_l));
+ else
+ inp_off = sizeof(float)*((ki*dilate_w + jj*stride_w
+ - pad_l)*ic_blk + ifm2);
+ vbroadcastss(Ymm(oc_blocks * ur_w + jj),
+ make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
+ }
+
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk
+ + ki * ic_blk * oc_blk + ifm2 * oc_blk;
+ vmovups(ymm15, ptr[aux_reg_kernel + sizeof(float) * ker_off]);
+ for (int jj = jj_start; jj < jj_end; jj++)
+ if (mayiuse(avx2))
+ vfmadd231ps(Ymm(ur_w * ii + jj),
+ Ymm(oc_blocks * ur_w + jj), ymm15);
+ else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
+ vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
+ vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
+ }
+ }
+ }
+ }
+}
+
+void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
+ int pad_l, int pad_r, char pad_tag,
+ int oc_blocks, char oc_blocks_tag)
+{
+ Label kw_loop;
+
+ int iw = jcp.iw;
+ int ih = jcp.ih;
+ int id = jcp.id;
+ int kw = jcp.kw;
+ int kh = jcp.kh;
+ int kd = jcp.kd;
+ int nb_ic = jcp.nb_ic;
+ int stride_w = jcp.stride_w;
+ int dilate_w = jcp.dilate_w + 1;
+ int ic_blk = jcp.ic_block;
+ int oc_blk = jcp.oc_block;
+
+ xor_(ki_iter, ki_iter);
+ L(kw_loop);
+ {
+ int jj_start = 0;
+ int jj_end = ur_w;
+ for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ size_t inp_off;
+ if (one_of(jcp.src_tag, ncw, nchw, ncdhw))
+ inp_off = sizeof(float)*((size_t)ifm2 * id * ih * iw
+ + (jj * stride_w - pad_l));
+ else
+ inp_off = sizeof(float)*((jj * stride_w - pad_l) * ic_blk
+ + ifm2);
+ vbroadcastss(Ymm(oc_blocks * ur_w + jj),
+ make_safe_addr(aux_reg_input, inp_off, reg_long_offt));
+ }
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ int aux_kernel_offset =
+ ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ifm2 * oc_blk;
+ vmovups(ymm15, ptr[aux_reg_kernel
+ + sizeof(float) * aux_kernel_offset]);
+ for (int jj = jj_start; jj < jj_end; jj++)
+ if (mayiuse(avx2))
+ vfmadd231ps(Ymm(ur_w * ii + jj),
+ Ymm(oc_blocks * ur_w + jj), ymm15);
+ else { // Intel AVX support
+ vmulps(ytmp, ymm15, Ymm(oc_blocks * ur_w + jj));
+ vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), ytmp);
+ }
+ }
+ }
+ add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
+ add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw, ncdhw)
+ ? dilate_w : ic_blk * dilate_w));
+
+ inc(ki_iter);
+ cmp(ki_iter, kw);
+ jl(kw_loop, T_NEAR);
+ }
+}
+
+void jit_avx2_conv_fwd_kernel_f32::width_blk_step(int ur_w,
+ int pad_l, int pad_r, char pad_tag,
+ int oc_blocks, char oc_blocks_tag)
+{
+ int iw = jcp.iw;
+ int kw = jcp.kw;
+ int ow = jcp.ow;
+ int oh = jcp.oh;
+ int od = jcp.od;
+ int dilate_h = jcp.dilate_h + 1;
+ int dilate_w = jcp.dilate_w + 1;
+ int ic_blk = jcp.ic_block;
+ int oc_blk = jcp.oc_block;
+ const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
+ ? 1 : ic_blk;
+ const int inp_off = one_of(jcp.src_tag, ncw, nchw, ncdhw)
+ ? dilate_w : ic_blk * dilate_w;
+
+ Label init_done, init_first;
+
+ if (!jcp.with_sum) {
+ test(reg_ci_flag, FLAG_IC_FIRST);
+ jne(init_first, T_NEAR);
+ }
+
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ size_t offt =
+ sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
+ vmovups(Ymm(ur_w * ii + jj),
+ make_safe_addr(reg_output, offt, reg_long_offt));
+ }
+ }
+
+ if (jcp.with_sum && jcp.with_bias) {
+ test(reg_ci_flag, FLAG_IC_FIRST);
+ je(init_done, T_NEAR);
+
+ for (int ii = 0; ii < oc_blocks; ii++)
+ for (int jj = 0; jj < ur_w; jj++)
+ vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
+ yword[reg_bias + sizeof(float) * ii * oc_blk]);
+ }
+
+ jmp(init_done);
+
+ L(init_first);
+ if (this->jcp.with_bias) {
+ for (int ii = 0; ii < oc_blocks; ii++)
+ for (int jj = 0; jj < ur_w; jj++)
+ vmovups(Ymm(ur_w * ii + jj),
+ yword[reg_bias + sizeof(float) * ii * oc_blk]);
+ } else {
+ for (int ii = 0; ii < oc_blocks; ii++)
+ for (int jj = 0; jj < ur_w; jj++)
+ uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj));
+ }
+
+ L(init_done);
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ mov(aux_reg_input, reg_input);
+ mov(aux_reg_kernel, reg_kernel);
+ }
+
+ Label skip_kh_loop, skip_kd_loop, kd_loop;
+ if (jcp.ndims == 5) {
+ push(reg_output);
+ push(oi_iter);
+
+ mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
+ mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
+ mov(aux_reg_inp_d, reg_input);
+
+ if ((jcp.dilate_d >= jcp.id)
+ || (jcp.kd - 1) * (jcp.dilate_d + 1) < jcp.f_pad) {
+ cmp(reg_ki, 0);
+ je(skip_kd_loop, T_NEAR);
+ }
+ L(kd_loop);
+ mov(kj, ptr[param1 + GET_OFF(kh_padding)]);
+ } else {
+ mov(kj, reg_kh);
+ }
+
+ if (jcp.ndims == 5) {
+ mov(aux_reg_input, aux_reg_inp_d);
+ mov(aux_reg_kernel, aux_reg_ker_d);
+ }
+
+ if ((jcp.dilate_h >= jcp.ih)
+ || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
+ cmp(kj, 0);
+ je(skip_kh_loop, T_NEAR);
+ }
+ Label kh_loop;
+ L(kh_loop);
+ {
+ if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
+ oh_step_nopad(ur_w, pad_l, pad_r, pad_tag, oc_blocks,
+ oc_blocks_tag);
+ sub(aux_reg_input, sizeof(float) * kw * inp_off);
+ add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
+ } else {
+ oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
+ add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
+ add(aux_reg_input, sizeof(float) * iw * dilate_h * inp_mult);
+ }
+
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_loop, T_NEAR);
+ }
+
+ L(skip_kh_loop);
+
+ if (jcp.ndims == 5) {
+ add(aux_reg_inp_d,
+ sizeof(float) * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mult);
+ add(aux_reg_ker_d, sizeof(float) * jcp.kw * jcp.kh * jcp.oc_block
+ * jcp.ic_block);
+
+ dec(reg_ki);
+ cmp(reg_ki, 0);
+ jg(kd_loop, T_NEAR);
+ L(skip_kd_loop);
+
+ pop(oi_iter);
+ pop(reg_output);
+ }
+
+ Label regular_store;
+
+ if (jcp.with_eltwise) {
+ test(reg_ci_flag, FLAG_IC_LAST);
+ je(regular_store, T_NEAR);
+
+ eltwise_injector_->compute_vector_range(0, oc_blocks * ur_w);
+
+ L(regular_store);
+ }
+
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ const size_t o_off
+ = sizeof(float) * ((size_t)ii * od * oh * ow + jj) * oc_blk;
+ Ymm reg_out = Ymm(ur_w * ii + jj);
+ vmovups(make_safe_addr(reg_output, o_off, reg_long_offt), reg_out);
+ }
+ }
+}
+
+inline void jit_avx2_conv_fwd_kernel_f32::solve_common(
+ int oc_blocks, char oc_blocks_tag)
+{
+ int ur_w = jcp.ur_w;
+ int ur_w_tail = jcp.ur_w_tail;
+ int n_oi = jcp.ow / ur_w;
+ int iw = jcp.iw;
+ int kw = jcp.kw;
+ int ic_blk = jcp.ic_block;
+ int oc_blk = jcp.oc_block;
+ int dilate_w = jcp.dilate_w + 1;
+ int str_w = jcp.stride_w;
+ const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_blk;
+
+ int l_pad = jcp.l_pad;
+ int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
+ - (iw + l_pad - 1));
+ int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
+ - (iw + l_pad - 1);
+ if (r_pad1 > 0) n_oi--;
+
+ if (l_pad > 0) {
+ n_oi--;
+ if (n_oi < 0 && r_pad1 > 0)
+ width_blk_step(ur_w, l_pad, r_pad1,
+ 'l', oc_blocks, oc_blocks_tag); // "lrpad"
+ else
+ width_blk_step(ur_w, l_pad, 0,
+ 'l', oc_blocks, oc_blocks_tag); // "lpad"
+ add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
+ add(reg_output, sizeof(float) * ur_w * oc_blk);
+ }
+
+ Label ow_loop;
+ xor_(oi_iter, oi_iter);
+
+ if (n_oi > 0) {
+ L(ow_loop);
+
+ width_blk_step(ur_w, 0, 0,
+ 'm', oc_blocks, oc_blocks_tag); // "middle"
+ add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
+ add(reg_output, sizeof(float) * ur_w * oc_blk);
+
+ inc(oi_iter);
+ cmp(oi_iter, n_oi);
+ jl(ow_loop, T_NEAR);
+ }
+
+ if (r_pad1 > 0 && n_oi >=0) {
+ width_blk_step(ur_w, 0, r_pad1,
+ 'r', oc_blocks, oc_blocks_tag); // "rpad"
+ add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
+ add(reg_output, sizeof(float) * ur_w * oc_blk);
+ }
+
+ if (ur_w_tail != 0)
+ width_blk_step(ur_w_tail, 0, r_pad,
+ 't', oc_blocks, oc_blocks_tag); // "tail"
+}
+
+void jit_avx2_conv_fwd_kernel_f32::generate()
+{
+ this->preamble();
+
+ mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
+ mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
+ mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
+ if (jcp.with_bias)
+ mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
+ mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
+ mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
+ mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
+
+ int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
+ Label tail, exit;
+
+ if (jcp.nb_oc > jcp.nb_oc_blocking) {
+ cmp(reg_oc_blocks, jcp.nb_oc_blocking);
+ jne(nb_oc_tail ? tail : exit, T_NEAR);
+
+ solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
+ jmp(exit, T_NEAR);
+
+ if (nb_oc_tail) {
+ L(tail);
+ cmp(reg_oc_blocks, nb_oc_tail);
+ jne(exit, T_NEAR);
+ solve_common(nb_oc_tail, '0' + nb_oc_tail);
+ }
+
+ L(exit);
+ } else if (jcp.nb_oc == jcp.nb_oc_blocking) {
+ solve_common(jcp.nb_oc_blocking, '0' + jcp.nb_oc_blocking);
+ } else {
+ solve_common(nb_oc_tail, '0' + nb_oc_tail);
+ }
+
+ this->postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+}
+
+bool jit_avx2_conv_fwd_kernel_f32::post_ops_ok(
+ jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+
+ switch (p.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
+ case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_avx2_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr)
+{
+ if (!mayiuse(avx)) return status::unimplemented;
+
+ jcp.prop_kind = cd.prop_kind;
+
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ int ndims = src_d.ndims();
+ jcp.ndims = ndims;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+
+ jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
+ jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
+ jcp.iw = src_d.dims()[ndims-1];
+ jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
+ jcp.oh = (ndims == 3) ? 1 :dst_d.dims()[ndims-2];
+ jcp.ow = dst_d.dims()[ndims-1];
+ jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
+ jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
+ jcp.kw = weights_d.dims()[with_groups + ndims-1];
+
+ jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
+ jcp.l_pad = cd.padding[0][ndims-3];
+ jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
+ jcp.stride_h = (ndims == 3) ? 1 :cd.strides[ndims-4];
+ jcp.stride_w = cd.strides[ndims-3];
+
+ jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
+ jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
+ jcp.dilate_w = cd.dilates[ndims-3];
+
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+ - (jcp.ih + jcp.t_pad - 1);
+
+ if (ndims == 3) {
+ jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c);
+ jcp.wei_tag = weights_d.matches_one_of_tag(
+ Owi8o, gOwi8o, OIw8i8o, gOIw8i8o);
+ jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c);
+ } else if (ndims == 4) {
+ jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c);
+ jcp.wei_tag = weights_d.matches_one_of_tag(
+ Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o);
+ jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c);
+ } else if (ndims == 5) {
+ jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c);
+ jcp.wei_tag = weights_d.matches_one_of_tag(
+ Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o);
+ jcp.dst_tag = dst_d.matches_one_of_tag(nCdhw8c);
+ }
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ jcp.with_sum = p.find(primitive_kind::sum) != -1;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise) {
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+ if (!mayiuse(avx2) && jcp.eltwise.alg != alg_kind::eltwise_relu)
+ return status::unimplemented;
+ }
+
+ const int simd_w = 8;
+ const bool flat = jcp.ic < simd_w;
+ const bool mimo = !flat;
+
+
+ /* Grouped channel offset to support 'non-blocked data' format for
+ * convolution sizes with '(input_channel / ngroups) < simd' */
+ jcp.nonblk_group_off =
+ one_of(jcp.src_tag, ncw, nchw, ncdhw) && jcp.ngroups > 1 ? jcp.ic : 1;
+
+ bool ok_to_pad_channels = true
+ && jcp.ngroups == 1;
+
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ if (mimo)
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+ }
+
+ bool args_ok = true
+ && IMPLICATION(flat, true
+ && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc)
+ && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o,
+ gOdhwi8o))
+ && IMPLICATION(mimo, true
+ && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c)
+ && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o,
+ OIdhw8i8o, gOIdhw8i8o))
+ && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c);
+ if (!args_ok) return status::unimplemented;
+
+ jcp.ur_h = 1; /* no code-unrolling by h so far */
+ jcp.ur_w = 3;
+
+ jcp.oc_block = simd_w;
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+
+ jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
+
+ // Intel AVX and Intel AVX2 kernels need 2 and 1 temporary YMMs, respectively
+ // Thus, we can only assign 14 or 15 YMMs for data storage
+ const int num_avail_regs = mayiuse(avx2) ? 15 : 14;
+ if (!mayiuse(avx2)) {
+ if ((jcp.nb_oc_blocking + 1) * jcp.ur_w > num_avail_regs) {
+ // current register assignment requires more YMMs than available
+ // adjust one of nb_oc_block, ur_w preserving to ur_w >= l_pad
+ if (jcp.ur_w > jcp.l_pad && jcp.ur_w > 1)
+ jcp.ur_w -= 1;
+ else
+ for (int b = 3; b > 1; b--)
+ if (jcp.nb_oc % b == 0) {
+ jcp.nb_oc_blocking = b;
+ break;
+ }
+ }
+ }
+
+ if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
+ jcp.ur_w_tail = jcp.ow % jcp.ur_w;
+
+ args_ok = true
+ && jcp.oc % simd_w == 0
+ && jcp.l_pad <= jcp.ur_w
+ && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
+ || (jcp.stride_w == 1 && jcp.stride_h == 1))
+ && IMPLICATION(mimo, jcp.ic % simd_w == 0);
+ if (!args_ok) return status::unimplemented;
+
+ int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+
+ if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
+ /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
+ jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
+ nstl::min(jcp.ow, num_avail_regs / 2));
+ jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
+ jcp.ur_w_tail = jcp.ow % jcp.ur_w;
+ /* check again ... */
+ r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+ if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
+ return status::unimplemented;
+ }
+ assert(jcp.nb_oc_blocking > 0);
+ assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
+
+ jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+
+ if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
+ jcp.nb_ic_blocking = 12;
+ jcp.nb_ic_blocking_max = 16;
+ } else {
+ jcp.nb_ic_blocking = 1;
+ jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
+ }
+
+ return status::success;
+}
+
+void jit_avx2_conv_fwd_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+}
+
+void jit_avx2_conv_bwd_data_kernel_f32::compute_loop(int ur_w, int l_overflow,
+ int r_overflow)
+{
+ int kw = jcp.kw;
+ int kh = jcp.kh;
+ int kd = jcp.kd;
+ int iw = jcp.iw;
+ int ih = jcp.ih;
+ int id = jcp.id;
+ int ow = jcp.ow;
+
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ int nb_ic_block = jcp.nb_ic_blocking;
+ int stride_w = jcp.stride_w;
+ int stride_h = jcp.stride_h;
+
+ Label kd_loop, skip_kd_loop;
+ Label oc_loop, skip_oc_loop;
+
+ for (int ii = 0; ii < nb_ic_block; ii++)
+ for (int jj = 0; jj < ur_w; jj++) {
+ uni_vpxor(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
+ Ymm(ur_w * ii + jj));
+ }
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ cmp(reg_channel_work, 0);
+ jle(skip_oc_loop, T_NEAR);
+ xor_(reg_channel, reg_channel);
+
+ mov(aux_reg_ddst_oc_loop, reg_ddst);
+ mov(aux_reg_kernel_oc_loop, reg_kernel);
+
+ L(oc_loop);
+ mov(aux_reg_ddst, aux_reg_ddst_oc_loop);
+ mov(aux_reg_kernel, aux_reg_kernel_oc_loop);
+ }
+
+ if (jcp.ndims == 5) {
+ assert(jcp.nb_oc_blocking == 1);
+ push(oi_iter);
+
+ mov(reg_ki, ptr[this->param1 + GET_OFF(kd_padding)]);
+ mov(aux_reg_dst_d, reg_ddst);
+ mov(aux_reg_ker_d, ptr[this->param1 + GET_OFF(filt)]);
+
+ L(kd_loop);
+ mov(kj, ptr[this->param1 + GET_OFF(kh_padding)]);
+ } else {
+ mov(kj, reg_kh);
+ }
+
+ if (jcp.ndims == 5) {
+ mov(aux_reg_ddst, aux_reg_dst_d);
+ mov(aux_reg_kernel, aux_reg_ker_d);
+ }
+
+ Label kh_loop, skip_kh_loop;
+ cmp(kj, 0);
+ jle(skip_kh_loop, T_NEAR);
+ L(kh_loop); {
+ for (int ki = 0; ki < kw; ki++) {
+ int jj_start = get_iw_start(ki, l_overflow); // 0;
+ int jj_end = get_iw_end(ur_w, ki, r_overflow); // ur_w;
+ for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++) {
+
+ for (int jj = jj_start ; jj < jj_end; jj += stride_w) {
+ int aux_output_offset
+ = (jj + jcp.l_pad - ki) / stride_w * jcp.oc_block + ofm2;
+ vbroadcastss(Ymm(nb_ic_block * ur_w + jj / stride_w),
+ ptr[aux_reg_ddst
+ + sizeof(float) * aux_output_offset]);
+ }
+
+ for (int ii = 0; ii < nb_ic_block; ii++) {
+ int aux_kernel_offset
+ = ii * kd * kh * kw * jcp.ic_block * jcp.oc_block
+ + ki * jcp.ic_block * jcp.oc_block
+ + ofm2 * jcp.ic_block;
+ vmovups(ymm15,
+ ptr[aux_reg_kernel
+ + sizeof(float) * aux_kernel_offset]);
+ for (int jj = jj_start; jj < jj_end; jj += stride_w)
+ vfmadd231ps(Ymm(ur_w * ii + jj),
+ Ymm(nb_ic_block * ur_w + jj / stride_w), ymm15);
+ }
+ }
+ }
+ add(aux_reg_kernel, sizeof(float) * stride_h * kw * oc_block
+ * ic_block);
+ sub(aux_reg_ddst, sizeof(float) * ow * oc_block);
+
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_loop, T_NEAR);
+ }
+ L(skip_kh_loop);
+
+ if (jcp.ndims == 5) {
+ sub(aux_reg_dst_d,
+ sizeof(float) * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
+ add(aux_reg_ker_d,
+ sizeof(float) * jcp.kw * jcp.kh * oc_block * ic_block);
+
+ dec(reg_ki);
+ cmp(reg_ki, 0);
+ jg(kd_loop, T_NEAR);
+ L(skip_kd_loop);
+
+ pop(oi_iter);
+ }
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ int ddst_oc_shift = sizeof(float) * jcp.od * jcp.oh * jcp.ow
+ * jcp.oc_block;
+ int kernel_oc_shift = sizeof(float) * jcp.kd * jcp.kh * jcp.kw
+ * jcp.ic * jcp.oc_block;
+
+ add(aux_reg_ddst_oc_loop, ddst_oc_shift);
+ add(aux_reg_kernel_oc_loop, kernel_oc_shift);
+
+ inc(reg_channel);
+ cmp(reg_channel, reg_channel_work);
+ jl(oc_loop, T_NEAR);
+
+ L(skip_oc_loop);
+ mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
+ }
+
+ Label no_update_label;
+ cmp(reg_channel, 0);
+ je(no_update_label, T_NEAR);
+ for (int ii = 0; ii < nb_ic_block; ii++) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ size_t offt =
+ sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
+ vmovups(Ymm(15),
+ make_safe_addr(reg_dsrc, offt, reg_long_offt));
+ vaddps(Ymm(ur_w * ii + jj), Ymm(ur_w * ii + jj),
+ Ymm(15));
+
+ }
+ }
+ L(no_update_label);
+
+ for (int ii = 0; ii < nb_ic_block; ii++)
+ for (int jj = 0; jj < ur_w; jj++) {
+ size_t offt =
+ sizeof(float) * ((size_t)ii * id * ih * iw + jj) * ic_block;
+ vmovups(make_safe_addr(reg_dsrc, offt, reg_long_offt),
+ Ymm(ur_w * ii + jj));
+ }
+}
+
+void jit_avx2_conv_bwd_data_kernel_f32::generate() {
+ preamble();
+
+ mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
+ mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
+ mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
+ mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
+ mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
+ mov(reg_channel_work, ptr[param1 + GET_OFF(ch_blocks)]);
+
+ int ddst_shift = sizeof(float) * (jcp.ur_w / jcp.stride_w) * jcp.ic_block;
+ int dsrc_shift = sizeof(float) * jcp.ur_w * jcp.oc_block;
+
+ int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
+ int r_overflow = nstl::max(0, (jcp.kw - 1
+ - nstl::max(0, jcp.r_pad)) / jcp.stride_w);
+ int r_overflow1 = nstl::max(0, (jcp.kw - 1
+ - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
+
+ int n_oi = jcp.iw / jcp.ur_w;
+ if (r_overflow1 > 0)
+ n_oi--;
+
+ if (jcp.ur_w == jcp.iw) {
+ compute_loop(jcp.ur_w, l_overflow, r_overflow);
+ } else if (n_oi == 0) {
+ compute_loop(jcp.ur_w, l_overflow, r_overflow1);
+ add(reg_dsrc, dsrc_shift);
+ add(reg_ddst, ddst_shift);
+ if (jcp.ur_w_tail != 0)
+ compute_loop(jcp.ur_w_tail, 0, r_overflow);
+ } else {
+ xor_(oi_iter, oi_iter);
+ if (l_overflow > 0) {
+ compute_loop(jcp.ur_w, l_overflow, 0);
+ add(reg_dsrc, dsrc_shift);
+ add(reg_ddst, ddst_shift);
+ inc(oi_iter);
+ }
+
+ if ((l_overflow <= 0 && n_oi > 0) || (l_overflow > 0 && n_oi > 1)) {
+ Label ow_loop;
+ L(ow_loop); {
+ compute_loop(jcp.ur_w, 0, 0);
+ add(reg_dsrc, dsrc_shift);
+ add(reg_ddst, ddst_shift);
+ inc(oi_iter);
+ cmp(oi_iter, n_oi); jl(ow_loop, T_NEAR);
+ }
+ }
+
+ if (r_overflow1 > 0 ) {
+ compute_loop(jcp.ur_w, 0, r_overflow1);
+ add(reg_dsrc, dsrc_shift);
+ add(reg_ddst, ddst_shift);
+ }
+
+ if (jcp.ur_w_tail != 0)
+ compute_loop(jcp.ur_w_tail, 0, r_overflow);
+ }
+
+ this->postamble();
+}
+
+status_t jit_avx2_conv_bwd_data_kernel_f32::init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d)
+{
+ if (!mayiuse(avx2)) return status::unimplemented;
+
+ const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
+
+ int ndims = diff_src_d.ndims();
+ jcp.ndims = ndims;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = diff_src_d.dims()[0];
+
+ jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
+
+ jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
+ jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
+ jcp.iw = diff_src_d.dims()[ndims-1];
+ jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
+ jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
+ jcp.ow = diff_dst_d.dims()[ndims-1];
+
+ jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
+ jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
+ jcp.kw = weights_d.dims()[with_groups + ndims - 1];
+
+ jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
+ jcp.l_pad = cd.padding[0][ndims-3];
+
+ jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
+ jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
+ jcp.stride_w = cd.strides[ndims-3];
+
+ jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
+ jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
+ jcp.dilate_w = cd.dilates[ndims-3];
+
+ const int simd_w = 8;
+
+ /* derivatives */
+ jcp.idp = jcp.id + 2 * jcp.f_pad;
+ jcp.ihp = jcp.ih + 2 * jcp.t_pad;
+ jcp.iwp = jcp.iw + 2 * jcp.l_pad;
+ jcp.ohp = jcp.oh; /* do we really need */
+ jcp.owp = jcp.ow; /* padded output ??? */
+
+ bool ok_to_pad_channels = true
+ && jcp.ngroups == 1;
+
+ /* gemm-based convolution performs better in these cases */
+ if (jcp.ic < simd_w && jcp.kw > 3 && jcp.stride_w > 1)
+ return status::unimplemented;
+
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+ }
+
+ jcp.ic_block = (jcp.ic % simd_w) ? 1 : simd_w;
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+
+ jcp.oc_block = simd_w;
+ if (jcp.oc % jcp.oc_block) return status::unimplemented;
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+
+ jcp.ur_h = 1; /* no code-unrolling by h so far */
+ jcp.nb_ic_blocking = 1;
+ jcp.nb_oc_blocking = 1;
+ jcp.ur_w = 1;
+
+ if(one_of(ndims, 3, 4) && jcp.ow < 40)
+ jcp.nb_oc_blocking = jcp.ow < 15 ? 4 : 2;
+
+ if (ndims == 3) {
+ jcp.src_tag = diff_src_d.matches_one_of_tag(nCw8c);
+ jcp.wei_tag = weights_d.matches_one_of_tag(OIw8i8o, gOIw8o8i);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c);
+ } else if (ndims == 4) {
+ jcp.src_tag = diff_src_d.matches_one_of_tag(nChw8c);
+ jcp.wei_tag = weights_d.matches_one_of_tag(OIhw8o8i, gOIhw8o8i);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c);
+ } else if (ndims == 5) {
+ jcp.src_tag = diff_src_d.matches_one_of_tag(nCdhw8c);
+ jcp.wei_tag = weights_d.matches_one_of_tag(OIdhw8o8i, gOIdhw8o8i);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c);
+ }
+
+ bool args_ok = true
+ && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c)
+ && one_of(jcp.wei_tag, gOIw8o8i, OIw8i8o, gOIhw8o8i, OIhw8o8i,
+ gOIdhw8o8i, OIdhw8o8i)
+ && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c)
+ && jcp.stride_w == jcp.stride_h
+ && jcp.stride_d == 1
+ && jcp.dilate_d == 0
+ && jcp.dilate_h == 0
+ && jcp.dilate_w == 0
+ && jcp.ic % simd_w == 0
+ && jcp.oc % simd_w == 0
+ && jcp.od == (jcp.idp - jcp.kd) / jcp.stride_d + 1
+ && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
+ && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
+ if (!args_ok) return status::unimplemented;
+ jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad;
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad;
+ int l_overflow = nstl::max(0, (jcp.kw - 1 - jcp.l_pad) / jcp.stride_w);
+
+ const int max_regs = 15; /* Maximun number of registers available for
+ result accumulation and delta dst data.
+ One additional register is reserved for weights
+ data. */
+
+ /* Find the best blocking with maximum number of fma instructions
+ per ur_w * nb_ic_blocking compute loops. Number of required registers
+ is num_regs = ur_w * nb_ic_blocking + ur_w / stride_w <= max_regs.
+ ur_w must be divisible by stride_w */
+ if (jcp.stride_w + 1 > max_regs) /* Minimal possible registers
+ distribution exceeds max_regs */
+ return status::unimplemented;
+
+ int best_nfmas = 0;
+ for (int b = 1; b <= 4; b++)
+ {
+ if (jcp.nb_ic % b != 0)
+ continue;
+
+ for (int u = jcp.stride_w;
+ u * b + u / jcp.stride_w <= max_regs && u < jcp.iw + jcp.stride_w;
+ u += jcp.stride_w)
+ {
+ int ur_w = nstl::min(u, jcp.iw);
+ /* maximum 1 step with l_overflow so far */
+ if (l_overflow * jcp.stride_w > ur_w && ur_w != jcp.iw)
+ continue;
+ int nfmas = utils::div_up(ur_w, jcp.stride_w) * b;
+ if (nfmas > best_nfmas
+ || (nfmas == best_nfmas && jcp.ur_w < ur_w)) {
+ jcp.ur_w = ur_w;
+ jcp.nb_ic_blocking = b;
+ best_nfmas = nfmas;
+ }
+ }
+ }
+ if (best_nfmas == 0) /* can't find appropriate blocking */
+ return status::unimplemented;
+
+ jcp.ur_w_tail = jcp.iw % jcp.ur_w;
+
+ int r_overflow_no_tail = nstl::max(0, (jcp.kw - 1 - jcp.ur_w_tail
+ - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
+ /* maximum 1 ur_w block with r_overflow so far */
+ if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
+ return status::unimplemented;
+
+ if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
+ return status::unimplemented;
+
+ return status::success;
+}
+
+void jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ UNUSED(scratchpad);
+ UNUSED(jcp);
+}
+
+void jit_avx2_conv_bwd_weights_kernel_f32::generate() {
+ this->preamble();
+
+ mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
+ mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
+ mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
+ compute_oh_loop_common();
+ this->postamble();
+}
+
+status_t jit_avx2_conv_bwd_weights_kernel_f32::init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &diff_weights_d,
+ const memory_desc_wrapper &diff_dst_d) {
+ if (!mayiuse(avx2)) return status::unimplemented;
+
+ const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
+ int ndims = src_d.ndims();
+ jcp.ndims = ndims;
+
+ jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+
+ jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+
+ jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
+ jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
+ jcp.iw = src_d.dims()[ndims-1];
+ jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
+ jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
+ jcp.ow = diff_dst_d.dims()[ndims-1];
+
+ jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
+ jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
+ jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
+
+ jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
+ jcp.l_pad = cd.padding[0][ndims-3];
+
+ jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
+ jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
+ jcp.stride_w = cd.strides[ndims-3];
+
+ jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
+ jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
+ jcp.dilate_w = cd.dilates[ndims-3];
+
+ if (ndims == 3) {
+ jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c);
+ jcp.wei_tag = diff_weights_d.matches_one_of_tag(
+ Owi8o, gOwi8o, OIw8i8o, gOIw8i8o);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCw8c);
+ } else if (ndims == 4) {
+ jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c);
+ jcp.wei_tag = diff_weights_d.matches_one_of_tag(
+ Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(nChw8c);
+ } else if (ndims == 5) {
+ jcp.src_tag = src_d.matches_one_of_tag(ncdhw, ndhwc, nCdhw8c);
+ jcp.wei_tag = diff_weights_d.matches_one_of_tag(
+ Odhwi8o, gOdhwi8o, OIdhw8i8o, gOIdhw8i8o);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(nCdhw8c);
+ }
+ jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
+
+ const bool flat = jcp.ic == 3;
+ const bool mimo = !flat;
+
+ const int simd_w = 8;
+
+ jcp.b_pad = nstl::max(
+ 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
+ jcp.r_pad = nstl::max(
+ 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
+
+ int back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d + jcp.kd - jcp.id
+ - jcp.f_pad);
+ if (ndims == 5)
+ if (jcp.f_pad != 0 || back_pad != 0)
+ return status::unimplemented;
+
+ const int max_h_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1);
+ const int max_w_pad = ((jcp.kw - 1) * (jcp.dilate_w + 1) + 1);
+ const bool boundaries_ok = true
+ && jcp.t_pad < max_h_pad && jcp.b_pad < max_h_pad
+ && jcp.l_pad < max_w_pad && jcp.r_pad < max_w_pad;
+ if (!boundaries_ok)
+ return status::unimplemented;
+
+ bool ok_to_pad_channels = true
+ && jcp.ngroups == 1;
+
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ if (mimo)
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+ }
+
+ bool args_ok = true
+ && IMPLICATION(flat, true
+ && one_of(jcp.src_tag, ncw, nwc, nchw, nhwc, ncdhw, ndhwc)
+ && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o, Odhwi8o,
+ gOdhwi8o))
+ && IMPLICATION(mimo, true
+ && one_of(jcp.src_tag, nCw8c, nChw8c, nCdhw8c)
+ && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o,
+ OIdhw8i8o, gOIdhw8i8o))
+ && one_of(jcp.dst_tag, nCw8c, nChw8c, nCdhw8c)
+ && IMPLICATION(mimo, jcp.ic % simd_w == 0)
+ && jcp.oc % simd_w == 0
+ && jcp.kw < 14
+ && jcp.kh <= jcp.t_pad + jcp.ih /* [bwd_w:r1] */
+ && jcp.kh <= jcp.ih /* [bwd_w:r2] */
+ && jcp.kd <= jcp.f_pad + jcp.id
+ && jcp.kd <= jcp.id
+ && jcp.t_pad < jcp.kh /* XXX: must fix the kernel! */
+ && jcp.dilate_d == 0
+ && jcp.dilate_h == 0
+ && jcp.dilate_w == 0;
+ if (!args_ok) return status::unimplemented;
+
+ jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+
+ jcp.oc_block = simd_w;
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+ jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
+
+ return status::success;
+}
+
+void jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+}
+
+inline void jit_avx2_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
+{
+ Label kd_comeback_loop;
+ mov(kj, jcp.kd); //FIXME (Anton): this works only if f_pad = back_pad = 0
+ L(kd_comeback_loop); {
+ const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
+ ? 1 : jcp.ic_block;
+ sub(aux_reg_input, sizeof(float) * jcp.iw * jcp.ih * inp_mult);
+ sub(aux_reg_kernel, sizeof(float) * jcp.kw * jcp.kh * jcp.ic_block
+ * jcp.oc_block);
+ dec(kj);
+ cmp(kj, 0);
+ jg(kd_comeback_loop, T_NEAR);
+ }
+}
+
+inline void jit_avx2_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
+{
+ mov(kj, reg_kh);
+ Label kh_comeback_loop;
+ L(kh_comeback_loop); {
+ const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
+ ? 1 : jcp.ic_block;
+ sub(reg_input, sizeof(float) * jcp.iw * inp_mult);
+ sub(reg_kernel, sizeof(float) * jcp.kw * jcp.ic_block * jcp.oc_block);
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_comeback_loop, T_NEAR);
+ }
+}
+
+inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_ic_block_step(
+ int ur_w, int pad_l, int pad_r, int ic_block_step, int input_offset,
+ int kernel_offset, int output_offset)
+{
+ const int kw = jcp.kw;
+ const int ic_block = jcp.ic_block;
+ const int oc_block = jcp.oc_block;
+ for (int i_kw = 0; i_kw < kw; i_kw++)
+ for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
+ size_t off
+ = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
+ + kernel_offset;
+ vmovups(Ymm(i_kw * ic_block_step + i_ic), yword[reg_kernel + off]);
+ }
+
+ for (int i_ur = 0; i_ur < ur_w; i_ur++) {
+ vmovups(Ymm(kw * ic_block_step + 0),
+ yword[reg_output
+ + sizeof(float) * i_ur * oc_block + output_offset]);
+
+ for (int i_kw = 0; i_kw < kw; i_kw++) {
+ int i_iw = i_ur * jcp.stride_w + i_kw;
+ if (i_iw - pad_l < 0
+ || i_iw > (ur_w - 1) * jcp.stride_w + kw - 1 - pad_r)
+ continue;
+ for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
+ size_t i_off = (size_t)input_offset + sizeof(float)*(
+ one_of(jcp.src_tag, ncw, nchw, ncdhw)
+ ? (i_iw - pad_l) + i_ic
+ * ((size_t)jcp.id * jcp.ih * jcp.iw)
+ : (i_iw - pad_l) * ic_block + i_ic);
+ vbroadcastss(Ymm(kw * ic_block_step + 1),
+ make_safe_addr(reg_input, i_off, reg_long_offt));
+ vfmadd231ps(Ymm(i_kw * ic_block_step + i_ic),
+ Ymm(kw * ic_block_step + 0),
+ Ymm(kw * ic_block_step + 1));
+ }
+ }
+ }
+
+ for (int i_kw = 0; i_kw < kw; i_kw++)
+ for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
+ size_t off
+ = sizeof(float) * (i_kw * ic_block + i_ic) * jcp.oc_block
+ + kernel_offset;
+ vmovups(yword[reg_kernel + off],
+ Ymm(i_kw * ic_block_step + i_ic));
+ }
+}
+
+inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_disp()
+{
+ int ic_block_step;
+ if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
+ ic_block_step = jcp.kw >= 5 ? 1 : jcp.ic_block;
+ } else {
+ ic_block_step = jcp.kw > 7 ? 1
+ : jcp.kw > 3 ? 2
+ : jcp.kw > 1 ? 4 : 8;
+ }
+
+ const int max_ur_w = jcp.ow > 56 ? 14 : 28;
+
+ if (jcp.ow <= max_ur_w)
+ compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
+ else
+ compute_oh_step_common(ic_block_step, max_ur_w);
+
+ if (jcp.ndims == 5) {
+ od_step_comeback_pointers();
+ mov(reg_input, aux_reg_input);
+ mov(reg_kernel, aux_reg_kernel);
+ } else {
+ oh_step_comeback_pointers();
+ }
+}
+
+inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_unroll_ow(
+ int ic_block_step, int max_ur_w)
+{
+ UNUSED(max_ur_w);
+
+ const int ic_block = jcp.ic_block;
+ const int oc_block = jcp.oc_block;
+ int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
+ Label kd_loop;
+
+ const int r_pad
+ = nstl::max(0,
+ (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
+
+ if (jcp.ndims == 5) {
+ mov(aux_reg_input, reg_input);
+ mov(aux_reg_kernel, reg_kernel);
+ mov(ki, jcp.kd);
+ L(kd_loop);
+ mov(reg_input, aux_reg_input);
+ mov(reg_kernel, aux_reg_kernel);
+ }
+
+ mov(kj, reg_kh);
+ Label kh_loop;
+ L(kh_loop); {
+ xor_(b_ic, b_ic);
+ Label ic_block_loop;
+ L(ic_block_loop); {
+ compute_ic_block_step(jcp.ow, jcp.l_pad, r_pad, ic_block_step, 0,
+ 0, 0);
+ size_t inp_icblk_stride = sizeof(float) * ic_block_step
+ * (one_of(jcp.src_tag, ncw, nchw, ncdhw)
+ ? jcp.id*jcp.ih*jcp.iw : 1);
+ safe_add(reg_input, inp_icblk_stride, reg_long_offt);
+ add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
+ add(b_ic, ic_block_step);
+ cmp(b_ic, ic_block);
+ jl(ic_block_loop, T_NEAR);
+ }
+ if(one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
+ size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
+ safe_sub(reg_input, offt, reg_long_offt);
+ add(reg_input, sizeof(float) * jcp.iw);
+ } else {
+ add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
+ }
+ add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_loop, T_NEAR);
+ }
+
+ if (jcp.ndims == 5) {
+ add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
+ add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
+ * oc_block);
+ dec(ki);
+ cmp(ki, 0);
+ jg(kd_loop, T_NEAR);
+ }
+
+}
+
+inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_step_common(
+ int ic_block_step, int max_ur_w)
+{
+ const int ic_block = jcp.ic_block;
+ const int oc_block = jcp.oc_block;
+ const int stride_w = jcp.stride_w;
+ int inp_mul = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : jcp.ic_block;
+ Label kd_loop;
+
+ const int r_pad = jcp.r_pad;
+
+ int ur_w = nstl::min(jcp.ow, max_ur_w);
+ int ur_w_trips = jcp.ow / ur_w;
+ int ur_w_tail = jcp.ow % ur_w;
+ if ((ur_w_tail == 0 && r_pad != 0) || r_pad >= ur_w_tail) {
+ if (ur_w_trips > 1) {
+ ur_w_tail += ur_w;
+ ur_w_trips--;
+ } else {
+ ur_w_tail += (ur_w - ur_w / 2);
+ ur_w = ur_w / 2;
+ }
+ }
+ const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw) ? 1 : ic_block;
+
+ int input_comeback = (ur_w_trips * ur_w * stride_w - jcp.l_pad) * inp_mult;
+ int output_comeback = ur_w_trips * ur_w * oc_block;
+
+ if (jcp.ndims == 5) {
+ mov(aux_reg_input, reg_input);
+ mov(aux_reg_kernel, reg_kernel);
+ mov(ki, jcp.kd);
+ L(kd_loop);
+ mov(reg_input, aux_reg_input);
+ mov(reg_kernel, aux_reg_kernel);
+ }
+
+ mov(kj, reg_kh);
+ Label kh_loop;
+ L(kh_loop); {
+ xor_(b_ic, b_ic);
+ Label ic_block_loop;
+ L(ic_block_loop); {
+ if (jcp.l_pad != 0) {
+ ur_w_trips--;
+ compute_ic_block_step(ur_w,
+ jcp.l_pad, 0, ic_block_step, 0, 0, 0);
+ add(reg_input, sizeof(float)
+ * (ur_w * stride_w - jcp.l_pad) * inp_mult);
+ add(reg_output, sizeof(float) * ur_w * oc_block);
+ }
+
+ if (ur_w_trips > 0) {
+ xor_(reg_ur_w_trips, reg_ur_w_trips);
+ Label ow_block_loop;
+ L(ow_block_loop); {
+ compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
+ add(reg_input, sizeof(float) * ur_w * stride_w * inp_mult);
+ add(reg_output, sizeof(float) * ur_w * oc_block);
+
+ inc(reg_ur_w_trips);
+ cmp(reg_ur_w_trips, ur_w_trips);
+ jl(ow_block_loop, T_NEAR);
+ }
+ }
+
+ if (ur_w_tail > 0)
+ compute_ic_block_step(ur_w_tail,
+ 0, r_pad, ic_block_step, 0, 0, 0);
+
+ sub(reg_input, sizeof(float) * input_comeback);
+ sub(reg_output, sizeof(float) * output_comeback);
+
+ size_t inp_icblk_stride = sizeof(float) * ic_block_step
+ * (one_of(jcp.src_tag, ncw, nchw, ncdhw)
+ ? jcp.id*jcp.ih*jcp.iw : 1);
+ safe_add(reg_input, inp_icblk_stride, reg_long_offt);
+ add(reg_kernel, sizeof(float) * ic_block_step * oc_block);
+
+ add(b_ic, ic_block_step);
+ cmp(b_ic, jcp.ic_block);
+ jl(ic_block_loop, T_NEAR);
+ }
+ if (one_of(jcp.src_tag, ncw, nchw, ncdhw)) {
+ size_t offt = sizeof(float) * jcp.id * jcp.ih * jcp.iw * ic_block;
+ safe_sub(reg_input, offt, reg_long_offt);
+ add(reg_input, sizeof(float) * jcp.iw);
+ } else {
+ add(reg_input, sizeof(float) * (jcp.iw - 1) * ic_block);
+ }
+ add(reg_kernel, sizeof(float) * (jcp.kw - 1) * ic_block * oc_block);
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_loop, T_NEAR);
+ }
+
+ if (jcp.ndims == 5) {
+ add(aux_reg_input, sizeof(float) * jcp.ih * jcp.iw * inp_mul);
+ add(aux_reg_kernel, sizeof(float) * jcp.kh * jcp.kw * ic_block
+ * oc_block);
+ dec(ki);
+ cmp(ki, 0);
+ jg(kd_loop, T_NEAR);
+ }
+
+}
+
+inline void jit_avx2_conv_bwd_weights_kernel_f32::compute_oh_loop_common()
+{
+ const int icoc_block = jcp.ic_block * jcp.oc_block;
+ const int t_pad = jcp.t_pad;
+ const int stride_h = jcp.stride_h;
+ const int inp_mult = one_of(jcp.src_tag, ncw, nchw, ncdhw)
+ ? 1 : jcp.ic_block;
+ int b_pad = jcp.b_pad;
+
+ Label oh_tpad_loop, oh_loop, oh_loop_end;
+
+ mov(reg_kh, jcp.kh);
+ xor_(reg_ih_count, reg_ih_count);
+ xor_(reg_oj, reg_oj);
+ if (t_pad > 0) {
+ assert(jcp.kh <= t_pad + jcp.ih); /* [bwd_w:r1] */
+ mov(reg_kh, jcp.kh <= t_pad + jcp.ih ? jcp.kh - t_pad : jcp.ih);
+ add(reg_kernel, sizeof(float) * t_pad * jcp.kw * icoc_block);
+
+ L(oh_tpad_loop); {
+ compute_oh_step_disp();
+ add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
+ sub(reg_kernel, sizeof(float) * stride_h * jcp.kw * icoc_block);
+
+ inc(reg_oj);
+ add(reg_ih_count, stride_h);
+ add(reg_kh, stride_h);
+
+ /* the overlap between input and kernel may not reach kernel size.
+ * so far we do not support that (until we put constant here) */
+ const int final_inp_ker_overlap = jcp.kh; /* [bwd_w:r2] */
+ cmp(reg_kh, final_inp_ker_overlap);
+ jl(oh_tpad_loop, T_NEAR);
+ }
+
+ if (t_pad % stride_h != 0) {
+ int inp_corr = stride_h - t_pad % stride_h;
+ add(reg_kernel, sizeof(float) * inp_corr * jcp.kw * icoc_block);
+ add(reg_input, sizeof(float) * inp_corr * jcp.iw * inp_mult);
+ }
+ }
+ cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
+ jge(oh_loop_end, T_NEAR);
+ cmp(reg_oj, jcp.oh);
+ jge(oh_loop, T_NEAR);
+
+ mov(reg_kh, jcp.kh);
+ L(oh_loop); {
+ compute_oh_step_disp();
+ add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
+ add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
+
+ inc(reg_oj);
+ add(reg_ih_count, stride_h);
+
+ cmp(reg_ih_count, jcp.ih + t_pad - jcp.kh + 1);
+ jge(oh_loop_end, T_NEAR);
+
+ cmp(reg_oj, jcp.oh);
+ jl(oh_loop, T_NEAR);
+ }
+ L(oh_loop_end);
+ if (b_pad > 0) {
+ Label oh_bpad_loop, oh_bpad_loop_end;
+ cmp(reg_oj, jcp.oh);
+ jge(oh_bpad_loop_end, T_NEAR);
+
+ mov(reg_kh, jcp.ih + t_pad);
+ sub(reg_kh, reg_ih_count);
+ L(oh_bpad_loop); {
+ compute_oh_step_disp();
+ add(reg_input, sizeof(float) * stride_h * jcp.iw * inp_mult);
+ add(reg_output, sizeof(float) * jcp.ow * jcp.oc_block);
+
+ sub(reg_kh, stride_h);
+ cmp(reg_kh, 0);
+ jle(oh_bpad_loop_end, T_NEAR);
+
+ inc(reg_oj);
+ cmp(reg_oj, jcp.oh);
+ jl(oh_bpad_loop, T_NEAR);
+ }
+ L(oh_bpad_loop_end);
+ }
+}
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp
new file mode 100644
index 0000000000..412c50c9ee
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.hpp
@@ -0,0 +1,225 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX2_CONV_KERNEL_F32_HPP
+#define JIT_AVX2_CONV_KERNEL_F32_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "cpu_memory.hpp"
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_avx2_conv_fwd_kernel_f32: public jit_generator {
+ jit_avx2_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
+ const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
+ {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx2>(this,
+ jcp.eltwise);
+
+ this->generate();
+ jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
+ }
+
+ ~jit_avx2_conv_fwd_kernel_f32() {
+ delete eltwise_injector_;
+ }
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_fwd_kernel_f32)
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ jit_conv_conf_t jcp;
+ const primitive_attr_t &attr_;
+ void (*jit_ker)(jit_conv_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ reg64_t reg_input = rax;
+ reg64_t aux_reg_input = r8;
+ reg64_t reg_kernel = rdx;
+ reg64_t aux_reg_kernel = r9;
+ reg64_t reg_output = rsi;
+ reg64_t reg_bias = rbx;
+
+ reg64_t aux_reg_inp_d = r11;
+ reg64_t aux_reg_ker_d = abi_not_param1;
+
+ reg64_t reg_ki = rsi;
+ reg64_t kj = r10;
+ reg64_t oi_iter = r11;
+ reg64_t ki_iter = r12;
+ reg64_t reg_kh = abi_not_param1;
+ reg64_t reg_oc_blocks = r14;
+ reg64_t imm_addr64 = r15;
+ reg64_t reg_long_offt = r15;
+ Xbyak::Reg32 reg_ci_flag = r13d;
+
+ Xbyak::Ymm ytmp = Xbyak::Ymm(14);
+
+ jit_uni_eltwise_injector_f32<avx2> *eltwise_injector_;
+
+ inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r,
+ int oc_blocks);
+ inline void oh_step_nopad(int ur_w, int pad_l, int pad_r,
+ char pad_label, int oc_blocks, char oc_blocks_label);
+ inline void width_blk_step(int ur_w, int pad_l, int pad_r,
+ char pad_label, int oc_blocks, char oc_blocks_label);
+ inline void solve_common(int oc_blocks, char oc_blocks_label);
+
+ void generate();
+};
+
+struct jit_avx2_conv_bwd_data_kernel_f32: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_data_kernel_f32)
+
+ jit_avx2_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
+ {
+ this->generate();
+ jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
+ }
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ jit_conv_conf_t jcp;
+ void (*jit_ker)(jit_conv_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+
+ reg64_t reg_ddst = rax;
+ reg64_t aux_reg_ddst = r8;
+ reg64_t reg_kernel = rdx;
+ reg64_t aux_reg_kernel = r10;
+ reg64_t reg_dsrc = rsi;
+ reg64_t aux_reg_ddst_oc_loop = rbx; // used in ndims < 5 case only
+ reg64_t aux_reg_kernel_oc_loop = abi_not_param1; /* used in ndims < 5
+ case only */
+
+ reg64_t aux_reg_dst_d = r12; // used in ndims == 5 case only
+ reg64_t aux_reg_ker_d = r14; // used in ndims == 5 case only
+
+ reg64_t reg_ki = abi_not_param1; // used in ndims == 5 case only
+ reg64_t kj = r11;
+ reg64_t oi_iter = r12;
+ reg64_t reg_kh = r14;
+ reg64_t reg_channel = r13; // used in ndims < 5 case only
+ reg64_t reg_channel_work = r9; // used in ndims < 5 case only
+ reg64_t reg_long_offt = r15;
+
+ inline void compute_loop(int ur_w, int l_overflow, int r_overflow);
+
+ void generate();
+
+ inline int get_iw_start(int ki, int l_overflow)
+ {
+ int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
+ + l_overflow * jcp.stride_w
+ - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
+ while (res < 0)
+ res += jcp.stride_w;
+
+ return res;
+ }
+
+ inline int get_iw_end(int ur_w, int ki, int r_overflow)
+ {
+ if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
+ ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
+ int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
+ + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
+ while (res < 0)
+ res += jcp.stride_w;
+
+ return ur_w - res;
+ }
+};
+
+struct jit_avx2_conv_bwd_weights_kernel_f32: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx2_conv_bwd_weights_kernel_f32)
+
+ jit_avx2_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
+ {
+ this->generate();
+ jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
+ }
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &diff_weights_d,
+ const memory_desc_wrapper &diff_dst_d);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ jit_conv_conf_t jcp;
+ void (*jit_ker)(jit_conv_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ reg64_t reg_input = rax;
+ reg64_t reg_kernel = rdx;
+ reg64_t reg_output = rsi;
+ reg64_t b_ic = abi_not_param1;
+ reg64_t kj = r8;
+ reg64_t reg_kh = r9;
+ reg64_t reg_ur_w_trips = r10;
+ reg64_t reg_tmp = r11;
+ reg64_t reg_oj = r15;
+ reg64_t reg_ih_count = rbx;
+ reg64_t aux_reg_input = r12;
+ reg64_t aux_reg_kernel = r13;
+ reg64_t ki = r14;
+ reg64_t reg_long_offt = r11;
+
+ inline void od_step_comeback_pointers();
+ inline void oh_step_comeback_pointers();
+ inline void compute_ic_block_step(int ur_w, int pad_l, int pad_r,
+ int ic_block_step, int input_offset, int kernel_offset,
+ int output_offset);
+ inline void compute_oh_step_disp();
+ inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w);
+ inline void compute_oh_step_common(int ic_block_step, int max_ur_w);
+ inline void compute_oh_loop_common();
+
+ void generate();
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp
new file mode 100644
index 0000000000..13f61e84fe
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.cpp
@@ -0,0 +1,410 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_avx2_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+#define src_blk_off(f, n, c, d, h, w) \
+ (pd()->ndims() == 3) \
+ ? (f).blk_off(n, c, w) \
+ : (pd()->ndims() == 4) \
+ ? (f).blk_off(n, c, h, w) \
+ : (f).blk_off(n, c, d, h, w)
+
+#define wht_blk_off_(f, g, ...) \
+ pd()->with_groups() ? (f).blk_off(g, __VA_ARGS__) : (f).blk_off(__VA_ARGS__)
+#define wht_blk_off(f, g, oc, ic, kd, kh, kw) \
+ (pd()->ndims() == 3) \
+ ? wht_blk_off_(f, g, oc, ic, kw) \
+ : (pd()->ndims() == 4) \
+ ? wht_blk_off_(f, g, oc, ic, kh, kw) \
+ : wht_blk_off_(f, g, oc, ic, kd, kh, kw)
+
+void jit_avx2_convolution_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const auto &jcp = kernel_->jcp;
+
+ int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
+ const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.od
+ * jcp.oh;
+
+ auto ker = [&](const int ithr, const int nthr) {
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ int icbb = 0;
+ while (icbb < jcp.nb_ic) {
+ int icb_step = jcp.nb_ic_blocking;
+ int icb_step_rem = jcp.nb_ic - icbb;
+ if (icb_step_rem < jcp.nb_ic_blocking_max)
+ icb_step = icb_step_rem;
+
+ size_t n{0}, g{0}, ocbb{0}, oh{0}, od{0};
+ nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
+ od, jcp.od, oh, jcp.oh);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ int ocb = ocbb * jcp.nb_oc_blocking;
+ int ocb_num = jcp.nb_oc_blocking;
+
+ for (int icb = icbb; icb < icbb + icb_step; ++icb) {
+ auto par_conv = jit_conv_call_s();
+
+ const int ij = oh * jcp.stride_h;
+ const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
+ const int i_b_overflow = nstl::max(jcp.ih, ij
+ + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
+
+ const int dj = od * jcp.stride_d;
+ const int d_t_overflow = nstl::max(0, jcp.f_pad - dj);
+ const int d_b_overflow = nstl::max(jcp.id, dj
+ + (jcp.kd-1) * (jcp.dilate_d+1) - jcp.f_pad+1) - jcp.id;
+
+ const size_t _oc = g * jcp.nb_oc + ocb;
+ const size_t _ic = g * jcp.nb_ic * jcp.nonblk_group_off + icb;
+
+ const int ih = nstl::max(ij - jcp.t_pad
+ + div_up(i_t_overflow,
+ (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
+
+ const int id = nstl::max(dj - jcp.f_pad
+ + div_up(d_t_overflow,
+ (jcp.dilate_d+1)) * (jcp.dilate_d + 1), 0);
+
+ par_conv.src = &src[src_blk_off(src_d, n,
+ jcp.ic == 3 ? 0 : _ic, id, ih, 0)];
+
+ par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, od, oh, 0)];
+
+ const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
+ const int wd = div_up(d_t_overflow, (jcp.dilate_d + 1));
+ par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
+ jcp.ic == 3 ? 0 : icb, wd, wh, 0)];
+
+ if (icb == 0) {
+ if (bias)
+ par_conv.bias =
+ &bias[bias_d.blk_off(_oc * jcp.oc_block)];
+ par_conv.flags |= FLAG_IC_FIRST;
+ }
+
+ if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) {
+ par_conv.flags |= FLAG_IC_LAST;
+ }
+
+ par_conv.oc_blocks =
+ nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
+
+ par_conv.kw_padding = 0;
+ const int kh_padding = jcp.kh
+ - div_up(i_t_overflow, (jcp.dilate_h + 1))
+ - div_up(i_b_overflow, (jcp.dilate_h + 1));
+ par_conv.kh_padding = nstl::max(0, kh_padding);
+
+ const int kd_padding = jcp.kd
+ - div_up(d_t_overflow, (jcp.dilate_d + 1))
+ - div_up(d_b_overflow, (jcp.dilate_d + 1));
+ par_conv.kd_padding = nstl::max(0, kd_padding);
+
+ kernel_->jit_ker(&par_conv);
+ }
+ nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
+ od, jcp.od, oh, jcp.oh);
+ }
+ icbb += icb_step;
+ }
+ };
+
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad(ctx).get<data_t>(key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bias = padded_bias;
+ }
+
+ parallel(0, ker);
+
+ if (pd()->wants_zero_pad_dst())
+ ctx.memory(MKLDNN_ARG_DST)->zero_pad();
+}
+
+void jit_avx2_convolution_bwd_data_t::execute_backward_data(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+
+ int icb_work = jcp.nb_ic / jcp.nb_ic_blocking;
+ int ih_block_size = jcp.ih;
+ int num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
+ size_t work_amount = jcp.mb * jcp.ngroups * icb_work * num_ih_blocks;
+ if (work_amount < (size_t)2 * mkldnn_get_max_threads()) {
+ ih_block_size = 1;
+ num_ih_blocks = utils::div_up(jcp.ih, ih_block_size);
+ work_amount *= num_ih_blocks;
+ }
+
+ auto ker = [&](const int ithr, const int nthr) {
+ size_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ size_t n{0}, g{0}, icbb{0}, ihb{0};
+ nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, icbb, icb_work,
+ ihb, num_ih_blocks);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ for (int oc = 0; oc < jcp.nb_oc; oc += jcp.nb_oc_blocking)
+ for (int id = 0; id < jcp.id; ++id) {
+ auto par_conv = jit_conv_call_s();
+
+ const int idp = jcp.id + 2 * jcp.f_pad;
+ const int d_t_overflow = nstl::max(0,
+ jcp.kd - 1 - id - jcp.f_pad);
+ const int back_pad = idp - jcp.id - jcp.f_pad;
+ const int d_b_overflow = nstl::max(0,
+ jcp.kd - 1 - (jcp.id - 1 - id) - back_pad);
+ const int od = id + jcp.f_pad - d_b_overflow;
+
+ int ih_start = ihb * ih_block_size;
+ int ih_end = nstl::min(jcp.ih, ih_start + ih_block_size);
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+
+ const int i_t_overflow = nstl::max(0, (jcp.kh - 1
+ - ih - jcp.t_pad) / jcp.stride_h);
+ const int i_b_overflow = nstl::max(0, (jcp.kh - jcp.ih
+ + ih - jcp.b_pad) / jcp.stride_h);
+ int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
+ + jcp.b_pad - ih) % jcp.stride_h);
+ int overflow_kh_lo = (ih + jcp.t_pad) % jcp.stride_h;
+
+ par_conv.kd_padding = jcp.kd - d_t_overflow - d_b_overflow;
+ par_conv.kh_padding = (overflow_kh_hi - overflow_kh_lo)
+ / jcp.stride_h + 1 - i_t_overflow - i_b_overflow;
+ par_conv.kw_padding = 0;
+
+ const int k_lo = overflow_kh_lo
+ + i_b_overflow * jcp.stride_h;
+ const int oh = (ih + jcp.t_pad - k_lo) / jcp.stride_h;
+
+ par_conv.src = &diff_src[src_blk_off(diff_src_d, n,
+ /*jcp.ic == 3 ? 0 :*/
+ g * jcp.nb_ic + jcp.nb_ic_blocking * icbb, id, ih, 0)];
+ par_conv.dst = &diff_dst[src_blk_off(diff_dst_d,
+ n, g * jcp.nb_oc + oc, od, oh, 0)];
+ par_conv.filt = &weights[wht_blk_off(weights_d, g, oc,
+ jcp.ic == 3 ? 0 : jcp.nb_ic_blocking * icbb,
+ d_b_overflow, k_lo, 0)];
+
+ par_conv.src_prf = nullptr;
+ par_conv.dst_prf = nullptr;
+ par_conv.filt_prf = nullptr;
+ par_conv.channel = oc;
+ par_conv.ch_blocks = nstl::min(jcp.nb_oc - oc,
+ jcp.nb_oc_blocking);
+
+ kernel_->jit_ker(&par_conv);
+ }
+ }
+ nd_iterator_step(n, jcp.mb, g, jcp.ngroups, icbb, icb_work, ihb,
+ num_ih_blocks);
+ }
+ };
+
+ parallel(0, ker);
+}
+
+void jit_avx2_convolution_bwd_weights_t::execute_backward_weights(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ data_t *diff_bias = pd()->wants_padded_bias()
+ ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+
+ auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_bia);
+ auto rb = this->reducer_bias_;
+ rb->init(reducer_bia_scratchpad);
+
+ auto reducer_wei_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_wei);
+ auto rw = this->reducer_weights_;
+ rw->init(reducer_wei_scratchpad);
+
+ auto ker = [&](int ithr, int nthr) {
+ assert(nthr == rw->balancer().nthr_);
+
+ const int w_job_start = rw->balancer().ithr_job_off(ithr);
+ const int w_njobs = rw->balancer().ithr_njobs(ithr);
+
+ if (w_njobs == 0) return;
+
+ /* reduction dimension */
+ int img_od_start{0}, img_od_end{0}, img{0}, od_s{0};
+ balance211(jcp.mb * jcp.od, rw->balancer().nthr_per_group_,
+ rw->balancer().id_in_group(ithr), img_od_start, img_od_end);
+
+ int img_start = img_od_start, img_end = img_od_end;
+ nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
+ const int img_first = img;
+
+ /* jobs */
+ int g_start{0}, ocb_start{0}, icb_start{0};
+ nd_iterator_init(w_job_start, g_start, jcp.ngroups, ocb_start,
+ jcp.nb_oc, icb_start, jcp.nb_ic);
+
+ while (img_start < img_end) {
+ int g = g_start, ocb = ocb_start, icb = icb_start;
+
+ const int work_rem = img_end - img_start;
+ const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
+ const int id_s = od_s * jcp.stride_d;
+ const int idp = jcp.id + jcp.f_pad + jcp.back_pad;
+
+ if (id_s < idp - jcp.back_pad - jcp.kd + 1)
+ for (int w_job_loc = 0; w_job_loc < w_njobs; ++w_job_loc) {
+ const size_t _oc = g * jcp.nb_oc + ocb;
+ const size_t _ic = g * jcp.nb_ic + icb;
+
+ /* TODO: put dw <-- 0 in kernel */
+ if (img == img_first)
+ array_set(rw->get_local_ptr(ithr, diff_weights,
+ reducer_wei_scratchpad) +
+ w_job_loc * rw->balancer().job_size_, 0,
+ rw->balancer().job_size_);
+
+ for (int od = od_s; od < od_e; ++od) {
+ const int id = od * jcp.stride_d;
+ if (id >= jcp.id - jcp.back_pad - jcp.kd + 1) break;
+
+ auto par_conv = jit_conv_call_s();
+ par_conv.src = &src[src_blk_off(src_d, img, _ic, id, 0, 0)];
+ par_conv.dst =
+ &diff_dst[src_blk_off(diff_dst_d, img, _oc, od, 0, 0)];
+ par_conv.filt = rw->get_local_ptr(ithr, diff_weights,
+ reducer_wei_scratchpad) +
+ w_job_loc * rw->balancer().job_size_;
+
+ kernel_->jit_ker(&par_conv);
+ }
+ nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc, icb,
+ jcp.nb_ic);
+ }
+ nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
+ }
+ rw->reduce(ithr, diff_weights, reducer_wei_scratchpad);
+ };
+
+ auto ker_bias = [&](int ithr, int nthr) {
+ assert(nthr == rb->balancer().nthr_);
+
+ const int b_job_start = rb->balancer().ithr_job_off(ithr);
+ const int b_njobs = rb->balancer().ithr_njobs(ithr);
+
+ if (b_njobs == 0) return;
+
+ /* reduction dimension */
+ int img_start{0}, img_end{0};
+ balance211(jcp.mb, rb->balancer().nthr_per_group_,
+ rb->balancer().id_in_group(ithr), img_start, img_end);
+
+ /* jobs */
+ int g_start{0}, ocb_start{0};
+ nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start,
+ jcp.nb_oc);
+
+ for (int img = img_start; img < img_end; ++img) {
+ int g = g_start, ocb = ocb_start;
+ for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
+ const size_t _oc = g * jcp.nb_oc + ocb;
+
+ const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
+ data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
+ reducer_bia_scratchpad) +
+ b_job_loc * rb->balancer().job_size_;
+
+ if (img == img_start)
+ for (int o = 0; o < 8; ++o)
+ d_bias[o] = 0.;
+
+ for (int dhw = 0; dhw < jcp.od * jcp.oh * jcp.ow; ++dhw) {
+ PRAGMA_OMP_SIMD()
+ for (int o = 0; o < 8; ++o)
+ d_bias[o] += d_dst[o];
+ d_dst += 8;
+ }
+
+ nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
+ }
+ }
+ rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
+ };
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ ker(ithr, nthr);
+ if (pd()->with_bias())
+ ker_bias(ithr, nthr);
+ });
+
+ /* TODO: put this in ker_bias */
+ if (pd()->wants_padded_bias()) {
+ assert(jcp.ngroups == 1);
+ for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
+ diff_bias_in[oc] = diff_bias[oc];
+ }
+}
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp
new file mode 100644
index 0000000000..bb65bce79c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx2_convolution.hpp
@@ -0,0 +1,302 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX2_CONVOLUTION_HPP
+#define CPU_JIT_AVX2_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_reducer.hpp"
+
+#include "jit_avx2_conv_kernel_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_avx2_convolution_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
+ jit_avx2_convolution_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx2_conv_fwd_kernel_f32::init_conf(jcp_,
+ *desc(), src_md(), weights_md(), dst_md(), *attr());
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx2_conv_fwd_kernel_f32::init_scratchpad(scratchpad, jcp_);
+
+ return status::success;
+ }
+
+ jit_conv_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ const bool flat = IC() < 8;
+ auto src_tag = flat
+ ? utils::pick(ndims() - 3, ncw, nchw, ncdhw)
+ : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto dst_tag =
+ utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto wei_tag = with_groups()
+ ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
+ gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
+ : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
+ OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
+
+ return set_default_formats_common(src_tag, wei_tag, dst_tag);
+ }
+ };
+
+ jit_avx2_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
+ { kernel_ = new jit_avx2_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); }
+ ~jit_avx2_convolution_fwd_t() { delete kernel_; }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx2_conv_fwd_kernel_f32 *kernel_;
+};
+
+struct jit_avx2_convolution_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_data_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_()
+ {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
+ jit_avx2_convolution_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::undef, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx2_conv_bwd_data_kernel_f32::init_conf(
+ jcp_, *desc(), *diff_src_md(), *weights_md(),
+ *diff_dst_md());
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx2_conv_bwd_data_kernel_f32::init_scratchpad(scratchpad,
+ jcp_);
+
+ return status::success;
+ }
+
+ jit_conv_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto wei_tag = with_groups()
+ ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i)
+ : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i);
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ jit_avx2_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd)
+ { kernel_ = new jit_avx2_conv_bwd_data_kernel_f32(pd()->jcp_); }
+ ~jit_avx2_convolution_bwd_data_t() { delete kernel_; }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx2_conv_bwd_data_kernel_f32 *kernel_;
+};
+
+struct jit_avx2_convolution_bwd_weights_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_weights_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", avx2, ""),
+ jit_avx2_convolution_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_weights
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx2_conv_bwd_weights_kernel_f32::init_conf(
+ jcp_, *desc(), *src_md(), *diff_weights_md(),
+ *diff_dst_md());
+ if (status != status::success) return status;
+
+ init_balancers();
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx2_conv_bwd_weights_kernel_f32::init_scratchpad(scratchpad,
+ jcp_);
+
+ auto reducer_bia_scratchpad = memory_tracking::registrar_t(
+ scratchpad, memory_tracking::names::prefix_reducer_bia);
+ reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
+
+ auto reducer_wei_scratchpad = memory_tracking::registrar_t(
+ scratchpad, memory_tracking::names::prefix_reducer_wei);
+ reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
+
+ return status::success;
+ }
+
+ jit_conv_conf_t jcp_;
+ cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
+ cpu_reducer_t<data_type::f32>::conf_t reducer_wei_conf_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ const bool flat = IC() == 3;
+
+ auto src_tag = flat
+ ? utils::pick(ndims() - 3, ncw, nchw, ncdhw)
+ : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto dst_tag =
+ utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto wei_tag = with_groups()
+ ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
+ gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
+ : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
+ OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
+
+ return set_default_formats_common(src_tag, wei_tag, dst_tag);
+ }
+
+ private:
+ void init_balancers() {
+ const int max_threads = mkldnn_get_max_threads();
+ const size_t max_buffer_size = 1<<21; /* just a heuristic */
+
+ if(with_bias()) {
+ reducer_bia_conf_.init(reduce_balancer_t(max_threads,
+ jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
+ max_buffer_size));
+ }
+
+ reducer_wei_conf_.init(reduce_balancer_t(max_threads,
+ jcp_.kd * jcp_.kh * jcp_.kw
+ * jcp_.ic_block * jcp_.oc_block,
+ jcp_.ngroups * jcp_.nb_ic * jcp_.nb_oc,
+ jcp_.mb * jcp_.od, max_buffer_size));
+ }
+ };
+
+ jit_avx2_convolution_bwd_weights_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , kernel_(nullptr)
+ , reducer_weights_(nullptr)
+ , reducer_bias_(nullptr)
+ {
+ kernel_ = new jit_avx2_conv_bwd_weights_kernel_f32(pd()->jcp_);
+ reducer_bias_ =
+ new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
+ reducer_weights_ =
+ new cpu_reducer_t<data_type::f32>(pd()->reducer_wei_conf_);
+ }
+
+ ~jit_avx2_convolution_bwd_weights_t() {
+ delete kernel_;
+ delete reducer_weights_;
+ delete reducer_bias_;
+ }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_weights(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_weights(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx2_conv_bwd_weights_kernel_f32 *kernel_;
+ cpu_reducer_t<data_type::f32> *reducer_weights_, *reducer_bias_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp
new file mode 100644
index 0000000000..635b83b2bf
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.cpp
@@ -0,0 +1,1255 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <float.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_memory.hpp"
+#include "cpu_barrier.hpp"
+
+#include "jit_uni_1x1_conv_utils.hpp"
+#include "jit_avx512_common_1x1_conv_kernel.hpp"
+
+#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::utils;
+
+using namespace Xbyak;
+
+void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk)
+{
+ mov(aux1_reg_bcast_data, reg_bcast_data);
+ mov(aux_reg_bcast_data, reg_bcast_data);
+
+ mov(aux_reg_output_data, reg_output_data);
+ mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt));
+
+ if (jcp.ver == ver_4fma)
+ {
+ Label bcast_loop;
+ Label bcast_loop_wraparound;
+ Label bcast_loop_out;
+ Label bcast_loop_ur_full;
+
+ cmp(bcast_loop_iter, jcp.ur);
+ jle(bcast_loop_wraparound, T_NEAR);
+
+ L(bcast_loop); {
+ assert(jcp.bcast_block % jcp.ur == 0);
+ int num_substeps = jcp.bcast_block / jcp.ur;
+ assert(num_substeps > 0 && num_substeps < 10);
+ for (int i = 0; i < num_substeps; i++) {
+ reduce_loop(load_loop_blk, jcp.ur, i, false);
+ if (i < num_substeps - 1) {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
+ add(aux_reg_output_data, jcp.bcast_loop_output_substep);
+ }
+ else {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
+ - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
+ add(aux_reg_output_data, jcp.bcast_loop_output_step
+ - (num_substeps - 1) * jcp.bcast_loop_output_substep);
+ }
+ }
+ sub(bcast_loop_iter, jcp.bcast_block);
+ cmp(bcast_loop_iter, jcp.bcast_block);
+ jg(bcast_loop, T_NEAR);
+ }
+
+ L(bcast_loop_wraparound);
+ if (jcp.ur_tail) {
+ je(bcast_loop_ur_full, T_NEAR);
+ reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
+ jmp(bcast_loop_out, T_NEAR);
+ }
+ L(bcast_loop_ur_full);
+ reduce_loop(load_loop_blk, jcp.ur, 0, true);
+ L(bcast_loop_out);
+ }
+ else
+ {
+ Label bcast_loop;
+ Label bcast_loop_tail;
+
+ cmp(bcast_loop_iter, jcp.ur);
+ jl(bcast_loop_tail, T_NEAR);
+
+ L(bcast_loop); {
+ assert(jcp.bcast_block % jcp.ur == 0);
+ int num_substeps = jcp.bcast_block / jcp.ur;
+ assert(num_substeps > 0 && num_substeps < 10);
+ for (int i = 0; i < num_substeps; i++) {
+ reduce_loop(load_loop_blk, jcp.ur, i, false);
+ if (i < num_substeps - 1) {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
+ add(aux_reg_output_data, jcp.bcast_loop_output_substep);
+ }
+ else {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
+ - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
+ add(aux_reg_output_data, jcp.bcast_loop_output_step
+ - (num_substeps - 1) * jcp.bcast_loop_output_substep);
+ }
+ }
+ sub(bcast_loop_iter, jcp.bcast_block);
+ cmp(bcast_loop_iter, jcp.bcast_block);
+ jge(bcast_loop, T_NEAR);
+ }
+
+ L(bcast_loop_tail);
+ if (jcp.ur_tail) {
+ Label bcast_loop_tail_out;
+ cmp(bcast_loop_iter, 0);
+ jz(bcast_loop_tail_out, T_NEAR);
+ reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
+ L(bcast_loop_tail_out);
+ }
+ }
+}
+
+void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk,
+ int ur, int substep, bool wraparound)
+{
+ auto vreg_load = [=](int i_load, int i_fma) {
+ return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
+ + jcp.fma_step * i_load + i_fma);
+ };
+
+ auto vreg_accum = [=](int i_load, int i_ur) {
+ return Zmm(i_ur * load_loop_blk + i_load);
+ };
+
+ auto bias_ptr = [=](int i_load) {
+ return EVEX_compress_addr(reg_bias_data,
+ jcp.typesize_out * jcp.oc_block * i_load);
+ };
+
+ auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
+ assert(i_ur < jcp.ur);
+ assert(i_reduce <= jcp.reduce_loop_unroll);
+ int offt;
+ if (one_of(jcp.prop_kind, forward_training, forward_inference,
+ backward_data)) {
+ assert(jcp.reduce_loop_unroll == jcp.reduce_block);
+ offt = (i_reduce == jcp.reduce_loop_unroll)
+ ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll
+ : i_ur * jcp.reduce_loop_unroll + i_reduce;
+ } else {
+ if (jcp.transpose_src) {
+ const int reduce_group = i_reduce / 4;
+ const int reduce_shift = i_reduce % 4;
+ offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift;
+ }
+ else
+ offt = i_reduce * jcp.ic_block + i_ur;
+ }
+ return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
+ bcast);
+ };
+
+ auto load_ptr = [=](int i_reduce, int i_load) {
+ int offt;
+ int u0 = i_reduce % jcp.reduce_loop_unroll;
+ int u1 = i_reduce / jcp.reduce_loop_unroll;
+ offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
+ return EVEX_compress_addr(aux_reg_load_data,
+ u1 * jcp.reduce_loop_load_step
+ + jcp.typesize_in * offt);
+ };
+
+ auto output_ptr = [=](int i_load, int i_ur) {
+ if (one_of(jcp.prop_kind, forward_training, forward_inference,
+ backward_data))
+ return EVEX_compress_addr(aux_reg_output_data,
+ (i_load * jcp.bcast_dim + i_ur) * jcp.load_block
+ * jcp.typesize_out);
+ else
+ return ptr[aux_reg_output_data +
+ (i_load
+ ? reg_output_stride * i_load
+ : 0) // TODO: Xbyak should allow 0 scale
+ + jcp.typesize_out * jcp.load_block * i_ur];
+ };
+
+ auto init = [=]() {
+ Label init_done;
+ Label init_zero;
+
+ if (jcp.with_sum) {
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ mic_prefetcht1(output_ptr(i_load, i_ur));
+ }
+ }
+ }
+
+ if (jcp.with_bias
+ && one_of(jcp.prop_kind, forward_training, forward_inference)) {
+ test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
+ jz(init_zero, T_NEAR);
+
+ for (int i_load = 0; i_load < load_loop_blk; i_load++)
+ for (int i_ur = 0; i_ur < ur; ++i_ur)
+ vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load));
+ jmp(init_done, T_NEAR);
+ }
+
+ L(init_zero);
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load)
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ auto r = vreg_accum(i_load, i_ur);
+ vpxord(r, r, r);
+ }
+ L(init_done);
+ };
+
+ auto store = [=]() {
+ Label store_noadd;
+ if (!jcp.with_sum) {
+ test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
+ jnz(store_noadd, T_NEAR);
+ }
+
+ for (int i_ur = 0; i_ur < ur; ++i_ur)
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ auto r = vreg_accum(i_load, i_ur);
+ vaddps(r, r, output_ptr(i_load, i_ur));
+ }
+
+ L(store_noadd);
+ if (jcp.with_eltwise) {
+ Label store_noeltwise;
+ test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
+ jz(store_noeltwise, T_NEAR);
+
+ eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
+
+ L(store_noeltwise);
+ }
+
+ auto store_output = [=](bool output_is_aligned) {
+ for (int i_ur = 0; i_ur < ur; ++i_ur)
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load)
+ if (output_is_aligned && jcp.use_vmovntps)
+ vmovntps(output_ptr(i_load, i_ur),
+ vreg_accum(i_load, i_ur));
+ else
+ vmovups(output_ptr(i_load, i_ur),
+ vreg_accum(i_load, i_ur));
+ };
+
+ Label unaligned_store, end_store;
+ test(aux_reg_output_data, cpu_isa_traits<avx512_common>::vlen - 1);
+ jnz(unaligned_store, T_NEAR);
+ store_output(true);
+ jmp(end_store, T_NEAR);
+ L(unaligned_store); {
+ store_output(false);
+ }
+ L(end_store);
+ };
+
+ auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load,
+ bool last_block, bool wraparound, int reduce_step)
+ {
+ bool pf_ker_l1 = true;
+ bool pf_ker_l2 = wraparound;
+ int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk;
+ int i_op = (i_reduce / reduce_step) * ur * load_loop_blk +
+ i_ur * load_loop_blk + i_load;
+
+ int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0;
+ int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0;
+ int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur;
+
+ int pf_inp_ops = n_ops / 2; // # of operations during which to pf input
+ int pf_inp_trigger;
+ if (jcp.prop_kind == backward_weights)
+ pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block);
+ else
+ pf_inp_trigger = nstl::max(1, pf_inp_ops / ur);
+
+ int n_other_pf =
+ load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1);
+ int n_other_pf_ops = n_ops - pf_inp_ops;
+ int other_pf_trigger
+ = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0;
+
+ if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) {
+ // input prefetches have the highest priority b/c the
+ // first iteration of the kernel block touches all the
+ // cache lines
+ int i_pf = i_op / pf_inp_trigger;
+ auto pf_reg = wraparound && last_block
+ ? reg_bcast_data
+ : (last_block ? aux1_reg_bcast_data
+ : aux_reg_bcast_data);
+ int offt = i_pf;
+ if (jcp.prop_kind == backward_weights) {
+ offt += wraparound && last_block
+ ? 0
+ : (last_block ? jcp.is : jcp.reduce_block);
+ offt *= jcp.bcast_block;
+ } else {
+ offt += wraparound && last_block
+ ? 0
+ : (last_block ? jcp.ur : jcp.bcast_dim);
+ offt *= jcp.reduce_block;
+ }
+ mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
+ } else if (i_op >= pf_inp_ops && n_other_pf) {
+ // remaining prefetches are spread among the rest of the
+ // operations; prefetches for output take priority
+ // TODO: spread L2 prefetches among L1 prefetches
+ i_op -= pf_inp_ops;
+ if (i_op % other_pf_trigger == 0) {
+ int i_pf = i_op / (load_loop_blk * other_pf_trigger);
+ if (i_pf < n_pf_ker_l2) {
+ int offt = (i_pf + (i_load + 1) * jcp.reduce_dim)
+ * jcp.load_block;
+ mic_prefetcht1(ptr[aux_reg_load_data
+ + offt * jcp.typesize_in]);
+ } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) {
+ i_pf -= n_pf_ker_l2;
+ auto pf_reg = last_block ? reg_load_data
+ : aux_reg_load_data;
+ int offt = (i_pf + i_load * jcp.reduce_dim
+ + (last_block
+ ? (wraparound ? jcp.reduce_dim : 0)
+ : jcp.reduce_block))
+ * jcp.load_block;
+ mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
+ } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) {
+ i_pf -= n_pf_ker_l1 + n_pf_ker_l2;
+ int offt = i_pf * jcp.load_block;
+ mic_prefetcht0(ptr[aux_reg_output_data
+ + offt * jcp.typesize_out]);
+ }
+ }
+ }
+ };
+
+ auto fma_block = [=](bool last_block) {
+ assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
+
+ int reduce_step = jcp.fma_step;
+
+ for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll;
+ i_reduce += reduce_step) {
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ // if transposed input data used and if spatial size is
+ // not divided by transpose step (4) then for last reduce step
+ // we should load only needed load_registers data
+ // and clear remaining
+ if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block
+ && i_reduce == jcp.reduce_loop_unroll - reduce_step) {
+ Label load_all;
+ Label load_finish;
+ test(reg_reduce_pos_flag, FLAG_SP_LAST);
+ jz(load_all, T_NEAR);
+
+ const int n_loads = jcp.is % jcp.fma_step;
+ for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
+ if (i_fma < n_loads)
+ vmovups(vreg_load(i_load, i_fma),
+ load_ptr(i_reduce + i_fma, i_load));
+ else
+ vpxord(vreg_load(i_load, i_fma),
+ vreg_load(i_load, i_fma),
+ vreg_load(i_load, i_fma));
+ }
+ jmp(load_finish);
+
+ L(load_all);
+ for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
+ vmovups(vreg_load(i_load, i_fma),
+ load_ptr(i_reduce + i_fma, i_load));
+ }
+ L(load_finish);
+ } else {
+ for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
+ vmovups(vreg_load(i_load, i_fma),
+ load_ptr(i_reduce + i_fma, i_load));
+ }
+ }
+ }
+
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ if (jcp.ver == ver_avx512_core && jcp.expl_bcast
+ && load_loop_blk > 1)
+ vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false));
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ if (jcp.ver == ver_4fma)
+ v4fmaddps(vreg_accum(i_load, i_ur),
+ vreg_load(i_load, 0),
+ bcast_ptr(i_reduce, i_ur, false));
+ else if (jcp.ver == ver_avx512_core && jcp.expl_bcast
+ && load_loop_blk > 1)
+ vfmadd231ps(vreg_accum(i_load, i_ur),
+ vreg_load(i_load, 0), vreg_bcast);
+ else
+ vfmadd231ps(vreg_accum(i_load, i_ur),
+ vreg_load(i_load, 0),
+ bcast_ptr(i_reduce, i_ur, true));
+ prefetch_callback(ur, i_reduce, i_ur, i_load,
+ last_block, wraparound, reduce_step);
+ }
+ }
+ }
+ };
+ Label reduce_loop;
+ Label reduce_loop_tail;
+
+ mov(aux_reg_load_data, reg_load_data);
+
+ mov(aux_reg_bcast_data, aux1_reg_bcast_data);
+ init();
+
+ mov(reduce_loop_iter, reg_reduce_loop_work);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jle(reduce_loop_tail, T_NEAR);
+
+ L(reduce_loop); {
+ fma_block(false);
+ add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
+ add(aux_reg_load_data, jcp.reduce_loop_load_step);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jg(reduce_loop, T_NEAR);
+ }
+
+ L(reduce_loop_tail);
+ fma_block(true);
+
+ store();
+}
+
+void jit_avx512_common_1x1_conv_kernel::generate()
+{
+ preamble();
+
+ mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
+ mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
+ mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
+
+ sub(rsp, stack_space_needed);
+
+ if (jcp.with_bias)
+ mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
+
+ mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
+ mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
+ mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
+ mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
+ mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
+ if (one_of(jcp.prop_kind, forward_training, forward_inference))
+ mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha));
+ if (jcp.prop_kind == backward_weights)
+ mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
+
+ auto load_loop_body = [=](int load_loop_blk) {
+ bcast_loop(load_loop_blk);
+ add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
+ switch (jcp.prop_kind) {
+ case forward_training:
+ case forward_inference:
+ add(reg_bias_data,
+ load_loop_blk * jcp.load_block * jcp.typesize_out);
+ add(reg_output_data,
+ load_loop_blk * jcp.bcast_dim * jcp.load_block *
+ jcp.typesize_out);
+ break;
+ case backward_data:
+ add(reg_output_data,
+ load_loop_blk * jcp.bcast_dim * jcp.load_block *
+ jcp.typesize_out);
+ break;
+ case backward_weights:
+ for (int i_load = 0; i_load < load_loop_blk; i_load++)
+ add(reg_output_data, reg_output_stride);
+ break;
+ default:
+ assert(!"invalid prop_kind");
+ }
+ sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
+ };
+
+ const int simd_w = 16;
+
+ Label load_loop_blk[7];
+
+ static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 };
+ static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
+ static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 };
+
+ const int size_ur_cases_fma
+ = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
+ sizeof(ur_cases_fma_expl_bcast) :
+ sizeof(ur_cases_fma_embd_bcast);
+ const int size_ur_cases_4fma = sizeof(ur_cases_4fma);
+
+ const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
+ ur_cases_fma_expl_bcast :
+ ur_cases_fma_embd_bcast;
+ const int *ur_cases = jcp.ver == ver_4fma ? ur_cases_4fma : ur_cases_fma;
+ const int num_ur_cases =
+ (jcp.ver == ver_4fma ? size_ur_cases_4fma : size_ur_cases_fma)
+ / sizeof(*ur_cases);
+
+ for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
+ int label_idx = num_ur_cases - ur_idx - 1;
+ if (jcp.ur <= ur_cases[ur_idx]) {
+ cmp(reg_load_loop_work, simd_w * (label_idx + 1));
+ jle(load_loop_blk[label_idx], T_NEAR);
+ }
+ }
+
+ for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
+ if (jcp.ur <= ur_cases[ur_idx]) {
+ int label_idx = num_ur_cases - ur_idx - 1;
+ L(load_loop_blk[label_idx]);
+ {
+ if (label_idx == 0) {
+ cmp(reg_load_loop_work, 0);
+ je(load_loop_blk[num_ur_cases], T_NEAR);
+ }
+ load_loop_body(label_idx + 1);
+ if (label_idx - 1 > 0) {
+ cmp(reg_load_loop_work, 2 * label_idx * simd_w);
+ je(load_loop_blk[label_idx - 1], T_NEAR);
+ }
+ cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
+ jge(load_loop_blk[label_idx]);
+ }
+ for (int idx = label_idx - 1; idx > 0; --idx) {
+ cmp(reg_load_loop_work, simd_w * (idx + 1));
+ je(load_loop_blk[idx], T_NEAR);
+ }
+ if (ur_idx < num_ur_cases - 2) {
+ cmp(reg_load_loop_work, simd_w);
+ jle(load_loop_blk[0], T_NEAR);
+ }
+ }
+ }
+ L(load_loop_blk[num_ur_cases]);
+
+ add(rsp, stack_space_needed);
+
+ postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+}
+
+bool jit_avx512_common_1x1_conv_kernel::post_ops_ok(
+ jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+
+ switch (p.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
+ case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr, int nthreads, bool reduce_src) {
+ if (!mayiuse(avx512_common)) return status::unimplemented;
+
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+ const int ndims = src_d.ndims();
+
+ jcp.prop_kind = cd.prop_kind;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+
+ jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+
+ bool ok_to_pad_channels = true
+ && jcp.ngroups == 1
+ && src_d.data_type() == data_type::f32;
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+ }
+
+ jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
+ jcp.iw = src_d.dims()[ndims - 1];
+ jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[ndims - 1];
+
+ jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
+ jcp.kw = weights_d.dims()[with_groups + ndims - 1];
+
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][ndims - 3];
+
+ jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
+ jcp.stride_w = cd.strides[ndims - 3];
+
+ jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format_kind,
+ format_kind::undef, cd.diff_bias_desc.format_kind)
+ != format_kind::undef;
+
+ jcp.os = jcp.oh * jcp.ow;
+ jcp.is = jcp.ih * jcp.iw;
+ jcp.tr_is = rnd_up(jcp.is, 4);
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ jcp.with_sum = p.find(primitive_kind::sum) != -1;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise) {
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+ if (dst_d.data_type() == data_type::s32) return status::unimplemented;
+ }
+
+ auto dat_tag = pick(ndims - 3, nCw16c, nChw16c);
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+
+ bool args_ok = true
+ && jcp.ngroups == 1
+ && jcp.src_tag == dat_tag
+ && jcp.dst_tag == dat_tag;
+ if (!args_ok) return status::unimplemented;
+
+ args_ok = true
+ && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
+ && jcp.t_pad == 0 && jcp.l_pad == 0
+ && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
+ && jcp.kh == 1 && jcp.kw == 1;
+ if (!args_ok) return status::unimplemented;
+
+ jcp.ic_block = jcp.oc_block = simd_w;
+ jcp.transpose_src = false;
+
+ if (everyone_is(data_type::f32, src_d.data_type(),
+ weights_d.data_type(), dst_d.data_type()))
+ {
+ const int is_bwd_d = jcp.prop_kind == backward_data;
+ format_tag_t wei_tag = with_groups
+ ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i,
+ gOIhw16i16o, gIOhw16o16i)
+ : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i,
+ OIhw16i16o, IOhw16o16i);
+
+ jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
+ if (jcp.wei_tag != wei_tag)
+ return status::unimplemented;
+
+ if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) &&
+ ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4
+ == 0) {
+ jcp.ver = ver_4fma;
+ jcp.fma_step = 4;
+ } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops)
+ && !reduce_src
+ /* Heuristic condition for relation of src size to oc. Otherwise
+ the src transposition overhead exceed the benefit from 4fma
+ */
+ && ((jcp.is * jcp.ic) / jcp.oc <= 2048)
+ && mkldnn_thr_syncable()
+ )
+ {
+ jcp.transpose_src = true;
+ jcp.ver = ver_4fma;
+ jcp.fma_step = 4;
+ } else {
+ jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma;
+ jcp.fma_step = 1;
+ }
+ jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
+ jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
+ } else {
+ return status::unimplemented;
+ }
+
+ /* once all the formats are set, check the padding consistency */
+ args_ok = true
+ && jcp.ic <= src_d.padded_dims()[1]
+ && jcp.oc <= dst_d.padded_dims()[1]
+ && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
+ && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
+ if (!args_ok) return status::unimplemented;
+
+ const int SMALL_SPATIAL = 10;
+ const int BIG_SPATIAL = 28;
+ const int BIG_REDUCE_DIM = 1024;
+ const int BIG_LOAD_DIM = 256;
+
+ int load_blocking{ 0 };
+ int load_blocking_max{ 0 };
+ int bcast_blocking{ 0 };
+ int bcast_blocking_max{ 0 };
+ int reduce_blocking{ 0 };
+ int reduce_blocking_max{ 0 };
+
+ jcp.load_grp_count = 1;
+
+ const int L1_capacity = get_cache_size(1, true) / sizeof(float);
+ const int L2_size = get_cache_size(2, true) / sizeof(float);
+ const int L2_capacity = (L2_size * 3) / 4;
+
+ if (one_of(jcp.prop_kind, forward_training, forward_inference,
+ backward_data)) {
+ if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
+ jcp.reduce_dim = jcp.ic;
+ jcp.reduce_block = jcp.ic_block;
+
+ jcp.load_dim = jcp.oc;
+ jcp.load_block = jcp.oc_block;
+
+ jcp.bcast_dim = jcp.is;
+ } else {
+ jcp.reduce_dim = jcp.oc;
+ jcp.reduce_block = jcp.oc_block;
+
+ jcp.load_dim = jcp.ic;
+ jcp.load_block = jcp.ic_block;
+
+ jcp.bcast_dim = jcp.os;
+ }
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
+
+ jcp.reduce_loop_load_step
+ = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
+ jcp.load_loop_load_step
+ = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
+
+ // adjusting registry blocking
+ int max_regs, min_regs, size_treshold, ur_step;
+ const int spatial
+ = (one_of(jcp.prop_kind, forward_training, forward_inference)) ?
+ jcp.oh :
+ jcp.ih;
+ if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) {
+ max_regs = 9;
+ min_regs = 6;
+ size_treshold = 14;
+ ur_step = 1;
+ jcp.expl_bcast = true;
+
+ if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM
+ && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) {
+ max_regs = 6;
+ min_regs = 5;
+ }
+ } else {
+ max_regs = jcp.ver == ver_4fma ? 28 : 30;
+ min_regs = 9;
+ size_treshold = jcp.ver == ver_4fma ? 28 : 14;
+ ur_step = jcp.ver == ver_4fma ? 4 : 1;
+ jcp.expl_bcast = false;
+ jcp.use_vmovntps = true;
+ }
+ jcp.ur = 1;
+ for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) {
+ if ((spatial >= size_treshold && spatial % ur_w == 0)
+ || (spatial < size_treshold && jcp.os % ur_w == 0)) {
+ jcp.ur = ur_w;
+ break;
+ }
+ }
+ if (jcp.ur == 1) {
+ jcp.ur = nstl::min(max_regs, jcp.os);
+ int os_tail = jcp.os % max_regs;
+ for (int i = max_regs; i >= min_regs; i -= ur_step) {
+ int i_tail = jcp.os % i;
+ if (i_tail > os_tail || i_tail == 0) {
+ jcp.ur = i;
+ os_tail = i_tail;
+ if (i_tail == 0)
+ break;
+ }
+ }
+ }
+
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
+
+ jcp.bcast_block = jcp.ur;
+
+ jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out;
+ jcp.bcast_loop_output_substep = -1; // unused
+ jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in;
+ jcp.bcast_loop_bcast_substep = -1; // unused
+
+ jcp.load_loop_iter_step = jcp.load_block;
+
+ if (jcp.prop_kind == backward_data)
+ jcp.loop_order = loop_lbr;
+ else
+ jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
+
+ int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
+ int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
+ int nb_load = div_up(jcp.load_dim, jcp.load_block);
+
+ if (jcp.ver == ver_avx512_core && jcp.expl_bcast) {
+ if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL
+ && spatial < BIG_SPATIAL)
+ reduce_blocking = nstl::min(jcp.reduce_dim, 80);
+ else if (spatial > SMALL_SPATIAL)
+ reduce_blocking = nstl::min(jcp.reduce_dim, 512);
+ else
+ reduce_blocking = nstl::min(jcp.reduce_dim, 256);
+
+ if ((jcp.mb > 28 && spatial >= 28)
+ || (jcp.mb > 112 && spatial >= 17))
+ jcp.use_vmovntps = true;
+ else
+ jcp.use_vmovntps = false;
+ } else {
+
+ reduce_blocking = nb_reduce;
+ if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
+ reduce_blocking = 16;
+ else if (spatial > SMALL_SPATIAL
+ && jcp.reduce_dim >= BIG_REDUCE_DIM)
+ reduce_blocking = 8;
+ reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
+ reduce_blocking *= jcp.reduce_block;
+ }
+
+ // Check input data cache aliasing.
+ // For other ISA constants may be updated.
+ // 64 * 1024 is chosen due to 1MB L2 16-way cache.
+ // 7 is empirical value. It is about half of 16.
+ // So we leave about half of the set for other data - weights, dst
+ int way_size = (64 * 1024) / jcp.typesize_in;
+ int max_hits = 7;
+ if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) {
+ int nrb = reduce_blocking / simd_w;
+ int sp = jcp.bcast_dim;
+ int wl = way_size / simd_w;
+ for (int start_off = 0; start_off < jcp.ur; start_off++) {
+ for (int off = start_off, hits = 0; off < sp * nrb; off += wl) {
+ if (off % sp >= jcp.ur || ++hits < max_hits)
+ continue;
+ int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp);
+ reduce_blocking
+ = nstl::min(reduce_blocking, max_r_blocking);
+ break;
+ }
+ }
+ }
+
+ if (reduce_blocking < jcp.reduce_dim) {
+ jcp.use_vmovntps = false;
+ if (jcp.prop_kind == backward_data)
+ jcp.loop_order = reduce_src ? loop_lbr : loop_rlb;
+ else
+ jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
+ }
+ load_blocking = jcp.load_dim;
+
+ int load_size = jcp.load_dim * jcp.reduce_dim;
+ int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
+
+ if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads
+ && nb_load * nb_bcast > nthreads) {
+ // Some heuristic here
+ float calc_koef = 0.01, best_cost = FLT_MAX;
+ int n_lgc = nthreads;
+ float ratio = (float)load_size / (float)bcast_size;
+ int best_lgc = ratio > 1 ? n_lgc : 1;
+ auto calc_job_cost = [&](int lb, int tg, float mem_k) {
+ int bb_size = jcp.mb * div_up(nb_bcast, tg);
+ float calc_size = (float)(bb_size * jcp.ur)
+ * (lb * jcp.load_block) * jcp.reduce_dim;
+ float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block)
+ * jcp.reduce_dim;
+ return calc_koef * calc_size + mem_k * mem_size;
+ };
+ for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) {
+ lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1;
+ int min_lb = nb_load / lgc;
+ int max_lb = div_up(nb_load, lgc);
+ int min_tg = nthreads / lgc;
+ int max_tg = div_up(nthreads, lgc);
+ // Some heuristic here
+ float mem_koef = (max_tg == 1) ? 1.f : 1.3f;
+ float job_cost = 0.;
+ if (nthreads % lgc < nb_load % lgc) {
+ job_cost = calc_job_cost(max_lb, min_tg, mem_koef);
+ } else {
+ auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef);
+ auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef);
+ job_cost = nstl::max(job_cost1, job_cost2);
+ }
+
+ if (job_cost < best_cost) {
+ best_lgc = lgc;
+ best_cost = job_cost;
+ }
+ }
+ jcp.load_grp_count = best_lgc;
+ load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
+ } else {
+ jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
+ jcp.load_grp_count = best_divider(
+ nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
+ }
+
+ if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64
+ && load_size >= L2_size) {
+ jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
+ } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads
+ && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
+ jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
+ load_blocking = jcp.load_block;
+ }
+
+ if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim
+ && jcp.oh * jcp.ow > 64
+ && IMPLICATION(reduce_src, jcp.load_dim < 1024)) {
+ /* Looking for best loading dimension blocking
+ * to get the best thread and data read/write efficiency
+ * by finding the optimal 'load_chunk' value
+ * Example:
+ * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512
+ * the 'best' load_chunk value should be 1
+ * TODO: remove heuristic constants in above condition
+ * TODO: check this blocking for other ISA
+ */
+ float best_eff = -1.f;
+ int best_lgc = 1;
+
+ for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) {
+ int lgc = div_up(nb_load, load_chunk);
+ if (lgc > nthreads)
+ continue;
+ int thr_per_grp = div_up(nthreads, lgc);
+ int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp)
+ * jcp.bcast_block;
+ int load_per_thr = load_chunk * simd_w;
+ float data_norm = (bcast_per_thr + load_per_thr) / 2.f;
+ float data_eff = (bcast_per_thr * load_per_thr)
+ / (data_norm * data_norm);
+ float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc)
+ / div_up(nthreads, lgc);
+ float thr_eff_in_grp = ((float)jcp.mb * nb_bcast)
+ / rnd_up(jcp.mb * nb_bcast, thr_per_grp);
+ float thr_eff = thr_eff_over_grp * thr_eff_in_grp;
+ float load_eff = (float)nb_load / rnd_up(nb_load, lgc);
+ float overall_eff = data_eff + thr_eff + load_eff;
+ if (overall_eff > best_eff) {
+ best_eff = overall_eff;
+ best_lgc = lgc;
+ }
+ }
+ jcp.load_grp_count = best_lgc;
+ load_blocking
+ = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
+ }
+ bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
+ div_up(nthreads, jcp.load_grp_count))
+ * jcp.bcast_block;
+ bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
+ bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
+
+ int space_for_bcast
+ = (L2_capacity - /* kernel_size - */
+ 2 * jcp.load_block * reduce_blocking
+ - jcp.ur * reduce_blocking - 3 * 1024);
+ if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
+ space_for_bcast /= 2;
+
+ int bcast_in_cache
+ = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
+ bcast_blocking = nstl::min(
+ bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
+
+ load_blocking_max = load_blocking;
+ bcast_blocking_max = bcast_blocking * 3 / 2;
+ reduce_blocking_max = reduce_blocking;
+
+ } else if (jcp.prop_kind == backward_weights) {
+
+ jcp.use_vmovntps = false;
+ if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma)
+ jcp.use_vmovntps = true;
+
+ if (jcp.transpose_src)
+ jcp.reduce_dim = jcp.tr_is;
+ else
+ jcp.reduce_dim = jcp.is;
+
+ if (jcp.ver == ver_4fma) {
+ // reduce_block should be divided by fma_step
+ jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4);
+ } else {
+ jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true);
+ if (jcp.reduce_dim % jcp.reduce_block != 0)
+ jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false);
+ if (jcp.reduce_block > 256) {
+ jcp.reduce_block = 1;
+ }
+
+ }
+
+ jcp.load_dim = jcp.oc;
+ jcp.load_block = jcp.oc_block;
+
+ jcp.bcast_dim = jcp.ic;
+ jcp.bcast_block = jcp.ic_block;
+
+ if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) {
+ // if reduce_block is big then generated JIT code may be big
+ // for small values of ur because reduce_loop_unroll = reduce_block
+ jcp.ur = jcp.bcast_block / 2;
+ jcp.expl_bcast = true;
+ } else {
+ jcp.ur = jcp.bcast_block;
+ jcp.expl_bcast = false;
+ }
+
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in;
+ jcp.reduce_loop_load_step
+ = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in;
+
+ jcp.bcast_loop_output_step =
+ jcp.oc_block * jcp.ic_block * jcp.typesize_out;
+ jcp.bcast_loop_output_substep =
+ jcp.oc_block * jcp.ur * jcp.typesize_out;
+ jcp.bcast_loop_bcast_step =
+ jcp.ic_block * jcp.reduce_dim * jcp.typesize_in;
+ jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in;
+
+ jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in;
+ jcp.load_loop_iter_step = jcp.oc_block;
+
+ /* --- */
+ balance(jcp, nthreads);
+
+ load_blocking = div_up(jcp.load_dim, jcp.load_block);
+ load_blocking = best_divider(load_blocking, 16, load_blocking, false);
+ load_blocking *= jcp.load_block;
+
+ load_blocking_max = load_blocking;
+ assert(jcp.load_dim % load_blocking == 0);
+
+ int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
+ int min_bcast_blocking = 5;
+
+ bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
+ bcast_blocking = best_divider(
+ bcast_blocking, min_bcast_blocking, max_bcast_blocking, false);
+ bcast_blocking *= jcp.bcast_block;
+ bcast_blocking_max = bcast_blocking;
+ assert(jcp.bcast_dim % bcast_blocking == 0);
+
+ // for reduction balance
+ if (jcp.ver == ver_avx512_core) {
+ int max_reduce_blocking
+ = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim);
+ int min_reduce_blocking = nstl::min(
+ L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih));
+ reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking,
+ max_reduce_blocking, true);
+ reduce_blocking
+ = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block),
+ jcp.reduce_block);
+ } else {
+ int max_reduce_blocking = L2_capacity
+ / ((bcast_blocking + load_blocking) * jcp.reduce_block);
+ max_reduce_blocking = nstl::min(max_reduce_blocking,
+ (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block);
+
+ int num_jobs = div_up(jcp.load_dim, load_blocking)
+ * div_up(jcp.bcast_dim, bcast_blocking);
+ int threads_per_job = nstl::max(1, nthreads / num_jobs);
+ reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block);
+ reduce_blocking = div_up(reduce_blocking, threads_per_job);
+
+ reduce_blocking = best_divider(reduce_blocking,
+ max_reduce_blocking - 2, max_reduce_blocking, true);
+ reduce_blocking *= jcp.reduce_block;
+ }
+
+ reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
+ } else
+ return status::unimplemented;
+
+ assert(load_blocking);
+ assert(load_blocking_max);
+ assert(bcast_blocking);
+ assert(bcast_blocking_max);
+ assert(reduce_blocking);
+ assert(reduce_blocking_max);
+ assert(load_blocking % jcp.load_block == 0);
+ assert(reduce_blocking % jcp.reduce_block == 0);
+ assert(load_blocking_max % jcp.load_block == 0);
+ assert(reduce_blocking_max % jcp.reduce_block == 0);
+ if (jcp.ver == ver_4fma) {
+ assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
+ assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
+ }
+
+ assert(jcp.bcast_block % jcp.ur == 0);
+ assert(jcp.reduce_dim % jcp.reduce_block == 0);
+
+ jcp.ur_tail = jcp.bcast_dim % jcp.ur;
+
+ jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
+ jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
+ jcp.nb_load_blocking = load_blocking / jcp.load_block;
+ jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
+ jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
+ jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
+
+ jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
+ jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
+ jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
+
+ return status::success;
+}
+
+void jit_avx512_common_1x1_conv_kernel::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp) {
+ using namespace mkldnn::impl::memory_tracking::names;
+
+ if (jcp.prop_kind != backward_data && jcp.with_bias
+ && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
+
+ if (jcp.prop_kind == backward_weights) {
+ const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic;
+ scratchpad.book(key_conv_wei_reduction,
+ jcp.typesize_out * wei_size * (jcp.nthr_mb - 1));
+ }
+
+ if (jcp.transpose_src) {
+ const size_t tr_src_size =
+ (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is;
+ scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size);
+ scratchpad.book(key_conv_tr_src_bctx,
+ sizeof(simple_barrier::ctx_t) * jcp.nthr);
+ }
+}
+
+void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp,
+ int nthreads)
+{
+ // initialize jcp reduction threading properties
+ jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1;
+ if (nthreads < jcp.ngroups) {
+ /* simplification... fortunately it doesn't hurt much */
+ return;
+ }
+ const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
+ const int nb_load = div_up(jcp.load_dim, jcp.load_block);
+ const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
+
+ jcp.nthr_g = jcp.ngroups;
+ const int nthr = nthreads / jcp.nthr_g;
+
+ auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
+ /* calculate per thread memory cost (read/write). high level
+ * optimizer tries to minimize memory consumption. few notes: (n1)
+ * unclear why, but that essentially helps first convolution...
+ * (n2) assuming the reduction over minibatch is always there:
+ * - instead of 8 it should be 5 here (write ~= 2 read):
+ * kernel: temporal workspace 1 write
+ * reduction: 1 read from workspace and 1 write to the diff_wei
+ * - but experiments showed 8 works better than 5 or 6... */
+ int bcast_koeff = 1;
+ int load_koeff = 1;
+ int output_koeff = 12;
+ if (jcp.transpose_src) {
+ bcast_koeff = 5;
+ load_koeff = 1;
+ output_koeff = 8;
+ }
+ return 0
+ + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
+ * div_up(jcp.ngroups, jcp.nthr_g)
+ * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block
+ / jcp.stride_h / jcp.stride_w /* (n1) */
+ + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
+ * div_up(jcp.ngroups, jcp.nthr_g)
+ * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block
+ + (size_t)output_koeff /* (n2) */
+ * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
+ * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block
+ * jcp.oc_block;
+ };
+
+ int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1;
+ auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+
+ /* step 1: find the best thread distribution with lowest memory cost */
+ const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce);
+ for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
+ const int nthr_par = nthr / nthr_mb;
+ const int nthr_oc_b_max = nstl::min(nthr_par, nb_load);
+ for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
+ nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast);
+ auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+ if (mem_cost <= best_mem_cost) {
+ best_mem_cost = mem_cost;
+ jcp.nthr_mb = nthr_mb;
+ jcp.nthr_oc_b = nthr_oc_b;
+ jcp.nthr_ic_b = nthr_ic_b;
+ }
+ }
+
+ if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
+ }
+ if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads)
+ jcp.nthr_mb = nstl::min(jcp.mb, nthreads);
+
+ jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b;
+ assert(jcp.nthr <= nthreads);
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp
new file mode 100644
index 0000000000..d2ae017943
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_conv_kernel.hpp
@@ -0,0 +1,108 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP
+#define JIT_AVX512_COMMON_1x1_CONV_KERNEL_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_avx512_common_1x1_conv_kernel : public jit_generator {
+ jit_avx512_common_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
+ const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
+ {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
+ this, jcp.eltwise);
+
+ this->generate();
+ jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
+ }
+
+ ~jit_avx512_common_1x1_conv_kernel() {
+ delete eltwise_injector_;
+ }
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_1x1_conv_kernel)
+
+ static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_1x1_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr,
+ int nthreads, bool reduce_src);
+
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp);
+
+ jit_1x1_conv_conf_t jcp;
+ const primitive_attr_t &attr_;
+ void (*jit_ker)(jit_1x1_conv_call_s *);
+
+ private:
+ using reg64_t = const Xbyak::Reg64;
+ using zmm_t = const Xbyak::Zmm;
+
+ reg64_t reg_bcast_data = r8;
+ reg64_t reg_load_data = r10;
+ reg64_t reg_output_data = r9;
+ reg64_t aux_reg_bcast_data = r14;
+ reg64_t aux1_reg_bcast_data = rbx;
+ reg64_t aux_reg_load_data = r15;
+ reg64_t imm_addr64 = aux_reg_load_data;
+ reg64_t aux_reg_output_data = abi_not_param1;
+ reg64_t reg_load_loop_work = rsi;
+ reg64_t reg_reduce_loop_work = r11;
+ reg64_t bcast_loop_iter = rdx;
+ reg64_t reduce_loop_iter = abi_param1;
+ reg64_t reg_reduce_pos_flag = rax;
+ reg64_t reg_output_stride = r13;
+ reg64_t reg_bias_data = r12;
+ reg64_t reg_relu_ns = r13;
+ reg64_t reg_bcast_loop_work = aux1_reg_bcast_data;
+
+ Xbyak::Zmm vreg_bcast = Xbyak::Zmm(31);
+
+ jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
+
+ int bcast_loop_work_offt = 0;
+ int stack_space_needed = 16;
+
+ void bcast_loop(int load_loop_blk);
+ void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
+
+ void generate();
+ static void balance(jit_1x1_conv_conf_t &jcp, int nthreads);
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp
new file mode 100644
index 0000000000..54d58c8a39
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.cpp
@@ -0,0 +1,816 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_generator.hpp"
+
+#include "jit_avx512_common_1x1_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+#define data_blk_off(f, n, c, h, w) \
+ ((ndims == 3) \
+ ? (f).blk_off(n, c, w) \
+ : (f).blk_off(n, c, h, w))
+
+
+namespace {
+template <typename T, typename U>
+void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
+ T nx, T &nx_start, T &nx_end, T nx_divider)
+{
+ const int grp_count = nstl::min(nx_divider, nthr);
+ const int grp_size_big = nthr / grp_count + 1;
+ const int grp_size_small = nthr / grp_count;
+ const int n_grp_big = nthr % grp_count;
+ const int threads_in_big_groups = n_grp_big * grp_size_big;
+
+ const int ithr_bound_distance = ithr - threads_in_big_groups;
+ T grp, grp_ithr, grp_nthr;
+ if (ithr_bound_distance < 0) { // ithr in first groups
+ grp = ithr / grp_size_big;
+ grp_ithr = ithr % grp_size_big;
+ grp_nthr = grp_size_big;
+ } else { // ithr in last groups
+ grp = n_grp_big + ithr_bound_distance / grp_size_small;
+ grp_ithr = ithr_bound_distance % grp_size_small;
+ grp_nthr = grp_size_small;
+ }
+
+ balance211(nx, grp_count, grp, nx_start, nx_end);
+ balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
+}
+}
+/* convolution forward */
+
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
+void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
+execute_forward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ const auto &jcp = kernel_->jcp;
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad.template get<dst_data_t>(
+ key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bias = padded_bias;
+ }
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
+ });
+
+ if (pd()->wants_zero_pad_dst())
+ ctx.memory(MKLDNN_ARG_DST)->zero_pad();
+}
+
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
+void jit_avx512_common_1x1_convolution_fwd_t<src_type, wei_type, dst_type>::
+execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
+ const wei_data_t *weights, const dst_data_t *bias, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+ auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
+
+ const int ndims = src_d.ndims();
+ const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[ndims - 3];
+ const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][ndims - 3];
+
+ const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
+
+ auto step = [](int default_step, int remaining, int tail_step) {
+ assert(default_step <= tail_step);
+ return remaining < tail_step ? remaining : default_step;
+ };
+
+ auto p = jit_1x1_conv_call_s();
+
+ auto rp = rtus_driver_t<avx512_common>::call_params_t();
+
+ const int nb_oc = jcp.nb_load;
+ const int nb_ic = jcp.nb_reduce;
+ const int nb_ic_blocking = jcp.nb_reduce_blocking;
+ const int os_block = jcp.bcast_block;
+
+ int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
+ balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
+ jcp.nb_load, ocb_start, ocb_end, jcp.load_grp_count);
+
+ auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
+ int &oh, int &ow, int &ih, int &iw)
+ {
+ int osb{0};
+ nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
+ jcp.nb_bcast);
+ bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
+ jcp.nb_bcast_blocking_max);
+ bcast_step = nstl::min(bcast_step, bcast_end - iwork);
+
+ const int os = osb * os_block;
+ oh = os / jcp.ow;
+ ow = os % jcp.ow;
+
+ ih = nstl::max(oh * stride_h - pad_t, 0);
+ iw = nstl::max(ow * stride_w - pad_l, 0);
+ rp.iw_start = iw;
+
+ p.bcast_dim = this_block_size(os, jcp.os,
+ bcast_step * os_block);
+ rp.os = p.bcast_dim;
+ };
+
+ auto init_load = [&](int ocb, int &load_step)
+ {
+ load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
+ jcp.nb_load_blocking_max);
+ p.load_dim = this_block_size(ocb * jcp.oc_block,
+ ocb_end * jcp.oc_block, load_step * jcp.oc_block);
+ };
+
+ auto init_reduce = [&](int icb)
+ {
+ const int nb_ic_blocking_step =
+ nstl::min(icb + nb_ic_blocking, nb_ic) - icb;
+ p.first_last_flag = 0
+ | (icb == 0 ? FLAG_REDUCE_FIRST : 0)
+ | (icb + nb_ic_blocking_step >= nb_ic
+ ? FLAG_REDUCE_LAST : 0);
+
+ p.reduce_dim = this_block_size(icb * jcp.ic_block,
+ jcp.ic, nb_ic_blocking_step * jcp.ic_block);
+ rp.icb = p.reduce_dim / jcp.reduce_block;
+ };
+
+ auto inner_ker = [&](int ocb, int icb, int n, int g, int oh, int ow,
+ int ih, int iw)
+ {
+
+ const int _ocb = g * nb_oc + ocb;
+ const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
+
+ p.output_data = &dst[dst_off];
+ p.bias_data = &bias[_ocb * jcp.oc_block];
+ p.load_data = &weights[pd()->with_groups()
+ ? weights_d.blk_off(g, ocb, icb)
+ : weights_d.blk_off(ocb, icb)];
+
+ const int _icb = g * nb_ic + icb;
+ if (pd()->rtus_.reduce_src_) {
+ rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
+ + _icb * jcp.is * jcp.ic_block;
+ if (ocb == ocb_start) {
+ rp.src = src + data_blk_off(src_d, n, _icb, ih, iw);
+ rtus_driver_->ker_(&rp);
+ }
+ p.bcast_data = rp.ws;
+ } else
+ p.bcast_data = src + data_blk_off(src_d, n, _icb, ih, iw);
+
+ kernel_->jit_ker(&p);
+ };
+
+ if (jcp.loop_order == loop_rlb) {
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ init_reduce(icb);
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
+ iwork += bcast_step;
+ }
+ ocb += load_step;
+ }
+ }
+ } else if (jcp.loop_order == loop_lbr) {
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ init_reduce(icb);
+ inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
+ }
+ iwork += bcast_step;
+ }
+ ocb += load_step;
+ }
+ } else if (jcp.loop_order == loop_rbl) {
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ init_reduce(icb);
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
+ ocb += load_step;
+ }
+ iwork += bcast_step;
+ }
+ }
+ } else if (jcp.loop_order == loop_blr) {
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ init_reduce(icb);
+ inner_ker(ocb, icb, n, g, oh, ow, ih, iw);
+ }
+ ocb += load_step;
+ }
+ iwork += bcast_step;
+ }
+ } else {
+ assert(!"unsupported loop order");
+ }
+}
+
+
+template struct jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
+/* convolution backward wtr data */
+
+template <data_type_t diff_dst_type, data_type_t wei_type,
+ data_type_t diff_src_type>
+void jit_avx512_common_1x1_convolution_bwd_data_t<diff_dst_type, wei_type,
+ diff_src_type>::execute_backward_data(const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+
+ const auto &jcp = kernel_->jcp;
+ auto rtus_space = scratchpad(ctx).template get<diff_src_data_t>(
+ key_conv_rtus_space);
+
+ const int ndims = diff_src_d.ndims();
+
+ // TODO (Roma): remove this restriction
+ assert(jcp.stride_w == 1 && jcp.stride_h == 1);
+
+ const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[ndims - 3];
+ const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][ndims - 3];
+
+ const int nb_ic = jcp.nb_load;
+ const int nb_oc = jcp.nb_reduce;
+ const int os_block = jcp.bcast_block;
+ const int nb_oc_blocking = jcp.nb_reduce_blocking;
+
+ const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
+
+ auto step = [](int default_step, int remaining, int tail_step) {
+ assert(default_step <= tail_step);
+ return remaining < tail_step ? remaining : default_step;
+ };
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ auto p = jit_1x1_conv_call_s();
+ auto rp = rtus_driver_t<avx512_common>::call_params_t();
+
+ int bcast_start{0}, bcast_end{0}, icb_start{0}, icb_end{0};
+ balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
+ jcp.nb_load, icb_start, icb_end, jcp.load_grp_count);
+
+ bool reduce_outer = (jcp.loop_order == loop_rbl
+ || jcp.loop_order == loop_rlb);
+ int nboc_outer = reduce_outer ? nb_oc : 1;
+ int ocb_outer_step = reduce_outer ? nb_oc_blocking : 1;
+
+ int nboc_inner = reduce_outer ? 1 : nb_oc;
+ int ocb_inner_step = reduce_outer ? 1 : nb_oc_blocking;
+
+ for (int ocb_outer = 0; ocb_outer < nboc_outer;
+ ocb_outer += ocb_outer_step) {
+ size_t cur_ocb_outer =
+ nstl::min(ocb_outer + ocb_outer_step, nboc_outer) - ocb_outer;
+
+ int load_step = 0;
+ for (int icb = icb_start; icb < icb_end; icb += load_step) {
+ load_step = step(jcp.nb_load_blocking, jcp.nb_load - icb,
+ jcp.nb_load_blocking_max);
+
+ p.load_dim = this_block_size(icb * jcp.ic_block,
+ icb_end * jcp.ic_block, load_step * jcp.ic_block);
+ rp.icb = p.load_dim / jcp.ic_block;
+
+ int bcast_step;
+ for (int iwork = bcast_start; iwork < bcast_end;
+ iwork += bcast_step)
+ {
+ int n{0}, g{0}, osb{0};
+ nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
+ jcp.nb_bcast);
+
+ bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
+ jcp.nb_bcast_blocking_max);
+ bcast_step = nstl::min(bcast_step, bcast_end - iwork);
+
+ const int os = osb * os_block;
+ p.bcast_dim = this_block_size(os, jcp.os,
+ bcast_step * os_block);
+ rp.os = p.bcast_dim;
+
+ const int oh = os / jcp.ow;
+ const int ow = os % jcp.ow;
+ const int ih = nstl::max(oh * stride_h - pad_t, 0);
+ const int iw = nstl::max(ow * stride_w - pad_l, 0);
+ rp.iw_start = iw;
+
+ const int _icb = g * nb_ic + icb;
+ rp.src = diff_src + data_blk_off(diff_src_d, n, _icb, ih, iw);
+ if (pd()->rtus_.reduce_src_) {
+ rp.ws = rtus_space
+ + ithr * pd()->rtus_.space_per_thread_;
+ p.output_data = rp.ws;
+ } else
+ p.output_data = rp.src;
+
+ for (int ocb_inner = 0; ocb_inner < nboc_inner;
+ ocb_inner += ocb_inner_step) {
+ int cur_ocb_inner =
+ nstl::min(ocb_inner + ocb_inner_step, nboc_inner) -
+ ocb_inner;
+
+ int ocb = reduce_outer ? ocb_outer : ocb_inner;
+ int nb_oc_blocking_step = reduce_outer
+ ? cur_ocb_outer : cur_ocb_inner;
+ const int _ocb = g * nb_oc + ocb;
+ size_t diff_dst_off = data_blk_off(diff_dst_d, n, _ocb, oh, ow);
+ p.bcast_data = &diff_dst[diff_dst_off];
+
+ p.load_data = &weights[pd()->with_groups()
+ ? weights_d.blk_off(g, ocb, icb)
+ : weights_d.blk_off(ocb, icb)];
+
+ p.first_last_flag = ocb == 0 ? FLAG_REDUCE_FIRST : 0;
+
+ p.reduce_dim = this_block_size(ocb * jcp.oc_block,
+ jcp.oc, nb_oc_blocking_step * jcp.oc_block);
+
+ kernel_->jit_ker(&p);
+ }
+ if (pd()->rtus_.reduce_src_)
+ rtus_driver_->ker_(&rp);
+ }
+ }
+ }
+ });
+}
+
+template struct jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
+
+/* convolution backward wtr weights */
+
+#define wht_blk_off(d, g, ...) \
+ (pd()->with_groups() \
+ ? (d).blk_off((g), __VA_ARGS__) \
+ : (d).blk_off(__VA_ARGS__))
+
+jit_avx512_common_1x1_convolution_bwd_weights_t ::
+ jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
+ , trans_kernel_(nullptr), rtus_driver_(nullptr)
+{
+ kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
+ acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
+ reducer_bias_ = new cpu_reducer_t<data_type::f32>(pd()->reducer_bia_conf_);
+ init_rtus_driver<avx512_common>(this);
+
+ const auto &jcp = kernel_->jcp;
+
+ if (jcp.transpose_src) {
+ auto tp = jit_transpose4x16_src_t();
+ tp.src_pf0_distance = 4;
+ tp.tr_src_pf0_distance = 0;
+ tp.src_pf1 = true;
+ tp.tr_src_pf1 = false;
+ trans_kernel_ = new jit_transpose4x16_src(&jcp, &tp);
+ }
+}
+
+void jit_avx512_common_1x1_convolution_bwd_weights_t::execute_backward_weights(
+ const exec_ctx_t &ctx) const
+{
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto diff_bias_in = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+
+ const auto scratchpad = this->scratchpad(ctx);
+
+ auto rtus_space = scratchpad.get<data_t>(key_conv_rtus_space);
+ data_t *diff_bias = pd()->wants_padded_bias()
+ ? scratchpad.get<data_t>(key_conv_padded_bias) : diff_bias_in;
+ auto wei_reduction = scratchpad.get<data_t>(key_conv_wei_reduction);
+
+ /* prepare src transposition barriers */
+ auto tr_src = scratchpad.get<data_t>(key_conv_tr_src);
+ auto tr_src_bctx = scratchpad.get<simple_barrier::ctx_t>(
+ key_conv_tr_src_bctx);
+ if (jcp.transpose_src) {
+ for (int i = 0; i < jcp.nthr; ++i)
+ simple_barrier::ctx_init(&tr_src_bctx[i]);
+ }
+
+ const int ndims = src_d.ndims();
+ const int wei_size = jcp.ngroups * jcp.oc * jcp.ic;
+
+ simple_barrier::ctx_t reduction_barrier;
+ simple_barrier::ctx_init(&reduction_barrier);
+
+ const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_bia);
+ auto rb = this->reducer_bias_;
+ rb->init(reducer_bia_scratchpad);
+
+ // TODO (Roma): remove this restriction
+ assert(jcp.stride_w == 1 && jcp.stride_h == 1);
+
+ const int nb_ic = jcp.nb_bcast;
+ const int nb_ic_blocking = jcp.nb_bcast_blocking;
+
+ const int nb_oc = jcp.nb_load;
+ const int nb_oc_blocking = jcp.nb_load_blocking;
+
+ const int sp_nb = jcp.nb_reduce;
+ const int mb_sp_work = jcp.mb * sp_nb;
+
+ const int stride_h = (ndims == 3) ? 1 : pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[ndims - 3];
+ const int pad_t = (ndims == 3) ? 0 : pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][ndims - 3];
+
+ auto step = [](int default_step, int remaining, int tail_step) {
+ assert(default_step <= tail_step);
+ return remaining < tail_step ? remaining : default_step;
+ };
+
+ // TODO: use memory descriptor with the same fmt as src
+ // (or use a macro :))
+ auto tr_src_off = [&](int img, int icb, int is) {
+ const size_t tr_chn_size = jcp.tr_is * jcp.ic_block;
+ const size_t tr_img_size = tr_chn_size * nb_ic * jcp.ngroups;
+ return img * tr_img_size + icb * tr_chn_size + is * jcp.ic_block;
+ };
+
+ auto uker_trans = [&](int ithr_mb, int img, int sp_b_start, int sp_size,
+ int g_start, int g_work, int ic_b_start, int ic_b_work,
+ int ithr, int nthr, int first_ic_b)
+ {
+ const int work_amount = g_work * ic_b_work;
+
+ int start{ 0 }, end{ 0 };
+ balance211(work_amount, nthr, ithr, start, end);
+
+ int g{ 0 }, ic_b{ 0 };
+ nd_iterator_init(start, g, g_work, ic_b, ic_b_work);
+ g += g_start;
+ const int ic_b_tr = g * nb_ic + first_ic_b + ic_b;
+ ic_b += ic_b_start;
+
+ const int _ic = g * nb_ic + ic_b;
+
+ const int is = sp_b_start * jcp.reduce_block;
+ const int ih = is / jcp.iw;
+ const int iw = is % jcp.iw;
+
+ const int src1_off = data_blk_off(src_d, img, _ic, ih, iw);
+ data_t *src1 = (data_t *)&src[src1_off];
+ data_t *tr_src1 = &tr_src[tr_src_off(ithr_mb, ic_b_tr, is)];
+
+ assert(jcp.ic_block == 16);
+ const int src_stride = jcp.is * jcp.ic_block;
+ const int tr_src_stride = jcp.tr_is * jcp.ic_block;
+
+ const int my_work = end - start;
+ for (int iwork = 0; iwork < my_work; iwork++) {
+ auto par_trans = jit_src_transpose_s();
+ assert(sp_size % 4 == 0 || sp_size % 4 == jcp.is % 4);
+ par_trans.size = sp_size;
+ par_trans.src = src1;
+ par_trans.tr_src = tr_src1;
+ par_trans.src_prf = src1 + 64 * 16;
+ par_trans.tr_src_prf = tr_src1 + 80 * 16;
+ trans_kernel_->jit_ker(&par_trans);
+
+ src1 += src_stride;
+ tr_src1 += tr_src_stride;
+ }
+ };
+
+ auto ker = [&](const int ithr, const int nthr) {
+ assert(nthr == jcp.nthr);
+ assert(IMPLICATION(!mkldnn_thr_syncable(), jcp.nthr_mb == 1));
+
+ const int ithr_ic_b = ithr % jcp.nthr_ic_b;
+ const int ithr_oc_b = ithr / jcp.nthr_ic_b % jcp.nthr_oc_b;
+ const int ithr_g = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b % jcp.nthr_g;
+ const int ithr_mb = ithr / jcp.nthr_ic_b / jcp.nthr_oc_b /
+ jcp.nthr_g;
+
+ const int ithr_but_oc
+ = (ithr_mb * jcp.nthr_g + ithr_g) * jcp.nthr_ic_b + ithr_ic_b;
+
+ /* reduction dimension */
+ int mb_sp_b_start{ 0 }, mb_sp_b_end{ 0 };
+ if (jcp.transpose_src && jcp.nthr_mb < jcp.mb / 2) {
+ // it's preferable to parallelize by mb if possible
+ int img_start{ 0 }, img_end{ 0 };
+ balance211(jcp.mb, jcp.nthr_mb, ithr_mb, img_start, img_end);
+ mb_sp_b_start = img_start * sp_nb;
+ mb_sp_b_end = img_end * sp_nb;
+ }
+ else {
+ balance211(mb_sp_work, jcp.nthr_mb, ithr_mb, mb_sp_b_start,
+ mb_sp_b_end);
+ }
+
+ /* independent dimensions */
+ int g_start{ 0 }, oc_b_start{ 0 }, ic_b_start{ 0 };
+ int g_end{ 0 }, oc_b_end{ 0 }, ic_b_end{ 0 };
+
+ balance211(jcp.ngroups, jcp.nthr_g, ithr_g, g_start, g_end);
+ balance211(jcp.nb_load, jcp.nthr_oc_b, ithr_oc_b, oc_b_start,
+ oc_b_end);
+ balance211(jcp.nb_bcast, jcp.nthr_ic_b, ithr_ic_b, ic_b_start,
+ ic_b_end);
+
+ const int g_work = g_end - g_start;
+ const int oc_b_work = oc_b_end - oc_b_start;
+ const int ic_b_work = ic_b_end - ic_b_start;
+
+ data_t *diff_wei = ithr_mb == 0
+ ? diff_weights : wei_reduction + (ithr_mb - 1) * wei_size;
+
+ int sp_b_step = 0;
+ for (int mb_sp_b = mb_sp_b_start; mb_sp_b < mb_sp_b_end;
+ mb_sp_b += sp_b_step) {
+ int img{ 0 }, sp_b{ 0 };
+ nd_iterator_init(mb_sp_b, img, jcp.mb, sp_b, sp_nb);
+ sp_b_step = step(jcp.nb_reduce_blocking,
+ nstl::min(sp_nb - sp_b, mb_sp_b_end - mb_sp_b),
+ jcp.nb_reduce_blocking_max);
+
+ for (int g = g_start; g < g_end; ++g) {
+ int load_step = 0;
+ int bcast_step = 0;
+ for (int ic_b = ic_b_start; ic_b < ic_b_end;
+ ic_b += bcast_step) {
+ bcast_step = step(nb_ic_blocking, ic_b_end - ic_b,
+ jcp.nb_bcast_blocking_max);
+ if (jcp.transpose_src) {
+ if (jcp.nthr_oc_b > 1)
+ simple_barrier::barrier(
+ &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
+ const int sp_size
+ = nstl::min(sp_b_step * jcp.reduce_block,
+ jcp.is - sp_b * jcp.reduce_block);
+ uker_trans(ithr_mb, img, sp_b, sp_size, g, 1, ic_b,
+ bcast_step, ithr_oc_b, jcp.nthr_oc_b, ic_b_start);
+ if (jcp.nthr_oc_b > 1)
+ simple_barrier::barrier(
+ &tr_src_bctx[ithr_but_oc], jcp.nthr_oc_b);
+ }
+
+ for (int oc_b = oc_b_start; oc_b < oc_b_end;
+ oc_b += load_step) {
+ load_step = step(nb_oc_blocking, oc_b_end - oc_b,
+ jcp.nb_load_blocking_max);
+ const int _ic_b = g * nb_ic + ic_b;
+ const int _ic_b_tr = g * nb_ic + ic_b_start;
+ const int _oc_b = g * nb_oc + oc_b;
+
+ data_t *store_to;
+
+ const size_t off
+ = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
+ store_to = diff_wei + off;
+
+ const data_t *diff_src = jcp.transpose_src ?
+ &tr_src[tr_src_off(ithr_mb, _ic_b_tr, 0)] :
+ &src[src_d.blk_off(img, _ic_b)];
+
+ int sp_b_end = sp_b + sp_b_step;
+ const data_t *pdiff_dst
+ = &diff_dst[diff_dst_d.blk_off(img, _oc_b)];
+ const data_t *local_src = diff_src;
+
+ auto p = jit_1x1_conv_call_s();
+ auto rp = rtus_driver_t<avx512_common>::call_params_t();
+
+ p.output_stride
+ = jcp.ic * jcp.oc_block * jcp.typesize_out;
+
+ p.load_dim = load_step * jcp.oc_block;
+
+ p.bcast_dim = bcast_step * jcp.ic_block;
+ rp.icb = bcast_step;
+ p.output_data = store_to;
+
+ p.reduce_dim = sp_b_step * jcp.reduce_block;
+ rp.os = p.reduce_dim;
+
+ p.first_last_flag = 0
+ | (mb_sp_b == mb_sp_b_start ? FLAG_REDUCE_FIRST : 0)
+ | (sp_b_end == sp_nb ? FLAG_SP_LAST : 0);
+
+ int sp = sp_b * jcp.reduce_block;
+ p.load_data = pdiff_dst + sp * jcp.oc_block;
+
+ if (pd()->rtus_.reduce_src_) {
+ const int oh = sp / jcp.ow;
+ const int ow = sp % jcp.ow;
+
+ const int ih = nstl::max(oh * stride_h - pad_t, 0);
+ const int iw = nstl::max(ow * stride_w - pad_l, 0);
+ rp.iw_start = iw;
+
+ rp.ws = rtus_space
+ + ithr * pd()->rtus_.space_per_thread_
+ + sp * jcp.ic_block;
+
+ if (ndims == 3)
+ rp.src = local_src + iw
+ * src_d.blocking_desc().strides[2];
+ else
+ rp.src = local_src + ih
+ * src_d.blocking_desc().strides[2]
+ + iw * src_d.blocking_desc().strides[3];
+ rtus_driver_->ker_(&rp);
+
+ p.bcast_data = rp.ws;
+ } else
+ p.bcast_data = local_src + sp * jcp.ic_block;
+
+ kernel_->jit_ker(&p);
+ }
+ }
+ }
+ }
+
+ /* diff_weights[:] += sum(wei_reduction[thr_mb][:]) */
+ if (jcp.nthr_mb > 1) {
+ simple_barrier::barrier(&reduction_barrier, jcp.nthr);
+ const int work = g_work * oc_b_work * ic_b_work;
+ int start{ 0 }, end{ 0 };
+ balance211(work, jcp.nthr_mb, ithr_mb, start, end);
+ if (start == end)
+ return;
+
+ for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
+ int w = start;
+ int sub_g_start{ 0 }, sub_oc_b_start{ 0 },
+ sub_ic_b_start{ 0 };
+ nd_iterator_init(w, sub_g_start, g_work, sub_oc_b_start,
+ oc_b_work, sub_ic_b_start, ic_b_work);
+ while (w < end) {
+ const int g = g_start + sub_g_start;
+ const int oc_b = oc_b_start + sub_oc_b_start;
+ const int ic_b = ic_b_start + sub_ic_b_start;
+
+ const int acc_size
+ = nstl::min(end - w, ic_b_work - sub_ic_b_start)
+ * jcp.ic_block * jcp.oc_block;
+
+ const size_t off
+ = wht_blk_off(diff_weights_d, g, oc_b, ic_b);
+ data_t *d = diff_weights + off;
+ data_t *s = wei_reduction + (thr_mb - 1) * wei_size + off;
+
+ acc_ker_->accumulate(d, s, acc_size);
+
+ nd_iterator_jump(w, end, sub_g_start, g_work,
+ sub_oc_b_start, oc_b_work, sub_ic_b_start,
+ ic_b_work);
+ }
+ }
+ }
+ };
+
+ auto ker_bias = [&](int ithr, int nthr) {
+ assert(nthr == rb->balancer().nthr_);
+
+ const int b_job_start = rb->balancer().ithr_job_off(ithr);
+ const int b_njobs = rb->balancer().ithr_njobs(ithr);
+
+ if (b_njobs == 0)
+ return;
+
+ /* reduction dimension */
+ int img_start{ 0 }, img_end{ 0 };
+
+ balance211(jcp.mb, rb->balancer().nthr_per_group_,
+ rb->balancer().id_in_group(ithr), img_start, img_end);
+
+ /* jobs */
+ int g_start{ 0 }, ocb_start{ 0 };
+ nd_iterator_init(
+ b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_load);
+
+ for (int img = img_start; img < img_end; ++img) {
+ int g = g_start, ocb = ocb_start;
+ for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
+ const size_t _oc = g * jcp.nb_load + ocb;
+
+ const data_t *d_dst = &diff_dst[diff_dst_d.blk_off(img, _oc)];
+ data_t *d_bias = rb->get_local_ptr(ithr, diff_bias,
+ reducer_bia_scratchpad)
+ + b_job_loc * rb->balancer().job_size_;
+
+ if (img == img_start)
+ for (int o = 0; o < 16; ++o)
+ d_bias[o] = 0.;
+
+ for (int hw = 0; hw < jcp.oh * jcp.ow; ++hw) {
+ PRAGMA_OMP_SIMD()
+ for (int o = 0; o < 16; ++o)
+ d_bias[o] += d_dst[o];
+ d_dst += 16;
+ }
+
+ nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_load);
+ }
+ }
+ rb->reduce(ithr, diff_bias, reducer_bia_scratchpad);
+ };
+
+ parallel(jcp.nthr, [&](const int ithr, const int nthr) {
+ ker(ithr, jcp.nthr);
+ if (pd()->with_bias())
+ ker_bias(ithr, jcp.nthr);
+ });
+
+ /* TODO: put this in ker_bias */
+ if (pd()->wants_padded_bias()) {
+ assert(jcp.ngroups == 1);
+ utils::array_copy(diff_bias_in, diff_bias, jcp.oc_without_padding);
+ }
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp
new file mode 100644
index 0000000000..2e9fda76d6
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_1x1_convolution.hpp
@@ -0,0 +1,344 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
+#define CPU_JIT_AVX512_COMMON_1x1_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+#include "cpu_reducer.hpp"
+
+#include "jit_avx512_common_1x1_conv_kernel.hpp"
+#include "jit_uni_1x1_conv_utils.hpp"
+#include "jit_transpose_src_utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t src_type,
+ impl::data_type_t wei_type = src_type,
+ impl::data_type_t dst_type = src_type>
+struct jit_avx512_common_1x1_convolution_fwd_t : public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_(), rtus_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
+ jit_avx512_common_1x1_convolution_fwd_t);
+
+ status_t init() {
+ using namespace utils;
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(src_type, wei_type, dst_type, dst_type,
+ data_type::undef)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ const convolution_desc_t *conv_d = desc();
+ const memory_desc_t *src_d = src_md();
+ rtus_prepare(this, conv_d, src_d, dst_md());
+
+ status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
+ jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr(),
+ mkldnn_get_max_threads(), rtus_.reduce_src_);
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
+ jcp_);
+
+ rtus_prepare_space_info(this, scratchpad);
+
+ return status::success;
+ }
+
+ jit_1x1_conv_conf_t jcp_;
+ reduce_to_unit_stride_t rtus_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
+ auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
+ OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o);
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ template <cpu_isa_t isa, typename conv_t>
+ friend void init_rtus_driver(conv_t *self);
+
+ jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , kernel_(nullptr), rtus_driver_(nullptr)
+ {
+ kernel_ =
+ new jit_avx512_common_1x1_conv_kernel(pd()->jcp_, *pd()->attr());
+ init_rtus_driver<avx512_common>(this);
+ }
+
+ ~jit_avx512_common_1x1_convolution_fwd_t() {
+ delete kernel_;
+ delete rtus_driver_;
+ }
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<wei_type>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+ private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ void execute_forward_thr(const int ithr, const int nthr,
+ const src_data_t *src, const wei_data_t *weights,
+ const dst_data_t *bias, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx512_common_1x1_conv_kernel *kernel_;
+ rtus_driver_t<avx512_common> *rtus_driver_;
+};
+
+using jit_avx512_common_1x1_convolution_fwd_f32_t
+ = jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
+
+template <impl::data_type_t diff_dst_type,
+ impl::data_type_t wei_type = diff_dst_type,
+ impl::data_type_t diff_src_type = diff_dst_type>
+struct jit_avx512_common_1x1_convolution_bwd_data_t : public cpu_primitive_t {
+ struct pd_t : public cpu_convolution_bwd_data_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_(), rtus_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
+ jit_avx512_common_1x1_convolution_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(diff_src_type, wei_type, data_type::undef,
+ diff_dst_type, data_type::undef)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ const convolution_desc_t *conv_d = desc();
+ const memory_desc_t *diff_src_d = diff_src_md();
+ rtus_prepare(this, conv_d, diff_src_d, diff_dst_md());
+
+ status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
+ jcp_, *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(),
+ *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_);
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
+ jcp_);
+
+ rtus_prepare_space_info(this, scratchpad);
+
+ return status::success;
+ }
+
+ // TODO (Roma): structs conf header cleanup
+ jit_1x1_conv_conf_t jcp_;
+ reduce_to_unit_stride_t rtus_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
+ auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
+ IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i);
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ template <cpu_isa_t isa, typename conv_t>
+ friend void init_rtus_driver(conv_t *self);
+
+ jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , kernel_(nullptr), rtus_driver_(nullptr)
+ {
+ kernel_ = new jit_avx512_common_1x1_conv_kernel(pd()->jcp_,
+ *pd()->attr());
+ init_rtus_driver<avx512_common>(this);
+ }
+
+ ~jit_avx512_common_1x1_convolution_bwd_data_t() {
+ delete kernel_;
+ delete rtus_driver_;
+ }
+
+ typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
+ typedef typename prec_traits<wei_type>::type wei_data_t;
+ typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+ private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx512_common_1x1_conv_kernel *kernel_;
+ rtus_driver_t<avx512_common> *rtus_driver_;
+};
+
+using jit_avx512_common_1x1_convolution_bwd_data_f32_t
+ = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
+
+struct jit_avx512_common_1x1_convolution_bwd_weights_t : public cpu_primitive_t
+{
+ struct pd_t : public cpu_convolution_bwd_weights_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_(), rtus_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
+ jit_avx512_common_1x1_convolution_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_weights
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ const convolution_desc_t *conv_d = desc();
+ const memory_desc_t *src_d = src_md();
+ rtus_prepare(this, conv_d, src_d, diff_dst_md());
+
+ status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(
+ jcp_, *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(),
+ *attr(), mkldnn_get_max_threads(), rtus_.reduce_src_);
+ if (status != status::success) return status;
+
+ init_balancers();
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx512_common_1x1_conv_kernel::init_scratchpad(scratchpad,
+ jcp_);
+
+ auto reducer_bia_scratchpad = memory_tracking::registrar_t(
+ scratchpad, memory_tracking::names::prefix_reducer_bia);
+ reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
+
+ rtus_prepare_space_info(this, scratchpad);
+
+ return status::success;
+ }
+
+ // TODO (Roma): structs conf header cleanup
+ jit_1x1_conv_conf_t jcp_;
+ cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
+ reduce_to_unit_stride_t rtus_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
+ auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
+ OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o);
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+
+ private:
+ void init_balancers() {
+ const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
+ if (with_bias()) {
+ reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
+ jcp_.oc_block, jcp_.ngroups * jcp_.nb_load,
+ jcp_.mb, max_buffer_size));
+ }
+ }
+ };
+
+ template <cpu_isa_t isa, typename conv_t>
+ friend void init_rtus_driver(conv_t *self);
+
+ jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd);
+
+ ~jit_avx512_common_1x1_convolution_bwd_weights_t() {
+ delete kernel_;
+ delete acc_ker_;
+ delete reducer_bias_;
+ delete rtus_driver_;
+ delete trans_kernel_;
+ }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_weights(ctx);
+ return status::success;
+ }
+
+ private:
+ void execute_backward_weights(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx512_common_1x1_conv_kernel *kernel_;
+ cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
+ cpu_reducer_t<data_type::f32> *reducer_bias_;
+ jit_transpose4x16_src *trans_kernel_;
+ rtus_driver_t<avx512_common> *rtus_driver_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp
new file mode 100644
index 0000000000..235fb02fef
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp
@@ -0,0 +1,4539 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_barrier.hpp"
+
+#include "jit_avx512_common_conv_kernel.hpp"
+
+#define GET_OFF(field) offsetof(jit_conv_call_s, field)
+#define KNx_L2_EFFECTIVE_CAPACITY ((512-64)*1024)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+using namespace Xbyak;
+
+namespace {
+
+constexpr auto small_spatial = 14;
+unsigned int L1_cache_size = get_cache_size(1, true);
+
+inline void pick_loop_order(jit_conv_conf_t &jcp) {
+ using namespace prop_kind;
+ assert(one_of(jcp.prop_kind,
+ forward_training, forward_inference, backward_data));
+ auto w = (jcp.prop_kind == backward_data) ? jcp.iw : jcp.ow;
+ auto h = (jcp.prop_kind == backward_data) ? jcp.ih : jcp.oh;
+
+ // ow-threading is currently implemented for forward only
+ // TODO: single code for fwd and bwd after ow-thr for bwd
+ // meaningless switch was removed
+ if (jcp.prop_kind == backward_data) {
+ jcp.loop_order = (w <= small_spatial && h <= small_spatial)
+ ? loop_cgn : loop_gnc;
+ } else {
+ jcp.loop_order = (w <= small_spatial && h <= small_spatial)
+ ? loop_cwgn : loop_gncw;
+ }
+}
+
+inline bool is_1stconv(const jit_conv_conf_t &jcp) {
+ if (mayiuse(avx512_core))
+ return (jcp.ic < 16 && jcp.ngroups == 1);
+ else
+ return one_of(jcp.ic, 1, 3);
+}
+
+inline bool is_ow_threading_on(const jit_conv_conf_t &jcp) {
+ return (jcp.nb_ow > 1);
+}
+
+inline bool is_owb_prefetching(const jit_conv_conf_t &jcp) {
+ return (jcp.ver == ver_4fma && is_ow_threading_on(jcp));
+}
+
+}
+
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::prepare_output(int ur_w)
+{
+ for (int k = 0; k < jcp.nb_oc_blocking; k++)
+ for (int j = 0; j < ur_w; j++) {
+ Vmm vmm = vmm_out(j, k);
+ vpxord(vmm, vmm, vmm);
+ if (!is_owb_prefetching(jcp)) {
+ size_t aux_output_offset = get_output_offset(j, k);
+ mic_prefetcht1(EVEX_compress_addr_safe(reg_out_prf,
+ aux_output_offset, reg_out_long_offt));
+ }
+ }
+}
+
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::store_output(int ur_w)
+{
+ Label no_update_label, store_label, eltwise_label;
+
+ mov(reg_channel, ptr[param1 + GET_OFF(channel)]);
+ if (jcp.with_bias) {
+ mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
+ }
+
+ if (!jcp.with_sum) {
+ cmp(reg_channel, 0);
+ je(no_update_label, T_NEAR);
+ }
+
+ for (int k = 0; k < jcp.nb_oc_blocking; k++)
+ for (int j = 0; j < ur_w; j++) {
+ Vmm vmm = vmm_out(j, k);
+ size_t aux_output_offset = get_output_offset(j, k);
+ vaddps(vmm,
+ make_safe_addr(reg_out, aux_output_offset, reg_out_long_offt));
+ }
+
+ if (!jcp.with_sum) {
+ jmp(eltwise_label, T_NEAR);
+ } else {
+ cmp(reg_channel, 0);
+ jne(eltwise_label, T_NEAR);
+ }
+
+ L(no_update_label);
+ if (jcp.with_bias) {
+ for (int k = 0; k < jcp.nb_oc_blocking; k++) {
+ int bias_offset = jcp.typesize_out * k * jcp.oc_block;
+ for (int j = 0; j < ur_w; j++) {
+ Vmm vmm = vmm_out(j, k);
+ vaddps(vmm, EVEX_compress_addr(reg_bias, bias_offset));
+ }
+ mic_prefetcht1(EVEX_compress_addr(reg_bias, bias_offset + 64));
+ }
+ }
+
+ L(eltwise_label);
+ if (jcp.with_eltwise) {
+ cmp(reg_channel, jcp.nb_ic - 1);
+ jl(store_label, T_NEAR);
+
+ if (ur_w == jcp.ur_w) {
+ eltwise_injector_->compute_vector_range(0,
+ jcp.nb_oc_blocking * jcp.ur_w);
+ } else {
+ for (int k = 0; k < jcp.nb_oc_blocking; k++)
+ eltwise_injector_->compute_vector_range(k * jcp.ur_w,
+ k * jcp.ur_w + ur_w);
+ }
+ }
+
+ L(store_label);
+ for (int k = 0; k < jcp.nb_oc_blocking; k++)
+ for (int j = 0; j < ur_w; j++) {
+ Vmm vmm = vmm_out(j, k);
+ size_t aux_output_offset = (size_t)typesize *
+ ((size_t)k * jcp.od * jcp.oh * jcp.ow + j) * jcp.oc_block;
+ vmovups(EVEX_compress_addr_safe(reg_out, aux_output_offset,
+ reg_out_long_offt), vmm);
+ if (!is_owb_prefetching(jcp))
+ mic_prefetcht0(EVEX_compress_addr_safe(reg_out_prf,
+ aux_output_offset, reg_out_long_offt));
+ }
+}
+
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma_1st(int ur_w,
+ int pad_l, int pad_r)
+{
+}
+
+template<>
+void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma_1st(int ur_w,
+ int pad_l, int pad_r)
+{
+ assert(jcp.dilate_d == 0 && jcp.dilate_h == 0 && jcp.dilate_w == 0);
+
+ int iw = jcp.iw;
+ int ih = jcp.ih;
+ int kw = jcp.kw;
+ int stride_w = jcp.stride_w;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+
+ Label kh_label, kd_label;
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ mov(aux_reg_inp, reg_inp);
+ mov(aux_reg_ker, reg_ker);
+ mov(aux_reg_inp_prf, reg_inp_prf);
+ }
+
+ size_t max_input_offset = (size_t)jcp.typesize_in
+ * ((size_t)(kw + ur_w * stride_w - pad_l)
+ + (size_t)ic_block * iw * ih * jcp.id);
+ assert(reg_inp_prf == reg_long_offt);
+ if (max_input_offset > INT_MAX) push(reg_inp_prf);
+
+ if (jcp.ndims == 5) {
+ push(reg_out_prf);
+ push(reg_out);
+
+ mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
+ mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
+ mov(aux_reg_inp_d, reg_inp);
+ mov(aux_reg_inp_d_prf, reg_inp_prf);
+
+ L(kd_label);
+ }
+ mov(reg_kj, reg_kh);
+ if (jcp.ndims == 5) {
+ mov(aux_reg_inp, aux_reg_inp_d);
+ mov(aux_reg_ker, aux_reg_ker_d);
+ mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
+ }
+
+ L(kh_label);
+ for (int ki = 0; ki < kw; ki += 4) {
+ for (int ic = 0; ic < ic_block; ic++) {
+ for (int i = 0; i < 4; i++) {
+ int aux_ker_offset
+ = jcp.typesize_in
+ * ((ki + i) * oc_block
+ + ic * kw * jcp.kh * jcp.kd * oc_block);
+ if (ki + i < kw)
+ vmovups(vmm_ker(i),
+ EVEX_compress_addr(aux_reg_ker, aux_ker_offset));
+ else
+ vpxord(vmm_ker(i), vmm_ker(i), vmm_ker(i));
+ }
+
+ int j_start = get_ow_start(ki, pad_l);
+ int j_end = get_ow_end(ur_w, ki, pad_r);
+
+ for (int j = j_start, prf_count=0; j < j_end; j++) {
+ size_t aux_input_offset = (size_t)jcp.typesize_in
+ * ((size_t)(ki + j * stride_w
+ - pad_l) + (size_t)ic * iw * ih * jcp.id);
+ v4fmaddps(vmm_out(j, 0), vmm_ker(0),
+ EVEX_compress_addr_safe(aux_reg_inp, aux_input_offset,
+ reg_long_offt));
+ if (ki + prf_count < kw && prf_count < 4
+ && ((ki < 2 && j % 4) || j % 2)) {
+ int aux_ker_offset = jcp.typesize_in
+ * ((ki + prf_count) * oc_block
+ + ic * kw * jcp.kh * jcp.kd * oc_block + kw * oc_block);
+ mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
+ aux_ker_offset));
+ prf_count++;
+ }
+ if (ki == 0
+ && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
+ mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp_prf,
+ aux_input_offset, reg_long_offt));
+ }
+ if (ki == 1
+ && j % (64 / (stride_w * jcp.typesize_in)) == 0) {
+ mic_prefetcht0(EVEX_compress_addr_safe(aux_reg_inp,
+ aux_input_offset+jcp.typesize_in * iw, reg_long_offt));
+ }
+ }
+ }
+ }
+ add(aux_reg_ker, jcp.typesize_in * kw * oc_block);
+ add(aux_reg_inp, jcp.typesize_in * iw);
+ add(aux_reg_inp_prf, jcp.typesize_in * iw);
+
+ dec(reg_kj);
+ cmp(reg_kj, 0);
+ jg(kh_label, T_NEAR);
+
+ if (jcp.ndims == 5) {
+ add(aux_reg_inp_d, typesize * jcp.ih * jcp.iw);
+ add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block);
+ add(aux_reg_inp_d_prf, typesize * jcp.ih * jcp.iw);
+
+ dec(reg_ki);
+ cmp(reg_ki, 0);
+ jg(kd_label, T_NEAR);
+
+ pop(reg_out);
+ pop(reg_out_prf);
+ }
+
+ if (max_input_offset > INT_MAX) pop(reg_inp_prf);
+}
+
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_4fma(int ur_w,
+ int pad_l, int pad_r)
+{
+}
+
+template<>
+void _jit_avx512_common_conv_fwd_kernel<Zmm>::compute_loop_4fma(int ur_w,
+ int pad_l, int pad_r)
+{
+ int stride_w = jcp.stride_w;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ Label kh_label, last_iter_label, loop_end_label, kd_label;
+ int ker_load_number = 4;
+ int shift_kernel_ptr = typesize * jcp.kw * jcp.oc_block * jcp.ic_block;
+ int shift_input_ptr = typesize * (jcp.dilate_h + 1) * jcp.iw * jcp.ic_block;
+
+ bool check_last_kh = (jcp.kh > 3);
+ bool pref_current_inp = (jcp.iw < 14 || jcp.iw > 28);
+
+ int oi_ipref_t0 = get_ow_start(0, pad_l);
+ int ow_end_ipref = get_ow_end(ur_w, 0, pad_r);
+
+ assert(jcp.oc % jcp.nb_oc_blocking == 0);
+
+ auto kernel_offset = [=](int ocb, int ic, int ki) {
+ int blk_idx = ocb * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd + ki;
+ int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
+ int ic_offset = ic * jcp.oc_block;
+ return typesize * (blk_offset + ic_offset);
+ };
+ auto kernel_loads = [=](int ki, int ic, int kk) {
+ for (int ii = 0; ii < ker_load_number; ii++) {
+ int aux_kernel_offset = kernel_offset(kk, ic + ii, ki);
+ vmovups(vmm_ker(ii),
+ EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
+ }
+ };
+ auto prefetch_inp_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
+ if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
+ && ki >= ki_start && oi_ipref_t0 < ow_end_ipref) {
+ int aux_inp_offset
+ = typesize
+ * ((oi_ipref_t0 * stride_w - pad_l) * ic_block
+ + (jcp.dilate_h + 1) * jcp.iw * ic_block);
+ prefetcht0(EVEX_compress_addr(aux_reg_inp,
+ aux_inp_offset));
+ oi_ipref_t0++;
+ }
+ };
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ mov(aux_reg_inp, reg_inp);
+ mov(aux_reg_ker, reg_ker);
+ mov(aux_reg_ker_prf, reg_ker_prf);
+ mov(aux_reg_inp_prf, reg_inp_prf);
+ }
+
+ if (jcp.ndims == 5) {
+ push(reg_out_prf);
+ push(reg_out);
+
+ mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
+ mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
+ mov(aux_reg_inp_d, reg_inp);
+ mov(aux_reg_inp_d_prf, reg_inp_prf);
+ mov(aux_reg_ker_d_prf, reg_ker_prf);
+ L(kd_label);
+ mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
+ } else {
+ mov(reg_kj, reg_kh);
+ }
+ if (jcp.ndims == 5) {
+ mov(aux_reg_inp, aux_reg_inp_d);
+ mov(aux_reg_ker, aux_reg_ker_d);
+ mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
+ mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
+ }
+
+ align(16);
+ L(kh_label);
+ int kw = jcp.kw;
+ if (check_last_kh) {
+ for (int ki = 0; ki < kw; ki++)
+ for (int ic = 0; ic < ic_block; ic += 4)
+ for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
+ bool last_kernel_loads = (kk == jcp.nb_oc_blocking - 1
+ && ki == kw - 1 && (ic + 4) == ic_block);
+
+ if (last_kernel_loads) {
+ cmp(reg_kj, 1);
+ je(last_iter_label, T_NEAR);
+ }
+
+ kernel_loads(ki, ic, kk);
+ for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
+ prf_count_t0 = 0;
+ oi < get_ow_end(ur_w, ki, pad_r); oi++) {
+ int aux_input_offset = typesize
+ * ((ki * (jcp.dilate_w + 1) + oi * stride_w
+ - pad_l) * ic_block
+ + ic);
+ v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
+ EVEX_compress_addr(aux_reg_inp, aux_input_offset));
+
+ if (oi % 2) {
+ if (prf_count_t0 < 4) {
+ int aux_kernel_prf;
+ if (last_kernel_loads)
+ aux_kernel_prf= kernel_offset(0,
+ prf_count_t0 + ic + 4
+ - ic_block, 0) + typesize * kw
+ * oc_block * ic_block;
+ else
+ aux_kernel_prf = kernel_offset(kk, ic + 4
+ + prf_count_t0, ki);
+ mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
+ aux_kernel_prf));
+ prf_count_t0++;
+ } else if (prf_count_t1 < 4) {
+ mic_prefetcht1(EVEX_compress_addr(
+ aux_reg_ker_prf, kernel_offset(kk, ic
+ + prf_count_t1, ki)));
+ prf_count_t1++;
+ }
+ } else
+ prefetch_inp_next_kh(ki, 2, prf_count_t0,
+ prf_count_t1);
+ }
+
+ if (last_kernel_loads) {
+ jmp(loop_end_label, T_NEAR);
+
+ L(last_iter_label);
+
+ kernel_loads(ki, ic, kk);
+ for (int oi = get_ow_start(ki, pad_l), prf_count_t1 = 0,
+ prf_count_t0 = 0;
+ oi < get_ow_end(ur_w, ki, pad_r); oi++) {
+ int aux_input_offset = typesize
+ * ((ki * (jcp.dilate_w + 1) + oi * stride_w
+ - pad_l) * ic_block
+ + ic);
+ v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
+ EVEX_compress_addr(aux_reg_inp,
+ aux_input_offset));
+ if (oi % 2) {
+ if (prf_count_t0 < 4) {
+ mic_prefetcht0(EVEX_compress_addr(
+ aux_reg_ker_prf, kernel_offset(0,
+ prf_count_t0, 0)));
+ prf_count_t0++;
+ } else if (prf_count_t1 < 4) {
+ mic_prefetcht1(EVEX_compress_addr(
+ aux_reg_ker_prf, kernel_offset(kk,
+ ic + prf_count_t1, ki)));
+ prf_count_t1++;
+ }
+ }
+ }
+ L(loop_end_label);
+ }
+ }
+ } else {
+ for (int ki = 0; ki < kw; ki++)
+ for (int ic = 0; ic < ic_block; ic += 4)
+ for (int kk = 0; kk < jcp.nb_oc_blocking; kk++) {
+ kernel_loads(ki, ic, kk);
+ for (int oi = get_ow_start(ki, pad_l),
+ prf_count_t1 = 0, prf_count_t0 = 0;
+ oi < get_ow_end(ur_w, ki, pad_r); oi++) {
+ int aux_input_offset = typesize
+ * ((ki * (jcp.dilate_w + 1) + oi * stride_w
+ - pad_l) * ic_block + ic);
+ v4fmaddps(vmm_out(oi, kk), vmm_ker(0),
+ EVEX_compress_addr(aux_reg_inp,
+ aux_input_offset));
+
+ if (!is_owb_prefetching(jcp)) {
+ if ((oi % 2) && (prf_count_t1 < 4)) {
+ mic_prefetcht1(EVEX_compress_addr(
+ aux_reg_ker_prf, kernel_offset(kk,
+ ic + prf_count_t1, ki)));
+ prf_count_t1++;
+ }
+ } else {
+ if (!(ki == 0 && ic == 0)
+ && !(ki == kw-1 && ic == 0) &&
+ (oi % 2) && (prf_count_t1 < 4)
+ ) {
+ mic_prefetcht0(EVEX_compress_addr(
+ aux_reg_ker, kernel_offset(kk,
+ ic + 4 + prf_count_t0, ki)));
+ prf_count_t0++;
+ }
+ }
+ if (!is_owb_prefetching(jcp)) {
+ if (pref_current_inp) {
+ if (ki == 0 && ic == 0 && kk == 0)
+ mic_prefetcht0(EVEX_compress_addr(
+ aux_reg_inp,
+ aux_input_offset + shift_input_ptr));
+ } else {
+ if (ki == 1 && ic == 0 && kk == 0)
+ mic_prefetcht1(EVEX_compress_addr(
+ aux_reg_inp_prf, aux_input_offset));
+ }
+ } else {
+ int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
+ int inp_shift
+ = jcp.typesize_in * ur_w * stride_w * inp_mult;
+ bool kk_pref_slot = kk ? oi % 2 : !(oi % 2);
+ if (ki == 0 && ic == 0 && kk_pref_slot)
+ mic_prefetcht1(EVEX_compress_addr(
+ aux_reg_inp,
+ aux_input_offset + inp_shift));
+
+ if (ki == kw - 1 && ic == 0 && kk_pref_slot)
+ mic_prefetcht0(EVEX_compress_addr(
+ aux_reg_inp,
+ aux_input_offset + inp_shift));
+ }
+ }
+ }
+ }
+
+ add(aux_reg_ker, shift_kernel_ptr);
+ add(aux_reg_inp, shift_input_ptr);
+ add(aux_reg_ker_prf, shift_kernel_ptr);
+ add(aux_reg_inp_prf, shift_input_ptr);
+
+ dec(reg_kj);
+ cmp(reg_kj, 0);
+ jg(kh_label, T_NEAR);
+
+ if (jcp.ndims == 5) {
+ add(aux_reg_inp_d,
+ typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
+ add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
+ * jcp.ic_block);
+ add(aux_reg_inp_d_prf,
+ typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * jcp.ic_block);
+ add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
+ * jcp.ic_block);
+
+ dec(reg_ki);
+ cmp(reg_ki, 0);
+ jg(kd_label, T_NEAR);
+
+ pop(reg_out);
+ pop(reg_out_prf);
+ }
+}
+
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma(int ur_w,
+ int pad_l, int pad_r)
+{
+ bool prf_ker = true;
+ bool prf_inp = true;
+ int ih = jcp.ih;
+ int stride_w = jcp.stride_w;
+ int id = jcp.id;
+ int iw = jcp.iw;
+ int kw = jcp.kw;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ int nb_oc_block = jcp.nb_oc_blocking;
+ Label kh_label, kd_label;
+
+ int ker_pipeline_depth = 4;
+ assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
+ assert(oc_block >= ker_pipeline_depth);
+
+ int num_ker_loads = ic_block * nb_oc_block * kw;
+ int num_ker_prfs = prf_ker ? num_ker_loads : 0;
+ int num_inp_prfs = prf_inp ?
+ ur_w * nstl::min(kw, stride_w) + nstl::max(0, kw - stride_w) :
+ 0;
+ if (jcp.is_1stconv && prf_inp) {
+ num_inp_prfs = div_up(num_inp_prfs, jcp.simd_w) * ic_block;
+ }
+ int num_prfs = num_ker_prfs + num_inp_prfs;
+ int num_fmas = num_ker_loads * ur_w;
+ int prf_inst_spacing
+ = (prf_ker || prf_inp) ? nstl::max(1, num_fmas / num_prfs) : 1;
+ int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
+ int inp_mul = !jcp.is_1stconv ? ic_block : 1;
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ mov(aux_reg_inp, reg_inp);
+ mov(aux_reg_ker, reg_ker);
+ mov(aux_reg_inp_prf, reg_inp_prf);
+ mov(aux_reg_ker_prf, reg_ker_prf);
+ }
+
+ size_t max_input_offset = (size_t)jcp.typesize_in * ic_block * iw * ih * id;
+ assert(reg_inp_prf == reg_long_offt);
+ if (max_input_offset > INT_MAX) push(reg_inp_prf);
+
+
+ if (jcp.ndims == 5) {
+ push(reg_out_prf);
+ push(reg_out);
+
+ mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
+ mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
+ mov(aux_reg_inp_d, reg_inp);
+ mov(aux_reg_inp_d_prf, reg_inp_prf);
+ mov(aux_reg_ker_d_prf, reg_ker_prf);
+
+ L(kd_label);
+ mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
+ } else {
+ mov(reg_kj, reg_kh);
+ }
+
+ if (jcp.ndims == 5) {
+ mov(aux_reg_inp, aux_reg_inp_d);
+ mov(aux_reg_ker, aux_reg_ker_d);
+ mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
+ mov(aux_reg_inp_prf, aux_reg_inp_d_prf);
+ }
+
+ align(16);
+ L(kh_label);
+ {
+ int step = 0;
+ int ker_prfs = 0;
+ for (int ki = 0; ki < kw; ki++) {
+ for (int ic = 0; ic < ic_block; ic++) {
+ int aux_kernel_offset = 0;
+ if (step == 0) {
+ for (int i = 0; i < ker_pipeline_depth; i++) {
+ aux_kernel_offset = get_kernel_offset(ki, ic, 0, i);
+ vmovups(vmm_ker(i), EVEX_compress_addr(
+ aux_reg_ker, aux_kernel_offset));
+ }
+ } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
+ int load_offset = ker_pipeline_depth - 1;
+ int ker_load_reg_idx
+ = (step + load_offset) % ker_pipeline_depth;
+ aux_kernel_offset
+ = get_kernel_offset(ki, ic, 0, load_offset);
+ vmovups(vmm_ker(ker_load_reg_idx),
+ EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
+ }
+
+ bool ker_prf_inserted = false;
+ Vmm vmm_kernel = vmm_ker(step % ker_pipeline_depth);
+ int j_start = get_ow_start(ki, pad_l);
+ int j_end = get_ow_end(ur_w, ki, pad_r);
+ for (int j = j_start; j < j_end; j++) {
+ size_t aux_input_offset = get_input_offset(ki, ic, j, pad_l);
+ auto addr = EVEX_compress_addr_safe(aux_reg_inp,
+ aux_input_offset, reg_long_offt, true);
+ vfmadd231ps(vmm_out(j, 0), vmm_kernel, addr);
+ int fma_idx = step * ur_w + j;
+ int prf_slot_idx = fma_idx / prf_inst_spacing;
+ if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
+ if (prf_ker && !ker_prf_inserted
+ && ker_prfs < num_ker_prfs) {
+ int ker_prf_offset
+ = jcp.typesize_in * ker_prfs * jcp.oc_block;
+ mic_prefetcht2(EVEX_compress_addr(
+ aux_reg_ker_prf, ker_prf_offset));
+ ker_prf_inserted = true;
+ ker_prfs++;
+ } else if (prf_inp) {
+ int inp_prf_idx = prf_slot_idx - ker_prfs;
+ if (inp_prf_idx < num_inp_prfs) {
+ size_t inp_prf_stride = nstl::max(kw, stride_w);
+ size_t inp_prf_offset;
+ if (!jcp.is_1stconv) {
+ inp_prf_offset
+ = ic_block * jcp.typesize_in
+ * ((inp_prf_idx / kw)
+ * inp_prf_stride
+ + (inp_prf_idx % kw));
+ } else {
+ size_t ic_prf_stride =
+ (size_t)jcp.typesize_in * iw * ih * id;
+ size_t iw_prf_stride
+ = jcp.typesize_in * jcp.simd_w;
+ inp_prf_offset = ((inp_prf_idx / ic_block)
+ * iw_prf_stride
+ + (inp_prf_idx % ic_block)
+ * ic_prf_stride);
+ }
+ mic_prefetcht0(EVEX_compress_addr_safe(
+ aux_reg_inp_prf, inp_prf_offset,
+ reg_long_offt));
+ }
+ }
+ }
+ }
+ step++;
+ }
+ }
+ add(aux_reg_ker, jcp.typesize_in * kw * oc_block * ic_block);
+ if (prf_ker)
+ add(aux_reg_ker_prf, jcp.typesize_in * kw * oc_block * ic_block);
+ add(aux_reg_inp, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
+ if (prf_inp)
+ add(aux_reg_inp_prf,
+ jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
+ dec(reg_kj);
+ cmp(reg_kj, 0);
+ jg(kh_label, T_NEAR);
+ }
+
+
+ if (jcp.ndims == 5) {
+ add(aux_reg_inp_d,
+ typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
+ add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
+ * jcp.ic_block);
+ add(aux_reg_inp_d_prf,
+ typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
+ add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh * jcp.oc_block
+ * jcp.ic_block);
+
+ dec(reg_ki);
+ cmp(reg_ki, 0);
+ jg(kd_label, T_NEAR);
+
+ pop(reg_out);
+ pop(reg_out_prf);
+ }
+ if (max_input_offset > INT_MAX) pop(reg_inp_prf);
+}
+
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop_fma_core(int ur_w,
+ int pad_l, int pad_r)
+{
+ int kw = jcp.kw;
+ int stride_w = jcp.stride_w;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ int nb_oc_block = jcp.nb_oc_blocking;
+ Label kh_label, kd_label;
+ int shift_kernel_ptr = jcp.typesize_in * jcp.kw * jcp.oc_block
+ * jcp.ic_block;
+ int inp_mul = !jcp.is_1stconv ? ic_block : 1;
+ int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
+ * inp_mul;
+
+
+ auto input_offset = [=](int oi, int ic, int ki) {
+ return (size_t)jcp.typesize_in
+ * ((size_t)(ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
+ * inp_mul + (size_t)ic
+ * (!jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id));
+ };
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ mov(aux_reg_inp, reg_inp);
+ mov(aux_reg_ker, reg_ker);
+ }
+
+ if (jcp.ndims == 5) {
+ push(reg_out);
+
+ mov(reg_ki, ptr[param1 + GET_OFF(kd_padding)]);
+ mov(aux_reg_ker_d, ptr[param1 + GET_OFF(filt)]);
+ mov(aux_reg_inp_d, reg_inp);
+
+ L(kd_label);
+ mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
+ } else {
+ mov(reg_kj, reg_kh);
+ }
+
+ if (jcp.ndims == 5) {
+ mov(aux_reg_inp, aux_reg_inp_d);
+ mov(aux_reg_ker, aux_reg_ker_d);
+ }
+
+ L(kh_label);
+ {
+ for (int ki = 0; ki < kw; ki++) {
+ int jj_start = get_ow_start(ki, pad_l);
+ int jj_end = get_ow_end(ur_w, ki, pad_r);
+ for (int ic = 0; ic < ic_block; ic++) {
+ if (jcp.kernel_kind == expl_bcast) {
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ size_t aux_input_offset = input_offset(jj, ic, ki);
+ vbroadcastss(vmm_inp(jj, nb_oc_block),
+ EVEX_compress_addr_safe(aux_reg_inp,
+ aux_input_offset, reg_long_offt));
+ }
+ }
+ for (int ii = 0; ii < nb_oc_block; ii++) {
+ int aux_kernel_offset = jcp.typesize_in
+ * (ii * jcp.nb_ic * jcp.kh * jcp.kw * jcp.kd * ic_block
+ * oc_block + ki * ic_block * oc_block + ic * oc_block);
+ if (jj_end - jj_start > 0)
+ vmovups(vmm_wei, EVEX_compress_addr(aux_reg_ker,
+ aux_kernel_offset));
+ for (int jj = jj_start; jj < jj_end; jj++)
+ if (jcp.kernel_kind == expl_bcast)
+ vfmadd231ps(vmm_out(jj, ii),
+ vmm_inp(jj, nb_oc_block), vmm_wei);
+ else {
+ size_t aux_input_offset = input_offset(jj, ic, ki);
+ vfmadd231ps(vmm_out(jj, ii), vmm_wei,
+ EVEX_compress_addr_safe(aux_reg_inp,
+ aux_input_offset, reg_long_offt, true));
+ }
+ }
+ }
+ }
+ add(aux_reg_ker, shift_kernel_ptr);
+ add(aux_reg_inp, shift_input_ptr);
+ dec(reg_kj);
+ cmp(reg_kj, 0);
+ jg(kh_label, T_NEAR);
+ }
+
+ if (jcp.ndims == 5) {
+ add(aux_reg_inp_d,
+ typesize * (jcp.dilate_d + 1) * jcp.ih * jcp.iw * inp_mul);
+ add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * jcp.oc_block
+ * jcp.ic_block);
+
+ dec(reg_ki);
+ cmp(reg_ki, 0);
+ jg(kd_label, T_NEAR);
+
+ pop(reg_out);
+ }
+}
+
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::compute_loop(int ur_w,
+ int pad_l, int pad_r)
+{
+ if (jcp.ndims == 5) push(reg_oi);
+
+ prepare_output(ur_w);
+
+ Label skip_compute_loop;
+ if (jcp.ndims == 5) {
+ if ((jcp.dilate_d >= jcp.id)
+ || (jcp.kd - 1) * (jcp.dilate_d + 1) < nstl::max(jcp.f_pad, jcp.back_pad)) {
+ mov(reg_kj, ptr[param1 + GET_OFF(kd_padding)]);
+ cmp(reg_kj, 0);
+ je(skip_compute_loop, T_NEAR);
+ }
+ }
+ if ((jcp.dilate_h >= jcp.ih)
+ || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
+ mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
+ cmp(reg_kj, 0);
+ je(skip_compute_loop, T_NEAR);
+ }
+
+ if (jcp.ver == ver_4fma)
+ if(jcp.is_1stconv)
+ compute_loop_4fma_1st(ur_w, pad_l, pad_r);
+ else
+ compute_loop_4fma(ur_w, pad_l, pad_r);
+ else if (jcp.ver == ver_fma)
+ if ((jcp.is_1stconv && jcp.kernel_kind != expl_bcast)
+ || mayiuse(avx512_mic))
+ compute_loop_fma(ur_w, pad_l, pad_r);
+ else
+ if (jcp.kernel_kind == embd_bcast && jcp.nb_oc_blocking == 1)
+ compute_loop_fma(ur_w, pad_l, pad_r);
+ else
+ compute_loop_fma_core(ur_w, pad_l, pad_r);
+ else
+ assert(!"unknown convolution version");
+
+ L(skip_compute_loop);
+ store_output(ur_w);
+ if (jcp.ndims == 5) pop(reg_oi);
+}
+
+template<typename Vmm>
+void _jit_avx512_common_conv_fwd_kernel<Vmm>::generate()
+{
+ int iw = jcp.iw;
+ int ow = jcp.ow;
+ int ow_block = jcp.ow_block;
+ int nb_ow = jcp.nb_ow;
+ int kw = jcp.kw;
+ int l_pad = jcp.l_pad;
+ int ur_w = jcp.ur_w;
+ int ur_w_tail = jcp.ur_w_tail;
+ int dilate_w = jcp.dilate_w + 1;
+ int stride_w = jcp.stride_w;
+
+ int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
+ int inp_shift_pad = jcp.typesize_in * (ur_w * stride_w - l_pad) * inp_mult;
+ int inp_shift = jcp.typesize_in * ur_w * stride_w * inp_mult;
+ int inp_shift_pad_second_block = -1 * jcp.typesize_in * l_pad * inp_mult;
+ int out_shift = jcp.typesize_out * ur_w * jcp.oc_block;
+
+ preamble();
+ mov(reg_inp, ptr[param1 + GET_OFF(src)]);
+ mov(reg_out, ptr[param1 + GET_OFF(dst)]);
+ mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
+ mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
+ mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
+
+ int r_pad = nstl::max(
+ 0, (ow - 1) * stride_w + (kw - 1) * dilate_w - (iw + l_pad - 1));
+ int n_oi = ow / ur_w;
+ int r_pad1 = (ur_w * n_oi - 1) * stride_w + (kw - 1) * dilate_w
+ - (iw + l_pad - 1);
+
+ if (!is_ow_threading_on(jcp)) {
+ // ow is being processed as a whole - with left and right paddings
+ if (r_pad1 > 0) n_oi--;
+
+ if (ow == ur_w) {
+ mov(reg_inp_prf, ptr[param1 + GET_OFF(src_prf)]);
+ mov(reg_out_prf, ptr[param1 + GET_OFF(dst_prf)]);
+ compute_loop(ur_w, l_pad, r_pad);
+ } else {
+ mov(reg_inp_prf, reg_inp);
+ mov(reg_out_prf, reg_out);
+ if (n_oi == 0) {
+ add(reg_inp_prf, inp_shift_pad);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w, l_pad, r_pad1);
+ add(reg_inp, inp_shift_pad);
+ add(reg_out, out_shift);
+ if (ur_w_tail != 0) {
+ add(reg_inp_prf, inp_shift);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w_tail, 0, r_pad);
+ }
+ } else {
+ xor_(reg_oi, reg_oi);
+ if (l_pad > 0) {
+ add(reg_inp_prf, inp_shift_pad);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w, l_pad, 0);
+ add(reg_inp, inp_shift_pad);
+ add(reg_out, out_shift);
+ inc(reg_oi);
+ }
+ if ((l_pad <= 0 && n_oi > 0) || (l_pad > 0 && n_oi > 1)) {
+ Label ow_loop_label;
+ L(ow_loop_label);
+ {
+ add(reg_inp_prf, inp_shift);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w, 0, 0);
+ add(reg_inp, inp_shift);
+ add(reg_out, out_shift);
+ inc(reg_oi);
+ cmp(reg_oi, n_oi);
+ jl(ow_loop_label, T_NEAR);
+ }
+ }
+ if (r_pad1 > 0) {
+ add(reg_inp_prf, inp_shift);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w, 0, r_pad1);
+ add(reg_inp, inp_shift);
+ add(reg_out, out_shift);
+ }
+ if (ur_w_tail != 0) {
+ add(reg_inp_prf, inp_shift);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w_tail, 0, r_pad);
+ }
+ }
+ }
+ } else {
+ // ow block is only processed.
+ // Number of block is passed as parameter owb,
+ // and padding processing depends on this number.
+
+ Label end_label, last_oi_label, middle_ow_blocks_label, tail_label;
+ Label oi_loop_label, oi_loop_start_label, oi_loop_end_label;
+
+ assert(ow_block % ur_w == 0);
+ int n_oi_not_last_ow_block = ow_block / ur_w;
+ // to simplify code (and general regs usage),
+ // size of ow block must be >= 2 * ur_w
+ assert(n_oi_not_last_ow_block > 1);
+ int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
+ int n_oi_first_ow_block = n_oi_not_last_ow_block;
+
+ int n_oi_last_ow_block = (ow - ow_block * (nb_ow-1)) / ur_w;
+
+ // prepare right padding
+ bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
+ bool first_ow_block_padded = next_last_ow_block_padded && jcp.nb_ow == 2;
+ bool last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block > 0;
+
+ if (last_ow_block_padded) n_oi_last_ow_block--;
+ else if (first_ow_block_padded) n_oi_first_ow_block--;
+ else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
+
+ mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
+ cmp(reg_owb, 0); // is that the first ow-block ?
+ jg(middle_ow_blocks_label, T_NEAR);
+
+ // the first ow block, compute left padding
+
+ mov(reg_oi, n_oi_first_ow_block);
+ mov(reg_inp_prf, reg_inp);
+ mov(reg_out_prf, reg_out);
+
+ if (l_pad > 0) {
+ mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
+ add(reg_inp_prf, inp_shift_pad);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w, l_pad, 0);
+ add(reg_inp, inp_shift_pad);
+ add(reg_out, out_shift);
+ dec(reg_oi);
+ }
+ jmp(oi_loop_label, T_NEAR);
+
+ // middle or last ow block entry
+
+ L(middle_ow_blocks_label);
+
+ if (l_pad > 0) {
+ // just to consider left padding, not compute
+ add(reg_inp, inp_shift_pad_second_block);
+ add(reg_inp_prf, inp_shift_pad_second_block);
+ }
+
+ // set number of iteration for oi-loop
+ cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
+ mov(reg_oi, n_oi_last_ow_block);
+ je(oi_loop_label, T_NEAR);
+ cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
+ mov(reg_oi, n_oi_next_last_ow_block);
+ je(oi_loop_label, T_NEAR);
+ mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
+
+ // oi loop w/o padding
+ L(oi_loop_label);
+ mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
+ L(oi_loop_start_label);
+ cmp(reg_oi, 0);
+ jle(oi_loop_end_label, T_NEAR);
+
+ add(reg_inp_prf, inp_shift);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w, 0, 0);
+ add(reg_inp, inp_shift);
+ add(reg_out, out_shift);
+ dec(reg_oi);
+ jmp(oi_loop_start_label, T_NEAR);
+ L(oi_loop_end_label);
+
+ mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
+
+ cmp(reg_owb, 0); // first ow-block ?
+ if (first_ow_block_padded) {
+ je(last_oi_label, T_NEAR);
+ } else {
+ je(end_label, T_NEAR);
+ }
+ cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
+ jl(end_label, T_NEAR);
+ if (next_last_ow_block_padded) {
+ je(last_oi_label, T_NEAR);
+ } else {
+ je(end_label, T_NEAR);
+ }
+ // that is last block
+ if (!last_ow_block_padded) {
+ jmp(tail_label, T_NEAR);
+ }
+
+ // last oi block with right padding
+ L(last_oi_label);
+ mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
+ add(reg_inp_prf, inp_shift);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w, 0, r_pad1);
+ add(reg_inp, inp_shift);
+ add(reg_out, out_shift);
+
+ mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
+ cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
+ jl(end_label, T_NEAR);
+
+ L(tail_label);
+ mov(reg_ker_prf, ptr[param1 + GET_OFF(filt_prf)]);
+ if (ur_w_tail != 0) {
+ add(reg_inp_prf, inp_shift);
+ add(reg_out_prf, out_shift);
+ compute_loop(ur_w_tail, 0, r_pad);
+ }
+ L(end_label);
+ }
+ postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+}
+
+bool jit_avx512_common_conv_fwd_kernel::post_ops_ok(
+ jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+
+ switch (p.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
+ case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_avx512_common_conv_fwd_kernel::init_conf(
+ jit_conv_conf_t &jcp, const convolution_desc_t &cd,
+ memory_desc_t &src_md, memory_desc_t &weights_md,
+ memory_desc_t &dst_md, memory_desc_t &bias_md,
+ const primitive_attr_t &attr, int nthreads)
+{
+ using namespace prop_kind;
+
+ if (!mayiuse(avx512_common))
+ return status::unimplemented;
+
+ const memory_desc_wrapper src_d(&src_md);
+ const memory_desc_wrapper weights_d(&weights_md);
+ const memory_desc_wrapper dst_d(&dst_md);
+ const memory_desc_wrapper bias_d(&bias_md);
+
+ const int regs = 28;
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ int ndims = src_d.ndims();
+
+ jcp = zero<decltype(jcp)>();
+ jcp.ndims = ndims;
+ jcp.prop_kind = cd.prop_kind;
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
+ jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
+ jcp.iw = src_d.dims()[ndims-1];
+ jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
+ jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[ndims-2];
+ jcp.ow = dst_d.dims()[ndims-1];
+ jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
+ jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims-2];
+ jcp.kw = weights_d.dims()[with_groups + ndims-1];
+ jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
+ jcp.l_pad = cd.padding[0][ndims-3];
+ jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
+ jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
+ jcp.stride_w = cd.strides[ndims-3];
+
+ jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
+ jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
+ jcp.dilate_w = cd.dilates[ndims-3];
+
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+ - (jcp.ih + jcp.t_pad - 1);
+ jcp.back_pad = (jcp.od - 1) * jcp.stride_d
+ + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
+
+ jcp.is_1stconv = is_1stconv(jcp);
+
+ bool ok_to_pad_channels = true
+ && jcp.ngroups == 1
+ && src_d.data_type() == data_type::f32;
+
+ const int full_simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+ jcp.simd_w = full_simd_w;
+ bool ok_to_try_xmm = true
+ && mayiuse(avx512_core)
+ && src_d.data_type() == data_type::f32
+ && !jcp.is_1stconv
+ && !ok_to_pad_channels
+ && (jcp.ic % jcp.simd_w != 0 || jcp.oc % jcp.simd_w != 0)
+ && (jcp.ic % 8 != 0 || jcp.oc % 8 != 0);
+ if (ok_to_try_xmm)
+ jcp.simd_w = 4;
+
+ jcp.oc_block = jcp.simd_w;
+ jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
+ jcp.aligned_threads = 0;
+
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
+ jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
+ }
+ bool args_ok = true
+ && jcp.oc % jcp.oc_block == 0
+ && jcp.ic % jcp.ic_block == 0;
+ if (!args_ok)
+ return status::unimplemented;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ jcp.with_sum = p.find(primitive_kind::sum) != -1;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise) {
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+ if (dst_d.data_type() == data_type::s32) return status::unimplemented;
+ }
+
+ auto src_tag = jcp.is_1stconv
+ ? pick(ndims - 3, ncw, nchw, ncdhw)
+ : ((jcp.simd_w == 4)
+ ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
+ : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c));
+ auto dst_tag = (jcp.simd_w == 4)
+ ? pick(ndims - 3, nCw4c, nChw4c, nCdhw4c)
+ : pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
+ auto wei_tag = with_groups
+ ? ((jcp.simd_w == 4)
+ ? pick(ndims - 3, gOIw4i4o, gOIhw4i4o, gOIdhw4i4o)
+ : pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o))
+ : ((jcp.simd_w == 4)
+ ? pick(ndims - 3, OIw4i4o, OIhw4i4o, OIdhw4i4o)
+ : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o));
+
+ if (src_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(src_md, src_tag));
+ jcp.src_tag = src_tag;
+ } else {
+ jcp.src_tag = src_d.matches_one_of_tag(src_tag);
+ }
+ if (jcp.src_tag != src_tag)
+ return status::unimplemented;
+
+ if (dst_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(dst_md, dst_tag));
+ jcp.dst_tag = dst_tag;
+ } else {
+ jcp.dst_tag = dst_d.matches_one_of_tag(dst_tag);
+ }
+ if (jcp.dst_tag != dst_tag)
+ return status::unimplemented;
+
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+ if (jcp.with_bias) {
+ if (bias_d.format_kind() == format_kind::any)
+ CHECK(memory_desc_init_by_tag(bias_md, x));
+ }
+
+ if (mayiuse(avx512_common) &&
+ src_d.data_type() == data_type::f32
+ && weights_d.data_type() == data_type::f32
+ && dst_d.data_type() == data_type::f32) {
+ jcp.ver = ver_fma;
+ jcp.typesize_in = sizeof(float);
+ jcp.typesize_out = sizeof(float);
+ if (mayiuse(avx512_mic_4ops))
+ jcp.ver = ver_4fma;
+
+ if (jcp.is_1stconv) {
+ // TODO: fix & remove constraints below
+ bool not_for_4fma
+ = IMPLICATION(everyone_is(0, jcp.l_pad, jcp.t_pad),
+ nstl::max(jcp.kw, jcp.kh) < 7);
+ bool is_dilated
+ = !everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w);
+ if (one_of(true, not_for_4fma, is_dilated))
+ jcp.ver = ver_fma;
+ if (jcp.ver == ver_4fma) {
+ wei_tag = with_groups
+ ? ((jcp.simd_w == 4)
+ ? pick(ndims - 3, gOiw4o, gOihw4o, gOidhw4o)
+ : pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o))
+ : ((jcp.simd_w == 4)
+ ? pick(ndims - 3, Oiw4o, Oihw4o, Oidhw4o)
+ : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o));
+ } else {
+ wei_tag = with_groups
+ ? ((jcp.simd_w == 4)
+ ? pick(ndims - 3, gOwi4o, gOhwi4o, gOdhwi4o)
+ : pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o))
+ : ((jcp.simd_w == 4)
+ ? pick(ndims - 3, Owi4o, Ohwi4o, Odhwi4o)
+ : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o));
+ }
+ }
+ } else {
+ return status::unimplemented;
+ }
+
+ if (weights_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(weights_md, wei_tag));
+ jcp.wei_tag = wei_tag;
+ } else {
+ jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
+ }
+ if (jcp.wei_tag != wei_tag)
+ return status::unimplemented;
+
+ if (jcp.is_1stconv) {
+ jcp.ur_w = nstl::min(jcp.ow, regs);
+ } else {
+ // avx512_core guard - just to avoid possible regression for other archs
+ if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
+ jcp.ur_w = nstl::min(jcp.ow, regs);
+ } else {
+ for (int ur_w = regs; ur_w > 0; --ur_w) {
+ if (jcp.ow % ur_w == 0) {
+ jcp.ur_w = ur_w;
+ break;
+ }
+ }
+ }
+ if ((ndims == 5 && jcp.ur_w <= 8) || (jcp.ur_w <= 1)) {
+ jcp.ur_w = nstl::min(jcp.ow, regs);
+ }
+ }
+ // TODO (Tanya): currently applied to Segnet convolutions only.
+ // Need to try for other topologies
+ if (jcp.ow > 150 && jcp.ur_w < regs/2)
+ jcp.ur_w = regs;
+
+ int n_oi = (jcp.ow / jcp.ur_w);
+ int r_pad = (jcp.ur_w * n_oi - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
+ if (jcp.l_pad > 0 && r_pad > 0)
+ n_oi--;
+
+ bool large_code_size = jcp.ur_w != jcp.ow && jcp.l_pad > 0 && r_pad > 0
+ && ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1));
+ if (large_code_size) {
+ const int max_code_size = 24 * 1024;
+ const int num_ops_per_reg = 6 + jcp.ic_block * jcp.kw;
+ int mult = 1;
+ if (jcp.l_pad > 0) mult += 1;
+ if (r_pad > 0) mult += 1;
+ for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
+ if (ur_w * mult * num_ops_per_reg * 9.0 < max_code_size) {
+ jcp.ur_w = ur_w;
+ break;
+ }
+ }
+ }
+
+ /* Grouped channel offset to support 'non-blocked data' format for
+ * convolution sizes with '(input_channel / ngroups) < simd' */
+ jcp.nonblk_group_off
+ = (jcp.ngroups > 1 && one_of(jcp.src_tag, ncw, nchw, ncdhw)) ?
+ jcp.ic :
+ 1;
+
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+ jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
+
+ auto is_ow_threading_applicable = [=]() {
+ return (true && !jcp.is_1stconv && one_of(jcp.ndims, 3, 4)
+ && IMPLICATION(mayiuse(avx512_mic),
+ jcp.ver == ver_4fma
+ && IMPLICATION(jcp.mb != 1,
+ jcp.ih == 1 && jcp.kh == 1)));
+ };
+
+ if (jcp.ver == ver_4fma && !jcp.is_1stconv) {
+ if ((jcp.kw <= 5 && jcp.kh <= 5 && jcp.kw == jcp.kh && jcp.ow <= 8
+ && jcp.oh <= 8 && jcp.ow == jcp.oh)
+ || (jcp.stride_h != 1 && jcp.ur_w < jcp.ow)) {
+ if (jcp.nb_oc % 2 == 0) {
+ jcp.nb_oc_blocking = 2;
+ jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
+ }
+ } else {
+ for (int i = jcp.nb_oc; i > 0; i--)
+ if (i * jcp.ur_w <= regs && jcp.nb_oc % i == 0) {
+ jcp.nb_oc_blocking = i;
+ break;
+ }
+ }
+ if (jcp.ver == ver_4fma && is_ow_threading_applicable()) {
+ if (jcp.nb_oc % 2 == 0 && jcp.ur_w < jcp.ow
+ && jcp.ow != 2 * jcp.ur_w) {
+ jcp.nb_oc_blocking = 2;
+ jcp.ur_w = nstl::min(jcp.ow, regs / jcp.nb_oc_blocking);
+ }
+ }
+ }
+
+ jcp.ow_block = jcp.ow;
+
+ auto get_thr_eff = [=](int nb_oc_blocking, int ow_block) {
+ int nb_ow = div_up(jcp.ow, ow_block);
+ int nb_oc_chunks = div_up(jcp.nb_oc, nb_oc_blocking);
+ int work_amount = jcp.mb * jcp.oh * nb_oc_chunks * nb_ow;
+ float disbalance = (float)jcp.ow / rnd_up(jcp.ow, ow_block);
+ float thr_eff = disbalance * (float)work_amount
+ / rnd_up(work_amount, nthreads);
+ return thr_eff;
+ };
+
+ auto get_ow_block = [=](int nb_oc_blocking, int ur_w, float &eff) {
+ int res_ow_block = jcp.ow;
+ eff = get_thr_eff(nb_oc_blocking, res_ow_block);
+ if (!is_ow_threading_applicable())
+ return res_ow_block;
+
+ int L2_part = (get_cache_size(2) * 7 / 8) / typesize;
+ if (jcp.ver == ver_4fma)
+ L2_part /= 2;
+ int size_src_chunk = jcp.ic_block * ur_w * jcp.kh;
+ int size_dst_chunk = jcp.oc_block * nb_oc_blocking * ur_w;
+ int size_wei_chunk = jcp.oc_block * nb_oc_blocking * jcp.ic_block
+ * jcp.kw * jcp.kh;
+ int nurw_cache = (L2_part - 2 * size_wei_chunk)
+ / (2 * size_dst_chunk + 2 * size_src_chunk);
+ // current design of generate() requires ow_block >= 2 * ur_w
+ int ow_block_cache = ur_w * nstl::max(2, nurw_cache);
+
+ int ow_block_thr = ow_block_cache;
+ eff = get_thr_eff(nb_oc_blocking, ow_block_thr);
+
+ int max_nb_ow = div_up(jcp.ow, 2 * ur_w);
+ int start_nb_ow = div_up(jcp.ow, ow_block_thr);
+ for (int nb_ow = start_nb_ow; nb_ow <= max_nb_ow; nb_ow++) {
+ int ow_block
+ = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), ur_w), jcp.ow);
+ float eff_threshold = (jcp.ver == ver_4fma) ? 0.8f : 0.9f;
+ if (ow_block < nb_oc_blocking * jcp.oc_block && eff > eff_threshold)
+ break;
+ if (div_up(jcp.ow, ow_block) != nb_ow)
+ continue;
+ float thr_eff = get_thr_eff(nb_oc_blocking, ow_block);
+ float eff_step = (jcp.ver == ver_4fma) ? 1.1f : 1.f;
+ if (ow_block >= 2 * ur_w && thr_eff > eff_step * eff) {
+ ow_block_thr = ow_block;
+ eff = thr_eff;
+ }
+ eff_threshold = (jcp.ver == ver_4fma) ? 0.9f : 0.98f;
+ if (eff > eff_threshold)
+ break;
+ }
+ res_ow_block = nstl::min(jcp.ow, nstl::max(2 * ur_w, ow_block_thr));
+ eff = get_thr_eff(nb_oc_blocking, res_ow_block);
+ return res_ow_block;
+ };
+
+
+ if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
+ int try_nb_oc_blocking = 2;
+ unsigned int ker_inp_size = typesize * div_up(jcp.iw, jcp.stride_w)
+ * jcp.ic_block * jcp.kh * jcp.kd;
+ unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block
+ * try_nb_oc_blocking;
+ unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
+ * jcp.oc_block * try_nb_oc_blocking * jcp.kd;
+ unsigned int ker_total_size = ker_inp_size + ker_out_size
+ + ker_wei_size;
+
+ bool embd_bcast_condition = true
+ && (jcp.kw == 3 && jcp.ow <= 28 && ker_total_size < L1_cache_size)
+ && !(jcp.kw == 3 && jcp.ow == 13 && jcp.ic >= 192)
+ && !(jcp.kw == 3 && jcp.ow == 28 && jcp.ic >= 512);
+
+ if (jcp.mb == 1) {
+ unsigned int inp_size = jcp.mb * div_up(jcp.ih, jcp.stride_h)
+ * div_up(jcp.iw, jcp.stride_w) * jcp.ic;
+ unsigned int wei_size = jcp.ic * jcp.oc * jcp.kh * jcp.kw;
+
+ // Estimate whether we need to limit the number of threads
+ // and calculate this number. Includes some heuristic.
+ int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
+ int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh;
+ int job_size_min = work_amount / nthreads;
+ int job_size_max = div_up(work_amount, nthreads);
+ int ch_max = rnd_up(jcp.oh, job_size_max);
+ int ch_min = (job_size_min == 0)
+ ? jcp.oh
+ : rnd_up(jcp.oh, job_size_min);
+ bool not_aligned_max = ch_max % jcp.oh != 0 && ch_max / jcp.oh < 2
+ && (jcp.oh != 8 || ch_max / jcp.oh > 1);
+ bool not_aligned_min = ch_min % jcp.oh != 0 && ch_min / jcp.oh < 2
+ && (jcp.oh != 8 || ch_min / jcp.oh > 1);
+ bool eligible_case = (jcp.stride_h == 1 && jcp.stride_w == 1)
+ || nthreads > oc_chunks;
+ if (jcp.loop_order == loop_cgn && oc_chunks > 1 && nthreads > 1
+ && wei_size / inp_size > 24
+ && (not_aligned_max || not_aligned_min)
+ && eligible_case) {
+ // Try to find nthreads > mkldnn_get_max_threads() / 2 such
+ // that oc_chunks is a multiple of nthreads, or nthreads is a
+ // multiple of oc_chunks. Otherwise, keep default value.
+ // TODO: implement a task-based alternative without throttling.
+ jcp.aligned_threads = nthreads;
+ for (int i = nthreads; i > nthreads / 2; i--) {
+ if (oc_chunks % i == 0 || i % oc_chunks == 0) {
+ jcp.aligned_threads = i;
+ break;
+ }
+ }
+ }
+ }
+
+ if (jcp.kw > 3
+ || (jcp.stride_w == 1 && jcp.stride_h == 1
+ && embd_bcast_condition)
+ || ((jcp.stride_w != 1 || jcp.stride_h != 1)
+ && ((jcp.mb <= 16 && (jcp.oc <= 192 || jcp.oh <= 10)
+ && embd_bcast_condition)))
+ || (jcp.mb == 1
+ && (jcp.ur_w >= jcp.ow || jcp.is_1stconv
+ || (jcp.ow <= 147 && jcp.oc <= 96)))) {
+ jcp.kernel_kind = embd_bcast;
+ jcp.ur_w = nstl::min(jcp.ow, regs);
+ jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
+ if (ker_total_size < L1_cache_size && jcp.ow <= 8 && jcp.kh <= 3
+ && jcp.kw <= 3 && jcp.nb_oc % try_nb_oc_blocking == 0
+ && IMPLICATION(jcp.is_1stconv, jcp.mb == 1)
+ && IMPLICATION(jcp.mb == 1, jcp.ur_w < jcp.ow)) {
+ jcp.nb_oc_blocking = try_nb_oc_blocking;
+ jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
+ }
+ } else {
+ jcp.kernel_kind = expl_bcast;
+ jcp.nb_ic_blocking = 1;
+ if (IMPLICATION(jcp.is_1stconv, jcp.mb > 1)) {
+ float best_thr_eff = 0.f;
+ int best_nb_oc_blocking = 1;
+ for (int i = nstl::min(jcp.nb_oc, 5); i > 0; i--) {
+ if (jcp.nb_oc % i == 0) {
+ float thr_eff;
+ int ur_w = nstl::min(jcp.ow, 31 / (i + 1));
+ get_ow_block(i, ur_w, thr_eff);
+ if (thr_eff > 1.05f * best_thr_eff) {
+ best_nb_oc_blocking = i;
+ best_thr_eff = thr_eff;
+ }
+ }
+ }
+ jcp.nb_oc_blocking = best_nb_oc_blocking;
+ jcp.ur_w = nstl::min(jcp.ow, 31 / (jcp.nb_oc_blocking + 1));
+ }
+ }
+ }
+
+ jcp.ur_w_tail = jcp.ow % jcp.ur_w;
+
+ args_ok = true
+ && jcp.l_pad <= jcp.ur_w
+ && jcp.ic <= src_d.padded_dims()[1]
+ && jcp.oc <= dst_d.padded_dims()[1]
+ && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
+ && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
+ if (!args_ok)
+ return status::unimplemented;
+
+ int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1)
+ - (jcp.iw + jcp.l_pad - 1));
+ if (r_pad_no_tail > jcp.ur_w)
+ return status::unimplemented;
+
+ pick_loop_order(jcp);
+
+ jcp.nb_ic_L2 = jcp.nb_ic;
+
+ float thr_eff;
+ jcp.ow_block = get_ow_block(jcp.nb_oc_blocking, jcp.ur_w, thr_eff);
+ jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
+
+ const int L2_size = get_cache_size(2, true) / sizeof(float);
+ // Source and output data needs to fit in L2,
+ // leaving some space for weights and prefetching.
+ int h_L2 = int(((0.6f * L2_size) / jcp.simd_w
+ - nstl::min(0, jcp.kh - jcp.stride_h) * jcp.iw)
+ / (jcp.stride_h * jcp.iw + jcp.ow));
+ jcp.h_blocking = nstl::max(1, nstl::min(jcp.oh, h_L2));
+
+ if (jcp.ver == ver_4fma) {
+ if (!is_ow_threading_on(jcp)) {
+ for (int divf = 2, temp_nb = jcp.nb_ic_L2; divf <= jcp.nb_ic;
+ divf++) {
+ size_t l2_src
+ = (size_t)jcp.iw * jcp.ic_block * jcp.ih * temp_nb * jcp.id;
+ size_t l2_dst = (size_t)jcp.ow * jcp.oc_block * jcp.nb_oc_blocking
+ * jcp.oh * jcp.od;
+ size_t l2_filt = (size_t)jcp.kw * jcp.oc_block * jcp.ic_block
+ * jcp.kh * jcp.nb_oc_blocking * temp_nb * jcp.kd;
+ if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
+ if (jcp.kh == 3 && jcp.oh == 7) {
+ jcp.nb_ic_L2 = 1;
+ break;
+ }
+ temp_nb = (jcp.nb_ic_L2 % divf == 0 ? jcp.nb_ic_L2 / divf
+ : jcp.nb_ic_L2);
+ } else {
+ jcp.nb_ic_L2 = temp_nb;
+ break;
+ }
+ }
+ } else if (jcp.ic > 64) {
+ jcp.nb_ic_L2 = 2; /* according to performance data*/
+ }
+ }
+
+ return status::success;
+}
+
+void jit_avx512_common_conv_fwd_kernel::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
+}
+
+void jit_avx512_common_conv_bwd_data_kernel_f32::prepare_output(int ur_w)
+{
+ for (int k = 0; k < jcp.nb_ic_blocking; k++) {
+ for (int j = 0; j < ur_w; j++) {
+ Zmm zmm = zmm_out(j, k);
+ vpxord(zmm, zmm, zmm);
+ size_t aux_src_offset
+ = (size_t)typesize * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j)
+ * jcp.ic_block;
+ mic_prefetcht1(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
+ reg_long_offt));
+ }
+ }
+}
+
+void jit_avx512_common_conv_bwd_data_kernel_f32::store_output(int ur_w)
+{
+ Label no_update_label;
+
+ mov(reg_channel, ptr[param + GET_OFF(channel)]);
+ cmp(reg_channel, 0);
+ je(no_update_label, T_NEAR);
+ for (int k = 0; k < jcp.nb_ic_blocking; k++) {
+ for (int j = 0; j < ur_w; j++) {
+ Zmm zmm = zmm_out(j, k);
+ size_t aux_src_offset = (size_t)typesize
+ * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
+ vaddps(zmm, EVEX_compress_addr_safe(reg_src, aux_src_offset,
+ reg_long_offt));
+ }
+ }
+
+ L(no_update_label);
+ for (int k = 0; k < jcp.nb_ic_blocking; k++) {
+ for (int j = 0; j < ur_w; j++) {
+ Zmm zmm = zmm_out(j, k);
+ size_t aux_src_offset = (size_t)typesize
+ * ((size_t)k * jcp.ih * jcp.iw * jcp.id + j) * jcp.ic_block;
+ vmovups(EVEX_compress_addr_safe(reg_src, aux_src_offset,
+ reg_long_offt), zmm);
+ mic_prefetcht0(EVEX_compress_addr_safe(reg_src_prf, aux_src_offset,
+ reg_long_offt));
+ }
+ }
+}
+
+void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_4fma(
+ int ur_w, int l_overflow, int r_overflow)
+{
+ int ow = jcp.ow;
+ int kw = jcp.kw;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ Label kh_label, last_iter_label, loop_end_label, kd_label;
+ int ker_load_number = 4;
+ int shift_ker_ptr = typesize * kw * oc_block * ic_block;
+ int shift_dst_ptr = typesize * ow * oc_block;
+ int ii_dpref_t0 = get_iw_start(0, l_overflow);
+ int iw_end_ipref = get_iw_end(ur_w, 0, r_overflow);
+
+ bool check_last_kh = (jcp.kh > 3);
+ auto kernel_offset = [=](int icb, int oc, int ki) {
+ int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
+ int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
+ int oc_offset = oc * jcp.oc_block;
+ return typesize * (blk_offset + oc_offset);
+ };
+ auto kernel_loads = [=](int ki, int oc, int kk) {
+ for (int ii = 0; ii < ker_load_number; ii++) {
+ int aux_kernel_offset = kernel_offset(kk, oc + ii, ki);
+ vmovups(zmm_ker(ii),
+ EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
+ }
+ };
+ auto prefetch_dst_next_kh = [&](int ki, int ki_start, int cnt0, int cnt1) {
+ if (cnt1 >= ker_load_number && cnt0 >= ker_load_number
+ && ki >= ki_start && ii_dpref_t0 < iw_end_ipref) {
+ int aux_dst_offset = typesize * ((ii_dpref_t0
+ + jcp.l_pad) * oc_block + jcp.ow * oc_block);
+ prefetcht0(EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
+ ii_dpref_t0++;
+ }
+ };
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ mov(aux_reg_dst, reg_dst);
+ mov(aux_reg_ker, reg_ker);
+ mov(aux_reg_dst_prf, reg_dst_prf);
+ mov(aux_reg_ker_prf, reg_ker_prf);
+ }
+
+ if (jcp.ndims == 5) {
+ push(reg_src_prf);
+ push(reg_src);
+
+ mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
+ mov(aux_reg_dst_d, reg_dst);
+ mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
+ mov(aux_reg_dst_d_prf, reg_dst_prf);
+ mov(aux_reg_ker_d_prf, reg_ker_prf);
+
+ L(kd_label);
+ mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
+ } else {
+ mov(reg_kj, reg_kh);
+ }
+
+ if (jcp.ndims == 5) {
+ mov(aux_reg_dst, aux_reg_dst_d);
+ mov(aux_reg_ker, aux_reg_ker_d);
+ mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
+ mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
+ }
+
+ align(16);
+ L(kh_label);
+ if (check_last_kh) {
+ for (int ki = 0; ki < kw; ki++)
+ for (int oc = 0; oc < oc_block; oc += 4)
+ for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
+ bool last_kernel_loads = (kk == jcp.nb_ic_blocking - 1
+ && ki == kw - 1 && (oc + 4) == oc_block);
+
+ if (last_kernel_loads) {
+ cmp(reg_kj, 1);
+ je(last_iter_label, T_NEAR);
+ }
+
+ kernel_loads(ki, oc, kk);
+ for (int ii = get_iw_start(ki, l_overflow),
+ prf_count_t0 = 0, prf_count_t1 = 0;
+ ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
+ int aux_dst_offset = typesize
+ * ((ii + jcp.l_pad - ki) * oc_block + oc);
+ v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
+ EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
+
+ if (ii % 2) {
+ if (prf_count_t0 < 4) {
+ int aux_kernel_prf;
+ if (last_kernel_loads)
+ aux_kernel_prf= kernel_offset(0, prf_count_t0
+ + oc + 4 - oc_block, 0) + typesize * kw
+ * oc_block * ic_block;
+ else
+ aux_kernel_prf = kernel_offset(kk, oc + 4
+ + prf_count_t0, ki);
+ mic_prefetcht0(EVEX_compress_addr(aux_reg_ker,
+ aux_kernel_prf));
+ prf_count_t0++;
+ } else if (prf_count_t1 < 4) {
+ mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
+ kernel_offset(kk, oc + prf_count_t1, ki)));
+ prf_count_t1++;
+ }
+ } else
+ prefetch_dst_next_kh(ki, 2, prf_count_t0, prf_count_t1);
+ }
+ if (last_kernel_loads) {
+ jmp(loop_end_label, T_NEAR);
+
+ L(last_iter_label);
+
+ kernel_loads(ki, oc, kk);
+ for (int ii = get_iw_start(ki, l_overflow),
+ prf_count_t0 = 0, prf_count_t1 = 0;
+ ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
+ int aux_dst_offset = typesize
+ * ((ii + jcp.l_pad - ki) * oc_block + oc);
+ v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
+ EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
+ if (ii % 2) {
+ if (prf_count_t0 < 4) {
+ mic_prefetcht0(EVEX_compress_addr(aux_reg_ker_prf,
+ kernel_offset(0, prf_count_t0, 0)));
+ prf_count_t0++;
+ } else if (prf_count_t1 < 4) {
+ mic_prefetcht1(EVEX_compress_addr(aux_reg_ker_prf,
+ kernel_offset(kk, oc + prf_count_t1, ki)));
+ prf_count_t1++;
+ }
+ }
+ }
+ L(loop_end_label);
+ }
+ }
+ } else {
+ for (int ki = 0; ki < kw; ki++)
+ for (int oc = 0; oc < oc_block; oc += 4)
+ for (int kk = 0; kk < jcp.nb_ic_blocking; kk++) {
+ kernel_loads(ki, oc, kk);
+
+ for (int ii = get_iw_start(ki, l_overflow), prf_count_t1 = 0;
+ ii < get_iw_end(ur_w, ki, r_overflow); ii++) {
+ int aux_dst_offset = typesize
+ * ((ii + jcp.l_pad - ki) * oc_block + oc);
+ v4fmaddps(zmm_out(ii, kk), zmm_ker(0),
+ EVEX_compress_addr(aux_reg_dst, aux_dst_offset));
+ if ((ii % 2) && (prf_count_t1 < 4)) {
+ mic_prefetcht1(EVEX_compress_addr(
+ aux_reg_ker_prf, kernel_offset(kk,
+ oc + prf_count_t1, ki)));
+ prf_count_t1++;
+ }
+ if ( ki == 1 && oc == 0 && kk == 0)
+ mic_prefetcht1(EVEX_compress_addr(
+ aux_reg_dst_prf, aux_dst_offset));
+ }
+ }
+ }
+
+ add(aux_reg_ker, shift_ker_ptr);
+ sub(aux_reg_dst, shift_dst_ptr);
+ add(aux_reg_ker_prf, shift_ker_ptr);
+ sub(aux_reg_dst_prf, shift_dst_ptr);
+
+ dec(reg_kj);
+ cmp(reg_kj, 0);
+ jg(kh_label, T_NEAR);
+
+ if (jcp.ndims == 5) {
+ sub(aux_reg_dst_d, typesize * (jcp.oh * ow) * ic_block);
+ add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
+ sub(aux_reg_dst_d_prf, typesize * (jcp.oh * ow) * ic_block);
+ add(aux_reg_ker_d_prf, typesize * jcp.kw * jcp.kh *oc_block * ic_block);
+
+ dec(reg_ki);
+ cmp(reg_ki, 0);
+ jg(kd_label, T_NEAR);
+
+ pop(reg_src);
+ pop(reg_src_prf);
+ }
+}
+
+void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma(
+ int ur_w, int l_overflow, int r_overflow)
+{
+ Label kh_label, kd_label;
+ int kw = jcp.kw;
+ int ow = jcp.ow;
+
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ int l_pad = jcp.l_pad;
+ int dilate_w = jcp.dilate_w + 1;
+ int stride_w = jcp.stride_w;
+ int stride_h = jcp.stride_h;
+
+ int ker_pipeline_depth = 4;
+ assert(ker_reg_base_idx + ker_pipeline_depth <= 32);
+ assert(oc_block >= ker_pipeline_depth);
+
+ int num_ker_loads = oc_block * kw;
+ int num_inp_prfs = ur_w * nstl::min(kw, stride_w)
+ + nstl::max(0, kw - stride_w);
+ int num_prfs = num_ker_loads + num_inp_prfs;
+ int num_fmas = num_ker_loads * ur_w / stride_w;
+ int prf_inst_spacing = nstl::max(1, num_fmas / num_prfs);
+ int prf_inst_trigger = (num_fmas % prf_inst_spacing) / 2;
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ mov(aux_reg_dst, reg_dst);
+ mov(aux_reg_ker, reg_ker);
+
+ mov(aux_reg_dst_prf, reg_dst_prf);
+ mov(aux_reg_ker_prf, reg_ker_prf);
+ }
+
+ if (jcp.ndims == 5) {
+ push(reg_src_prf);
+ push(reg_src);
+
+ mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
+ mov(aux_reg_dst_d, reg_dst);
+ mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
+ mov(aux_reg_dst_d_prf, reg_dst_prf);
+ mov(aux_reg_ker_d_prf, reg_ker_prf);
+
+ L(kd_label);
+ mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
+ } else {
+ mov(reg_kj, reg_kh);
+ }
+
+ if (jcp.ndims == 5) {
+ mov(aux_reg_dst, aux_reg_dst_d);
+ mov(aux_reg_ker, aux_reg_ker_d);
+ mov(aux_reg_dst_prf, aux_reg_dst_d_prf);
+ mov(aux_reg_ker_prf, aux_reg_ker_d_prf);
+ }
+
+ L(kh_label); {
+ int step = 0;
+ int ker_prfs = 0;
+ for (int ki = 0; ki < kw; ki++) {
+ for (int oc = 0; oc < oc_block; oc++) {
+ if (step == 0) {
+ for (int i = 0; i < ker_pipeline_depth; i++) {
+ int aux_kernel_offset = typesize * ((oc + i) * oc_block
+ + ki * ic_block * oc_block);
+ vmovups(zmm_ker(i), EVEX_compress_addr(
+ aux_reg_ker, aux_kernel_offset));
+ }
+ } else if (step < num_ker_loads - ker_pipeline_depth + 1) {
+ int load_offset = ker_pipeline_depth - 1;
+ int ker_load_reg_idx
+ = (step + load_offset) % ker_pipeline_depth;
+ int aux_kernel_offset = typesize * ((oc + load_offset)
+ * oc_block + ki * ic_block * oc_block);
+ vmovups(zmm_ker(ker_load_reg_idx),
+ EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
+ }
+
+ bool ker_prf_inserted = false;
+ auto zmm_kernel = zmm_ker(step % ker_pipeline_depth);
+
+ int jj_start = get_iw_start(ki, l_overflow);
+ int jj_end = get_iw_end(ur_w, ki, r_overflow);
+ assert(stride_w != 1
+ || jj_start == nstl::max(0,
+ l_overflow - (kw - 1 - ki) * dilate_w));
+ assert(stride_w != 1
+ || jj_end == ur_w - nstl::max(0,
+ r_overflow - ki * dilate_w));
+
+ for (int jj = jj_start; jj < jj_end; jj += stride_w) {
+ assert((jj + l_pad - ki * dilate_w) % stride_w == 0);
+ int aux_dst_offset = typesize *
+ (((jj + l_pad - ki * dilate_w)
+ / stride_w) * jcp.oc_block + oc);
+ vfmadd231ps(zmm_out(jj, 0), zmm_kernel,
+ EVEX_compress_addr(aux_reg_dst, aux_dst_offset, true));
+
+ int fma_idx = (step * ur_w + jj) / stride_w;
+ int prf_slot_idx = fma_idx / prf_inst_spacing;
+ if (fma_idx % prf_inst_spacing == prf_inst_trigger) {
+ if (!ker_prf_inserted && ker_prfs < num_ker_loads) {
+ int ker_prf_offset = typesize
+ * ker_prfs * jcp.oc_block;
+ mic_prefetcht1(EVEX_compress_addr(
+ aux_reg_ker_prf, ker_prf_offset));
+ ker_prf_inserted = true;
+ ker_prfs++;
+ } else {
+ int inp_prf_idx = prf_slot_idx - ker_prfs;
+ if (inp_prf_idx < num_inp_prfs) {
+ int inp_prf_offset
+ = ic_block * typesize
+ * ((inp_prf_idx / kw) * kw
+ + (inp_prf_idx % kw));
+ mic_prefetcht0(EVEX_compress_addr(
+ aux_reg_dst_prf, inp_prf_offset));
+ }
+ }
+ }
+ }
+ step++;
+ }
+ }
+
+ add(aux_reg_ker, typesize * stride_h * kw * oc_block * ic_block);
+ sub(aux_reg_dst, typesize * (jcp.dilate_h + 1) * ow * oc_block);
+ add(aux_reg_ker_prf, typesize * stride_h * kw * oc_block * ic_block);
+ sub(aux_reg_dst_prf, typesize * (jcp.dilate_h + 1) * ow * oc_block);
+
+ dec(reg_kj);
+ cmp(reg_kj, 0);
+ jg(kh_label, T_NEAR);
+ }
+ if (jcp.ndims == 5) {
+ sub(aux_reg_dst_d,
+ typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
+ add(aux_reg_ker_d, typesize * jcp.stride_d * jcp.kw * jcp.kh
+ * oc_block * ic_block);
+ sub(aux_reg_dst_d_prf,
+ typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
+ add(aux_reg_ker_d_prf, typesize * jcp.stride_d * jcp.kw * jcp.kh
+ * oc_block * ic_block);
+
+ dec(reg_ki);
+ cmp(reg_ki, 0);
+ jg(kd_label, T_NEAR);
+ }
+
+ if (jcp.ndims == 5)
+ {
+ pop(reg_src);
+ pop(reg_src_prf);
+ }
+}
+
+void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop_fma_core(
+ int ur_w, int l_overflow, int r_overflow)
+{
+ int kw = jcp.kw;
+ int ow = jcp.ow;
+ int dilate_w = jcp.dilate_w + 1;
+ int stride_w = jcp.stride_w;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ int nb_ic_block = jcp.nb_ic_blocking;
+ Label kh_label, kd_label;
+
+ int shift_ker_ptr = typesize * kw * oc_block * ic_block;
+ int shift_dst_ptr = typesize * (jcp.dilate_h + 1) * ow * oc_block;
+
+ auto output_offset = [=](int oi, int oc, int ki) {
+ return typesize *
+ (((oi + jcp.l_pad - ki * dilate_w) / stride_w) * oc_block + oc);
+ };
+ auto kernel_offset = [=](int icb, int oc, int ki) {
+ int blk_idx = icb * jcp.kh * jcp.kw * jcp.kd + ki;
+ int blk_offset = blk_idx * jcp.oc_block * jcp.ic_block;
+ int oc_offset = oc * jcp.oc_block;
+ return typesize * (blk_offset + oc_offset);
+ };
+
+ if (one_of(jcp.ndims, 3, 4)) {
+ mov(aux_reg_dst, reg_dst);
+ mov(aux_reg_ker, reg_ker);
+ }
+
+ if (jcp.ndims == 5) {
+ push(reg_src_prf);
+ push(reg_src);
+
+ mov(reg_ki, ptr[param + GET_OFF(kd_padding)]);
+ mov(aux_reg_dst_d, reg_dst);
+ mov(aux_reg_ker_d, ptr[param + GET_OFF(filt)]);
+
+ L(kd_label);
+ mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
+ } else {
+ mov(reg_kj, reg_kh);
+ }
+
+ if (jcp.ndims == 5) {
+ mov(aux_reg_dst, aux_reg_dst_d);
+ mov(aux_reg_ker, aux_reg_ker_d);
+ }
+
+ L(kh_label);
+ {
+ for (int ki = 0; ki < kw; ki++) {
+ int jj_start = get_iw_start(ki, l_overflow);
+ int jj_end = get_iw_end(ur_w, ki, r_overflow);
+ for (int oc = 0; oc < oc_block; oc++) {
+ if (jcp.kernel_kind == expl_bcast) {
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ int aux_output_offset = output_offset(jj, oc, ki);
+ vbroadcastss(zmm_inp(jj, nb_ic_block),
+ ptr[aux_reg_dst + aux_output_offset]);
+ }
+ }
+ for (int ii = 0; ii < nb_ic_block; ii++) {
+ int aux_kernel_offset = kernel_offset(ii, oc, ki);
+ if (jj_end - jj_start > 0)
+ vmovups(zmm_wei, EVEX_compress_addr(aux_reg_ker,
+ aux_kernel_offset));
+ for (int jj = jj_start; jj < jj_end; jj += stride_w)
+ if (jcp.kernel_kind == expl_bcast)
+ vfmadd231ps(zmm_out(jj, ii),
+ zmm_inp(jj, nb_ic_block), zmm_wei);
+ else
+ vfmadd231ps(zmm_out(jj, ii), zmm_wei,
+ EVEX_compress_addr(aux_reg_dst,
+ output_offset(jj, oc, ki), true));
+ }
+ }
+ }
+ add(aux_reg_ker, shift_ker_ptr);
+ sub(aux_reg_dst, shift_dst_ptr);
+ dec(reg_kj);
+ cmp(reg_kj, 0);
+ jg(kh_label, T_NEAR);
+ }
+
+ if (jcp.ndims == 5) {
+ sub(aux_reg_dst_d,
+ typesize * (jcp.dilate_d + 1) * jcp.oh * ow * ic_block);
+ add(aux_reg_ker_d, typesize * jcp.kw * jcp.kh * oc_block * ic_block);
+
+ dec(reg_ki);
+ cmp(reg_ki, 0);
+ jg(kd_label, T_NEAR);
+
+ pop(reg_src);
+ pop(reg_src_prf);
+ }
+}
+
+inline void jit_avx512_common_conv_bwd_data_kernel_f32::compute_loop(
+ int ur_w, int l_overflow, int r_overflow)
+{
+ if (jcp.ndims == 5) push(reg_oi);
+
+ prepare_output(ur_w);
+
+ Label skip_compute_loop;
+ if (jcp.ndims == 5) {
+ mov(reg_kj, ptr[param + GET_OFF(kd_padding)]);
+ cmp(reg_kj, 0);
+ je(skip_compute_loop, T_NEAR);
+ }
+ mov(reg_kj, ptr[param + GET_OFF(kh_padding)]);
+ cmp(reg_kj, 0);
+ je(skip_compute_loop, T_NEAR);
+
+ if (jcp.ver == ver_4fma)
+ compute_loop_4fma(ur_w, l_overflow, r_overflow);
+ else if (jcp.ver == ver_fma)
+ if (mayiuse(avx512_mic))
+ compute_loop_fma(ur_w, l_overflow, r_overflow);
+ else
+ if (jcp.kernel_kind == embd_bcast && jcp.nb_ic_blocking == 1)
+ compute_loop_fma(ur_w, l_overflow, r_overflow);
+ else
+ compute_loop_fma_core(ur_w, l_overflow, r_overflow);
+ else
+ assert("!unknown convolution version");
+
+ L(skip_compute_loop);
+ store_output(ur_w);
+ if (jcp.ndims == 5) pop(reg_oi);
+}
+
+void jit_avx512_common_conv_bwd_data_kernel_f32::generate()
+{
+ int iw = jcp.iw;
+ int kw = jcp.kw;
+ int ur_w = jcp.ur_w;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ int ur_w_tail = jcp.ur_w_tail;
+ int dilate_w = jcp.dilate_w + 1;
+ int stride_w = jcp.stride_w;
+
+ int dst_shift = jcp.typesize_in * (ur_w / stride_w) * ic_block;
+ int src_shift = jcp.typesize_out * ur_w * oc_block;
+
+ preamble();
+
+ mov(reg_src, ptr[param + GET_OFF(src)]);
+ mov(reg_dst, ptr[param + GET_OFF(dst)]);
+ mov(reg_ker, ptr[param + GET_OFF(filt)]);
+
+ mov(reg_kh, ptr[param + GET_OFF(kh_padding)]);
+ mov(reg_src_prf, ptr[param + GET_OFF(src_prf)]);
+ mov(reg_dst_prf, ptr[param + GET_OFF(dst_prf)]);
+ mov(reg_ker_prf, ptr[param + GET_OFF(filt_prf)]);
+
+ int l_overflow = nstl::max(0, ((kw - 1) * dilate_w - jcp.l_pad) / stride_w);
+ int r_overflow = nstl::max(0, ((kw - 1) * dilate_w
+ - nstl::max(0, jcp.r_pad)) / stride_w);
+ int r_overflow1 = nstl::max(0, ((kw - 1) * dilate_w
+ - nstl::max(0, jcp.r_pad) - ur_w_tail) / stride_w);
+
+ int n_oi = iw / ur_w;
+ if (r_overflow1 > 0) n_oi--;
+
+ if (ur_w == iw) {
+ compute_loop(ur_w, l_overflow, r_overflow);
+ } else if (n_oi == 0) {
+ compute_loop(ur_w, l_overflow, r_overflow1);
+ add(reg_src, src_shift);
+ add(reg_dst, dst_shift);
+ add(reg_src_prf, src_shift);
+ add(reg_dst_prf, dst_shift);
+ if (ur_w_tail != 0)
+ compute_loop(ur_w_tail, 0, r_overflow);
+ } else {
+ xor_(reg_oi, reg_oi);
+ if (l_overflow > 0) {
+ compute_loop(ur_w, l_overflow, 0);
+ add(reg_src, src_shift);
+ add(reg_dst, dst_shift);
+ add(reg_src_prf, src_shift);
+ add(reg_dst_prf, dst_shift);
+
+ inc(reg_oi);
+ }
+ if ((l_overflow <= 0 && n_oi > 0)
+ || (l_overflow > 0 && n_oi > 1)) {
+ Label ow_loop_label;
+ L(ow_loop_label); {
+ compute_loop(ur_w, 0, 0);
+ add(reg_src, src_shift);
+ add(reg_dst, dst_shift);
+ add(reg_src_prf, src_shift);
+ add(reg_dst_prf, dst_shift);
+
+ inc(reg_oi);
+ cmp(reg_oi, n_oi);
+ jl(ow_loop_label, T_NEAR);
+ }
+ }
+ if (r_overflow1 > 0) {
+ compute_loop(ur_w, 0, r_overflow1);
+ add(reg_src, src_shift);
+ add(reg_dst, dst_shift);
+ add(reg_src_prf, src_shift);
+ add(reg_dst_prf, dst_shift);
+ }
+ if (ur_w_tail != 0) {
+ compute_loop(ur_w_tail, 0, r_overflow);
+ }
+ }
+
+ postamble();
+}
+
+status_t jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(
+ jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d)
+{
+ if (!mayiuse(avx512_common)) return status::unimplemented;
+
+ jcp = zero<decltype(jcp)>();
+
+ jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+ const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
+ int ndims = diff_src_d.ndims();
+
+ jcp.ndims = ndims;
+ jcp.prop_kind = cd.prop_kind;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = diff_src_d.dims()[0];
+
+ jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = diff_src_d.dims()[1] / jcp.ngroups;
+
+ jcp.id = (ndims == 5) ? diff_src_d.dims()[2] : 1;
+ jcp.ih = (ndims == 3) ? 1 : diff_src_d.dims()[ndims-2];
+ jcp.iw = diff_src_d.dims()[ndims-1];
+ jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
+ jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
+ jcp.ow = diff_dst_d.dims()[ndims-1];
+
+ jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
+ jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + ndims - 2];
+ jcp.kw = weights_d.dims()[with_groups + ndims - 1];
+
+ jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
+ jcp.l_pad = cd.padding[0][ndims-3];
+
+ jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
+ jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
+ jcp.stride_w = cd.strides[ndims-3];
+
+ jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
+ jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
+ jcp.dilate_w = cd.dilates[ndims-3];
+ if ((jcp.dilate_w != 0 && jcp.stride_w != 1)
+ || (jcp.dilate_d != 0 && jcp.stride_d != 1)
+ || (jcp.dilate_h != 0 && jcp.stride_h != 1))
+ return status::unimplemented;
+
+ jcp.r_pad = (jcp.ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
+ - (jcp.iw + jcp.l_pad - 1);
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+ - (jcp.ih + jcp.t_pad - 1);
+ jcp.back_pad = (jcp.od - 1) * jcp.stride_d
+ + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1);
+
+ jcp.aligned_threads = 0;
+
+ jcp.is_1stconv = false;
+
+ jcp.oc_block = jcp.simd_w;
+ jcp.ic_block = jcp.is_1stconv ? jcp.ic : jcp.simd_w;
+
+ bool ok_to_pad_channels = true
+ && jcp.ngroups == 1
+ && diff_src_d.data_type() == data_type::f32;
+
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
+ jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
+ }
+
+ auto dat_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
+ auto wei_tag = with_groups
+ ? pick(ndims - 3, gOIw16o16i, gOIhw16o16i, gOIdhw16o16i)
+ : pick(ndims - 3, OIw16o16i, OIhw16o16i, OIdhw16o16i);
+ jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
+
+ bool args_ok = true
+ && jcp.oc % jcp.oc_block == 0
+ && jcp.ic % jcp.ic_block == 0
+ && jcp.src_tag == dat_tag
+ && jcp.dst_tag == dat_tag;
+ if (!args_ok)
+ return status::unimplemented;
+
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+
+ jcp.ur_w = jcp.stride_w;
+
+ int regs = 28;
+ if (jcp.iw <= regs)
+ jcp.ur_w = jcp.iw;
+ else {
+ for (int ur_w = regs; ur_w > 0; --ur_w)
+ if (ur_w % jcp.stride_w == 0) {
+ jcp.ur_w = ur_w;
+ break;
+ }
+ }
+ int l_overflow = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
+ - jcp.l_pad) / jcp.stride_w);
+ int r_overflow1 = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
+ - nstl::max(0, jcp.r_pad) - jcp.iw % jcp.ur_w) / jcp.stride_w);
+ int n_oi = jcp.iw / jcp.ur_w;
+ if (r_overflow1 > 0) n_oi--;
+
+ if (mayiuse(avx512_common)
+ && diff_dst_d.data_type() == data_type::f32
+ && weights_d.data_type() == data_type::f32
+ && diff_src_d.data_type() == data_type::f32) {
+ jcp.ver = ver_fma;
+ jcp.typesize_in = sizeof(float);
+ jcp.typesize_out = sizeof(float);
+ if (mayiuse(avx512_mic_4ops)
+ && jcp.stride_w == 1 && jcp.stride_h == 1 && jcp.stride_d == 1) {
+ jcp.ver = ver_4fma;
+ }
+ } else {
+ return status::unimplemented;
+ }
+
+ jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
+ if (jcp.wei_tag != wei_tag)
+ return status::unimplemented;
+
+ if (!utils::everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
+ && jcp.ver != ver_fma)
+ return status::unimplemented;
+
+ jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
+ if (jcp.ver == ver_4fma) {
+ if (jcp.kw == 3 && jcp.kh == 3 && jcp.iw == 7 && jcp.ih == 7) {
+ jcp.nb_ic_blocking = 2;
+ } else {
+ for (int i = jcp.nb_ic; i > 0; i--)
+ if (i * jcp.ur_w <= regs && jcp.nb_ic % i == 0) {
+ jcp.nb_ic_blocking = i;
+ break;
+ }
+ }
+ }
+
+ jcp.loop_order = loop_gnc;
+
+ bool large_code_size = (jcp.ur_w != jcp.ow)
+ && ((l_overflow <= 0 && n_oi > 0) ||(l_overflow > 0 && n_oi > 1))
+ && (r_overflow1 > 0) && (l_overflow > 0);
+ if (large_code_size) {
+ const int max_code_size = 24 * 1024;
+ const int num_ops_per_reg = 6 + jcp.oc_block * jcp.kw;
+ int mult = 1;
+ if (l_overflow > 0) mult += 1;
+ if (r_overflow1 > 0) mult += 1;
+ for (int ur_w = jcp.ur_w; ur_w > regs/2; --ur_w) {
+ if ((ur_w / jcp.stride_w) * mult * num_ops_per_reg * 9.2
+ < max_code_size) {
+ if (ur_w % jcp.stride_w == 0) {
+ jcp.ur_w = ur_w;
+ break;
+ }
+ }
+ }
+ }
+
+ if (jcp.ver == ver_fma && mayiuse(avx512_core)) {
+ int try_nb_ic_blocking = 2;
+ unsigned int ker_inp_size = typesize * jcp.iw * jcp.ic_block
+ * try_nb_ic_blocking * jcp.kh;
+ unsigned int ker_out_size = typesize * jcp.ow * jcp.oc_block;
+ unsigned int ker_wei_size = typesize * jcp.kh * jcp.kw * jcp.ic_block
+ * jcp.oc_block * try_nb_ic_blocking;
+ unsigned int ker_total_size = ker_inp_size + ker_out_size
+ + ker_wei_size;
+ if (!(jcp.kw == 1 || (jcp.kw == 5 && jcp.iw < 8)
+ || (jcp.kw < 5 && ((jcp.iw <= 5 || (jcp.iw > 8 && jcp.iw <= 13))
+ || ker_total_size > L1_cache_size )))
+ || jcp.stride_h > 1 || jcp.stride_d > 1) {
+ jcp.kernel_kind = embd_bcast;
+ jcp.ur_w = nstl::min(jcp.iw, regs);
+ jcp.nb_ic_blocking = jcp.nb_oc_blocking = 1;
+ if (!(jcp.kw > 3 || (jcp.kw == 3 && ker_total_size < L1_cache_size
+ && jcp.ow > 8)) && jcp.stride_h == 1)
+ if (jcp.nb_ic % try_nb_ic_blocking == 0) {
+ jcp.nb_ic_blocking = try_nb_ic_blocking;
+ jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
+ if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
+ }
+ } else {
+ jcp.kernel_kind = expl_bcast;
+ jcp.nb_oc_blocking = 1;
+ jcp.nb_ic_blocking = 4;
+ if (jcp.nb_ic < jcp.nb_ic_blocking) jcp.nb_ic_blocking = jcp.nb_ic;
+ if (jcp.nb_ic % jcp.nb_ic_blocking != 0)
+ for (int i = jcp.nb_ic_blocking; i > 0; i--)
+ if (jcp.nb_ic % i == 0) {
+ jcp.nb_ic_blocking = i;
+ break;
+ }
+ jcp.ur_w = 31 / (jcp.nb_ic_blocking + 1);
+ if (jcp.iw < jcp.ur_w) jcp.ur_w = jcp.iw;
+ }
+ }
+ jcp.ur_w_tail = jcp.iw % jcp.ur_w;
+
+ if (l_overflow * jcp.stride_w > jcp.ur_w)
+ return status::unimplemented;
+ int r_overflow_no_tail = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
+ - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail) / jcp.stride_w);
+ if (r_overflow_no_tail * jcp.stride_w > jcp.ur_w)
+ return status::unimplemented;
+ if ((jcp.iw > jcp.ur_w) && (jcp.ur_w % jcp.stride_w != 0))
+ return status::unimplemented;
+
+ pick_loop_order(jcp);
+
+ jcp.nb_oc_L2 = jcp.nb_oc;
+ if (jcp.ver == ver_4fma && (jcp.kh < 5 && jcp.kw < 5)) {
+ for (int divf = 2, temp_nb = jcp.nb_oc_L2; divf <= jcp.nb_oc;
+ divf++) {
+ size_t l2_src = jcp.iw * jcp.ic_block * jcp.nb_ic_blocking * jcp.ih
+ * jcp.id;
+ size_t l2_dst = jcp.ow * jcp.oc_block * temp_nb * jcp.oh * jcp.od;
+ size_t l2_filt = jcp.kw * jcp.oc_block * jcp.ic_block * jcp.kh
+ * jcp.kd * jcp.nb_ic_blocking * temp_nb;
+ if (4 * (l2_src + l2_dst + l2_filt) > KNx_L2_EFFECTIVE_CAPACITY) {
+ if (jcp.kh == 3 && jcp.ih == 7) {
+ jcp.nb_oc_L2 = 1;
+ break;
+ }
+ temp_nb = (jcp.nb_oc_L2 % divf == 0 ? jcp.nb_oc_L2 / divf
+ : jcp.nb_oc_L2);
+ } else {
+ jcp.nb_oc_L2 = temp_nb;
+ break;
+ }
+ }
+ }
+
+ args_ok = true
+ && jcp.ic <= diff_src_d.padded_dims()[1]
+ && jcp.oc <= diff_dst_d.padded_dims()[1]
+ && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
+ && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
+ if (!args_ok) return status::unimplemented;
+
+ return status::success;
+}
+
+void jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ UNUSED(scratchpad);
+ UNUSED(jcp);
+}
+
+const int jit_avx512_common_conv_bwd_weights_kernel_f32::max_ur_w = 28;
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::od_step_comeback_pointers()
+{
+ Label kd_comeback_label;
+
+ /* 'depth' loop count bound by 'kd_work_size' */
+ mov(kj, reg_kd_count);
+ L(kd_comeback_label); {
+ int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
+ int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
+ sub(reg_input,
+ jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mult);
+ sub(reg_kernel,
+ jcp.typesize_out * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block);
+ dec(kj);
+ cmp(kj, 0);
+ jg(kd_comeback_label, T_NEAR);
+ }
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::oh_step_comeback_pointers()
+{
+ Label kh_comeback_label, kd_comeback_label;
+ mov(kj, reg_kh);
+ L(kh_comeback_label); {
+ int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
+ int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
+ sub(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mult);
+ sub(reg_kernel,
+ jcp.typesize_out * jcp.kw * jcp.ic_block * jcp.oc_block);
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_comeback_label, T_NEAR);
+ }
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_fma(
+ int ur_w, int pad_l, int pad_r,
+ int ic_block_step, int input_offset, int kernel_offset,
+ int output_offset, bool input_wraparound)
+{
+
+ int kw = jcp.kw;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ for (int i_kw = 0; i_kw < kw; i_kw++)
+ for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
+ vmovups(Zmm(i_kw * ic_block_step + i_ic),
+ EVEX_compress_addr(reg_kernel, typesize * (i_kw * ic_block
+ + i_ic) * jcp.oc_block + kernel_offset));
+
+ for (int i_ur = 0; i_ur < ur_w; i_ur++) {
+ if (i_ur == 0) {
+ vmovups(Zmm(kw * ic_block_step + (i_ur + 0) % 4),
+ EVEX_compress_addr(reg_output, typesize * (i_ur + 0)
+ * oc_block + output_offset));
+ if (ur_w > 1) vmovups(Zmm(kw * ic_block_step + (i_ur + 1) % 4),
+ EVEX_compress_addr(reg_output, typesize * (i_ur + 1) * oc_block
+ + output_offset));
+ if (ur_w > 2) vmovups(Zmm(kw * ic_block_step + (i_ur + 2) % 4),
+ EVEX_compress_addr(reg_output, typesize * (i_ur + 2) * oc_block
+ + output_offset));
+ if (ur_w > 3) vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
+ EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
+ + output_offset));
+ } else if (i_ur + 3 < ur_w)
+ vmovups(Zmm(kw * ic_block_step + (i_ur + 3) % 4),
+ EVEX_compress_addr(reg_output, typesize * (i_ur + 3) * oc_block
+ + output_offset));
+
+ for (int i_kw = 0; i_kw < kw; i_kw++) {
+ int i_iw = i_ur * jcp.stride_w + i_kw * (jcp.dilate_w + 1);
+ if (i_iw - pad_l < 0 || i_iw > (ur_w - 1) * jcp.stride_w +
+ (kw - 1) * (jcp.dilate_w + 1) - pad_r) continue;
+ for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
+ const size_t i_offset = (size_t)input_offset
+ + (size_t)typesize * (jcp.ver == ver_4fma
+ ? (i_iw - pad_l + i_ic * jcp.tr_iw)
+ : (jcp.is_1stconv
+ ? (i_iw - pad_l) + (size_t)i_ic
+ * ((size_t)jcp.ih*jcp.iw*jcp.id)
+ : (i_iw - pad_l) * ic_block + i_ic));
+ vfmadd231ps(Zmm(i_kw * ic_block_step + i_ic),
+ Zmm(kw * ic_block_step + i_ur % 4),
+ EVEX_compress_addr_safe(reg_input, i_offset, reg_long_offt,
+ true));
+ }
+ }
+ }
+
+ for (int i_kw = 0; i_kw < kw; i_kw++)
+ for (int i_ic = 0; i_ic < ic_block_step; i_ic++)
+ vmovups(EVEX_compress_addr(reg_kernel, typesize
+ * (i_kw * ic_block + i_ic) * jcp.oc_block + kernel_offset),
+ Zmm(i_kw * ic_block_step + i_ic));
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step_4fma(
+ int ur_w, int pad_l, int pad_r,
+ int ic_block_step, int input_offset, int kernel_offset,
+ int output_offset, bool input_wraparound)
+{
+ // TODO: add prefetches to fma version as well
+
+ assert(jcp.ver == ver_4fma);
+
+ int kw = jcp.kw;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+
+ auto zmm_ker = [=](int i_kw, int i_ic) {
+ return Zmm(i_kw * ic_block_step + i_ic);
+ };
+
+ auto ker_addr = [=](int i_kw, int i_ic) {
+ size_t local_offset
+ = jcp.typesize_out * (i_kw * ic_block + i_ic) * jcp.oc_block;
+ return EVEX_compress_addr(reg_kernel, local_offset + kernel_offset);
+ };
+
+ auto inp_addr = [=](int i_iw, int i_ic, ptrdiff_t extra_offset = 0) {
+ int stride = jcp.tr_iw * (jcp.is_1stconv ? jcp.ih : 1);
+ int local_offset = jcp.typesize_in * (i_iw + i_ic * stride);
+ return EVEX_compress_addr(reg_input,
+ local_offset + input_offset + extra_offset);
+ };
+
+ auto zmm_out = [=](int i_iw) {
+ // TODO: move reg calc to global member funcs
+ const int out_zmm_base_idx = 28;
+ return Zmm(out_zmm_base_idx + i_iw % 4);
+ };
+
+ auto out_addr = [=](int i_ur) {
+ return EVEX_compress_addr(reg_output,
+ jcp.typesize_in * i_ur * oc_block + output_offset);
+ };
+
+ auto pf_callback = [=](int i_ur, int i_kw, int i_ic) {
+ assert(i_ur % 4 == 0);
+ if (i_ur == 0)
+ prefetcht1(ker_addr(i_kw, i_ic));
+ if (i_ur + 4 >= ur_w)
+ prefetcht0(ker_addr(i_kw, i_ic));
+
+ const ptrdiff_t next_input_block_offset
+ = jcp.typesize_in * ic_block_step * jcp.tr_iw;
+ if (i_ur % 16 == 4 && i_kw == 0) {
+ if (i_ur + 16 < ur_w)
+ prefetcht0(inp_addr(i_ur + 16, i_ic));
+ else
+ prefetcht0(inp_addr(0, i_ic, next_input_block_offset));
+ }
+ if (i_ur % 16 == 4 && i_kw == 1) {
+ if (input_wraparound)
+ prefetcht1(inp_addr(i_ur, i_ic, -input_offset));
+ else
+ prefetcht1(inp_addr(i_ur, i_ic, next_input_block_offset));
+ }
+ };
+
+ for (int i_kw = 0; i_kw < kw; i_kw++)
+ for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
+ auto zmm = zmm_ker(i_kw, i_ic);
+ vpxord(zmm, zmm, zmm);
+ }
+
+ for (int i_ur = 0; i_ur < ur_w; i_ur += 4) {
+
+ for (int i = 0; i < 4; i++) {
+ auto zmm = zmm_out(i_ur + i);
+ if (i_ur + i < ur_w)
+ vmovups(zmm, out_addr(i_ur + i));
+ else
+ vpxord(zmm, zmm, zmm);
+ prefetcht0(out_addr(i_ur + i + 4));
+ }
+
+ for (int i_kw = 0; i_kw < kw; i_kw++)
+ for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
+ int i_iw = i_ur + i_kw;
+ v4fmaddps(zmm_ker(i_kw, i_ic),
+ zmm_out(i_ur), inp_addr(i_iw, i_ic));
+ pf_callback(i_ur, i_kw, i_ic);
+ }
+ }
+
+ for (int i_kw = 0; i_kw < kw; i_kw++)
+ for (int i_ic = 0; i_ic < ic_block_step; i_ic++) {
+ auto addr = ker_addr(i_kw, i_ic);
+ auto zmm = zmm_ker(i_kw, i_ic);
+ vaddps(zmm, zmm, addr);
+ vmovups(addr, zmm);
+ }
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_ic_block_step(
+ int ur_w, int pad_l, int pad_r,
+ int ic_block_step, int input_offset, int kernel_offset,
+ int output_offset, bool input_wraparound)
+{
+ if (jcp.ver == ver_4fma)
+ compute_ic_block_step_4fma(ur_w, pad_l, pad_r,
+ ic_block_step, input_offset, kernel_offset, output_offset,
+ input_wraparound);
+ else if (jcp.ver == ver_fma)
+ compute_ic_block_step_fma(ur_w, pad_l, pad_r,
+ ic_block_step, input_offset, kernel_offset, output_offset,
+ input_wraparound);
+ else
+ assert(!"unknown convolution version");
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32
+ ::compute_oh_step_unroll_ow_icblock(
+ int ic_block_step, int max_ur_w)
+{
+ UNUSED(max_ur_w);
+
+ Label kh_label, kd_label;
+
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ int inp_mul = !jcp.is_1stconv ? ic_block : 1;
+ int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
+ int ow = jcp.ow;
+
+ int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+ int l_pad = jcp.l_pad;
+
+ if (jcp.ndims == 5) {
+ L(kd_label);
+ mov(reg_input, aux_reg_input);
+ mov(reg_kernel, aux_reg_kernel);
+ }
+
+ mov(kj, reg_kh);
+ L(kh_label);
+ {
+ for (int i_b_ic = 0; i_b_ic < jcp.ic_block; i_b_ic += ic_block_step) {
+ const int input_offset = jcp.typesize_in
+ * (jcp.ver == ver_4fma ? i_b_ic * iw : i_b_ic);
+ compute_ic_block_step(jcp.ur_w, l_pad, r_pad, ic_block_step,
+ input_offset, jcp.typesize_out * i_b_ic * jcp.oc_block, 0,
+ i_b_ic + ic_block_step >= jcp.ic_block);
+ }
+ add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * iw * inp_mul);
+ add(reg_kernel, jcp.typesize_out * jcp.kw * ic_block * oc_block);
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_label, T_NEAR);
+ }
+
+ if (jcp.ndims == 5) {
+ add(aux_reg_input,
+ jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih * iw * inp_mul);
+ add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
+ * oc_block);
+ dec(ki);
+ cmp(ki, 0);
+ jg(kd_label, T_NEAR);
+ }
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32
+ ::compute_oh_step_unroll_ow(
+ int ic_block_step, int max_ur_w)
+{
+ Label kh_label, ic_block_label, kd_label;
+
+ UNUSED(max_ur_w);
+
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+
+ int ow = jcp.ow;
+
+ int r_pad = nstl::max(0,
+ (ow - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
+ - (jcp.iw + jcp.l_pad - 1));
+ int l_pad = jcp.l_pad;
+
+ if (jcp.ndims == 5) {
+ L(kd_label);
+ mov(reg_input, aux_reg_input);
+ mov(reg_kernel, aux_reg_kernel);
+ }
+
+ mov(kj, reg_kh);
+ L(kh_label);
+ {
+ xor_(b_ic, b_ic);
+ L(ic_block_label); {
+ compute_ic_block_step(ow, l_pad, r_pad, ic_block_step,
+ 0, 0, 0);
+ size_t inp_icblk_stride = jcp.is_1stconv
+ ? (size_t)jcp.ih * jcp.iw * jcp.id
+ : (jcp.ver == ver_4fma ? jcp.tr_iw : 1);
+ size_t input_offset
+ = inp_icblk_stride * jcp.typesize_in * ic_block_step;
+ safe_add(reg_input, input_offset, reg_long_offt);
+ add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
+ add(b_ic, ic_block_step);
+ cmp(b_ic, jcp.ic_block);
+ jl(ic_block_label, T_NEAR);
+ }
+
+ if (jcp.is_1stconv) {
+ size_t input_offset
+ = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
+ safe_sub(reg_input, input_offset, reg_long_offt);
+ add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
+ } else if (jcp.ver != ver_4fma) {
+ add(reg_input, jcp.typesize_in
+ * ((jcp.dilate_h + 1) * jcp.iw - 1) * ic_block);
+ }
+ add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_label, T_NEAR);
+ }
+ if (jcp.ndims == 5) {
+ add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
+ * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
+ add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
+ * oc_block);
+ dec(ki);
+ cmp(ki, 0);
+ jg(kd_label, T_NEAR);
+ }
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32
+ ::compute_oh_step_common(
+ int ic_block_step, int max_ur_w)
+{
+ Label kh_label, ic_block_label, ow_block_label, kd_label;
+
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+
+ int ow = jcp.ow;
+ int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+ int l_pad = jcp.ver == ver_4fma ? 0 : jcp.l_pad;
+
+ int ur_w = nstl::min(ow, max_ur_w);
+ int ur_w_trips = ow / ur_w;
+ int ur_w_tail = ow % ur_w;
+ if ((ur_w_tail == 0 && r_pad != 0)
+ || r_pad >= ur_w_tail) {
+ if (ur_w_trips > 1) {
+ ur_w_tail += ur_w;
+ ur_w_trips--;
+ } else {
+ ur_w_tail += (ur_w - ur_w / 2);
+ ur_w = ur_w / 2;
+ }
+ }
+
+ int inp_mult = (jcp.is_1stconv || jcp.ver == ver_4fma) ? 1 : ic_block;
+ int input_comeback = (ur_w_trips * ur_w * jcp.stride_w - l_pad) * inp_mult;
+ int output_comeback = ur_w_trips * ur_w * oc_block;
+
+ if (jcp.ndims == 5) {
+ L(kd_label);
+ mov(reg_input, aux_reg_input);
+ mov(reg_kernel, aux_reg_kernel);
+ }
+
+ mov(kj, reg_kh);
+ L(kh_label); {
+ xor_(b_ic, b_ic);
+ L(ic_block_label); {
+ if (l_pad != 0) {
+ ur_w_trips--;
+ compute_ic_block_step(ur_w, l_pad, 0, ic_block_step, 0, 0, 0);
+ add(reg_input, jcp.typesize_in * (ur_w * jcp.stride_w - l_pad)
+ * inp_mult);
+ add(reg_output, jcp.typesize_in * ur_w * oc_block);
+ }
+
+ if (ur_w_trips > 0) {
+ xor_(reg_ur_w_trips, reg_ur_w_trips);
+ L(ow_block_label); {
+ compute_ic_block_step(ur_w, 0, 0, ic_block_step, 0, 0, 0);
+ add(reg_input, jcp.typesize_in * ur_w * jcp.stride_w
+ * inp_mult);
+ add(reg_output, jcp.typesize_in * ur_w * oc_block);
+
+ inc(reg_ur_w_trips);
+ cmp(reg_ur_w_trips, ur_w_trips);
+ jl(ow_block_label, T_NEAR);
+ }
+ }
+
+ if (ur_w_tail > 0) compute_ic_block_step(ur_w_tail, 0, r_pad,
+ ic_block_step, 0, 0, 0);
+
+ sub(reg_input, jcp.typesize_in * input_comeback);
+ sub(reg_output, jcp.typesize_in * output_comeback);
+ int inp_icblk_stride = jcp.is_1stconv
+ ? jcp.ih * jcp.iw * jcp.id
+ : (jcp.ver == ver_4fma ? jcp.tr_iw : 1);
+ size_t input_offset
+ = inp_icblk_stride * jcp.typesize_in * ic_block_step;
+ safe_add(reg_input, input_offset, reg_long_offt);
+ add(reg_kernel, jcp.typesize_out * ic_block_step * oc_block);
+
+ add(b_ic, ic_block_step);
+ cmp(b_ic, jcp.ic_block);
+ jl(ic_block_label, T_NEAR);
+ }
+ if (jcp.is_1stconv) {
+ size_t input_offset
+ = (size_t)jcp.typesize_in * jcp.id * jcp.ih * jcp.iw * ic_block;
+ safe_sub(reg_input, input_offset, reg_long_offt);
+ add(reg_input, jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw);
+ } else if (jcp.ver != ver_4fma) {
+ add(reg_input, jcp.typesize_in
+ * ((jcp.dilate_h + 1 ) * jcp.iw - 1) * ic_block);
+ }
+ add(reg_kernel, jcp.typesize_out * (jcp.kw - 1) * ic_block * oc_block);
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_label, T_NEAR);
+ }
+ if (jcp.ndims == 5) {
+ add(aux_reg_input, jcp.typesize_in * (jcp.dilate_d + 1) * jcp.ih
+ * jcp.iw * (jcp.is_1stconv ? 1 : ic_block));
+ add(aux_reg_kernel, jcp.typesize_out * jcp.kh * jcp.kw * ic_block
+ * oc_block);
+ dec(ki);
+ cmp(ki, 0);
+ jg(kd_label, T_NEAR);
+ }
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32
+ ::compute_oh_step_disp()
+{
+ int ic_block_step = jcp.kw <= 3 ? 8 : (jcp.kw <= 7 ? 4 : 2);
+ if (jcp.is_1stconv) {
+ bool large_code = jcp.kw >= 7 && (jcp.l_pad > 0 || jcp.t_pad > 0);
+ ic_block_step
+ = (jcp.kw * jcp.ic_block <= 28 && !large_code) ? jcp.ic_block : 1;
+ }
+
+ bool too_large_to_unroll
+ = (jcp.kw > 1 || jcp.kh > 1 || jcp.kd > 1)
+ && (jcp.stride_w > 1 || jcp.stride_h > 1 || jcp.stride_d > 1);
+
+ int ow = jcp.ow;
+ if (jcp.ndims == 5) {
+ /* NOTE: reg_kd_count = aux_reg_input = r12. The following order of
+ * 'movs' must be guaranteed. */
+ mov(ki, reg_kd_count);
+ push(reg_kd_count);
+ mov(aux_reg_input, reg_input);
+ mov(aux_reg_kernel, reg_kernel);
+ }
+
+ if (jcp.kw <= 3 && ow <= 16 && !too_large_to_unroll)
+ compute_oh_step_unroll_ow_icblock(ic_block_step, max_ur_w);
+ else if (ow <= max_ur_w)
+ compute_oh_step_unroll_ow(ic_block_step, max_ur_w);
+ else
+ compute_oh_step_common(ic_block_step, max_ur_w);
+
+ if (jcp.ndims == 5) {
+ mov(reg_input, aux_reg_input);
+ mov(reg_kernel, aux_reg_kernel);
+ pop(reg_kd_count);
+ od_step_comeback_pointers();
+ } else {
+ oh_step_comeback_pointers();
+ }
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::maybe_zero_kernel()
+{
+ Label skip_zeroing, zeroing_loop;
+
+ mov(reg_tmp, ptr[param + GET_OFF(channel)]);
+ cmp(reg_tmp, 0);
+ jz(skip_zeroing, T_NEAR);
+
+ Zmm zero = Zmm(0);
+ vpxord(zero, zero, zero);
+ xor_(reg_tmp, reg_tmp);
+ L(zeroing_loop); {
+ assert(jcp.oc_block * jcp.typesize_out
+ == cpu_isa_traits<avx512_common>::vlen);
+ for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
+ vmovups(ptr[reg_kernel + reg_tmp + ic1 * jcp.oc_block
+ * jcp.typesize_out], zero);
+ add(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
+ cmp(reg_tmp, jcp.ic_block * jcp.oc_block * jcp.kw * jcp.kh * jcp.kd
+ * jcp.typesize_out);
+ jnz(zeroing_loop);
+ }
+
+ L(skip_zeroing);
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::bias_kernel()
+{
+ Label skip_bias, bias_loop, skip_load_bias;
+
+ mov(reg_tmp, ptr[param + GET_OFF(flags)]);
+ test(reg_tmp,reg_tmp);
+ jne(skip_bias, T_NEAR);
+
+ mov(reg_bias, ptr[param + GET_OFF(bias)]);
+ mov(reg_output, ptr[param + GET_OFF(dst)]);
+ vpxord(Zmm(1), Zmm(1), Zmm(1));
+
+ mov(reg_tmp, ptr[param + GET_OFF(channel)]);
+ cmp(reg_tmp, 0);
+ jne(skip_load_bias, T_NEAR);
+ vmovups(Zmm(1), ptr[reg_bias]);
+
+ L(skip_load_bias);
+
+ mov(reg_oi, ptr[param + GET_OFF(d_worksize)]);
+ sub(reg_oi, ptr[param + GET_OFF(d_index)]);
+ mov(reg_tmp, jcp.oc_block * jcp.ow * jcp.oh * jcp.typesize_out);
+ imul(reg_oi, reg_tmp);
+
+ xor_(reg_tmp, reg_tmp);
+ L(bias_loop); {
+ vmovups(Zmm(0), ptr[reg_output + reg_tmp]);
+ vaddps(Zmm(1), Zmm(1), Zmm(0));
+ add(reg_tmp, jcp.oc_block * jcp.typesize_out);
+ cmp(reg_tmp, reg_oi);
+ jl(bias_loop);
+ }
+ vmovups(EVEX_compress_addr(reg_bias,0), Zmm(1));
+
+ L(skip_bias);
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32
+ ::compute_oh_loop_common()
+{
+ int b_pad = jcp.b_pad;
+ int t_pad = jcp.t_pad;
+ bool is_dilated = jcp.dilate_h != 0;
+ int dilate_h = jcp.dilate_h + 1;
+ int stride_h = jcp.stride_h;
+ const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
+ int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
+ Label oh_label, oh_label_end, oh_tpad_label, oh_tpad_tail_label,
+ oh_bpad_label, oh_bpad_label_end, od_label, od_label_end,
+ oh_dilate_label_shift, oh_dilate_label_noshift, oh_dilate_label_end;
+
+ int ow = jcp.ow;
+
+ mov(reg_kh, jcp.kh);
+ xor_(reg_ih_count, reg_ih_count);
+ xor_(reg_oj, reg_oj);
+ /* Compute 'top' edge */
+ if (t_pad > 0) {
+ const int kh_range = 1 + (jcp.kh - 1) * dilate_h;
+ const int overflow
+ = nstl::max(0, jcp.kh - div_up(t_pad + jcp.ih, dilate_h));
+ const int underflow = div_up(t_pad, dilate_h);
+ const int initial_inp_ker_overlap = jcp.kh - overflow - underflow;
+ mov(reg_kh, initial_inp_ker_overlap);
+ add(reg_kernel, jcp.typesize_out * underflow * jcp.kw * jcp.ic_block
+ * jcp.oc_block);
+ // generate loop to process kernel while it remains within t_pad + ih
+ if (kh_range < t_pad + jcp.ih) {
+ if (is_dilated) {
+ const int tail = t_pad % dilate_h;
+ const int shift = tail == 0 ? 0 : dilate_h - tail;
+ mov(reg_tmp, shift);
+ if (tail != 0)
+ add(reg_input, jcp.typesize_in * shift * iw * inp_mult);
+ }
+ L(oh_tpad_label); {
+ compute_oh_step_disp();
+ add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
+ if (is_dilated) {
+ inc(reg_tmp);
+ cmp(reg_tmp, dilate_h);
+ jl(oh_dilate_label_shift, T_NEAR);
+ // unshift input as new kernel element enters
+ sub(reg_input, jcp.typesize_in * (dilate_h - 1) * iw * inp_mult);
+ xor_(reg_tmp, reg_tmp);
+ }
+ // kernel overlap only changes when (t_pad + oj) % dilate_h == 0
+ sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
+ * jcp.ic_block * jcp.oc_block);
+ add(reg_kh, stride_h);
+ if (is_dilated) {
+ jmp(oh_dilate_label_noshift, T_NEAR);
+ L(oh_dilate_label_shift);
+ // shift input as old kernel element progresses
+ add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
+ L(oh_dilate_label_noshift);
+ }
+ inc(reg_oj);
+ add(reg_ih_count, stride_h);
+
+ // final number of kernel elements that overlap with input
+ const int final_inp_ker_overlap
+ = nstl::min(jcp.kh, div_up(jcp.ih, dilate_h));
+ cmp(reg_kh, final_inp_ker_overlap);
+ jl(oh_tpad_label, T_NEAR);
+ }
+ }
+ // need second loop to process kernel if it is larger than the input
+ // (does not apply to dilations as they must have unit stride)
+ if (kh_range >= jcp.ih + (t_pad % stride_h == 0 ? stride_h :
+ t_pad % stride_h)) {
+ assert(!is_dilated);
+ mov(reg_kh, jcp.ih);
+ L(oh_tpad_tail_label); {
+ compute_oh_step_disp();
+ add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
+ sub(reg_kernel, jcp.typesize_out * stride_h * jcp.kw
+ * jcp.ic_block * jcp.oc_block);
+
+ inc(reg_oj);
+ add(reg_ih_count, stride_h);
+
+ cmp(reg_ih_count, nstl::min(t_pad, jcp.oh * stride_h));
+ jl(oh_tpad_tail_label, T_NEAR);
+ }
+ }
+ // correct any excess shifts to kernel and input
+ // (does not apply to dilations as they must have unit stride,
+ // kernel must fit inside input, and padding is smaller than input)
+ if (t_pad <= jcp.oh * stride_h) {
+ // kernel has moved beyond padding (adjust for stride effects)
+ if (t_pad % stride_h != 0) {
+ assert(!is_dilated);
+ int inp_corr = stride_h - t_pad % stride_h;
+ add(reg_kernel, jcp.typesize_out * inp_corr * jcp.kw
+ * jcp.ic_block * jcp.oc_block);
+ add(reg_input, jcp.typesize_in * inp_corr * iw * inp_mult);
+ }
+ } else {
+ // kernel still overlaps padding (complete reset)
+ assert(!is_dilated);
+ sub(reg_kernel, jcp.typesize_out * (t_pad - jcp.oh * stride_h)
+ * jcp.kw * jcp.ic_block * jcp.oc_block);
+ }
+ }
+
+ cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
+ jge(oh_label_end, T_NEAR);
+ cmp(reg_oj, jcp.oh);
+ jge(oh_label, T_NEAR);
+
+ /* Compute middle block(s) */
+ mov(reg_kh, jcp.kh);
+ L(oh_label); {
+ compute_oh_step_disp();
+ add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
+ add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
+
+ inc(reg_oj);
+ add(reg_ih_count, stride_h);
+
+ cmp(reg_ih_count, jcp.ihp - b_pad - (jcp.kh - 1) * dilate_h);
+ jge(oh_label_end, T_NEAR);
+
+ cmp(reg_oj, jcp.oh);
+ jl(oh_label, T_NEAR);
+ }
+ L(oh_label_end);
+
+ /* Compute bottom edge */
+ if (b_pad > 0) {
+ cmp(reg_oj, jcp.oh);
+ jge(oh_bpad_label_end, T_NEAR);
+
+ if (is_dilated) {
+ mov(reg_kh, jcp.kh - 1); // assumes unit stride for dilations
+ mov(reg_tmp, 0);
+ } else {
+ mov(reg_kh, jcp.ihp - b_pad);
+ sub(reg_kh, reg_ih_count);
+ }
+ L(oh_bpad_label);
+ {
+ compute_oh_step_disp();
+ add(reg_input, jcp.typesize_in * stride_h * iw * inp_mult);
+ add(reg_output, jcp.typesize_in * ow * jcp.oc_block);
+ if (is_dilated) {
+ inc(reg_tmp);
+ cmp(reg_tmp, dilate_h);
+ jl(oh_dilate_label_end, T_NEAR);
+ xor_(reg_tmp, reg_tmp);
+ }
+ sub(reg_kh, stride_h);
+ cmp(reg_kh, 0);
+ jle(oh_bpad_label_end, T_NEAR);
+ if (is_dilated)
+ L(oh_dilate_label_end);
+
+ inc(reg_oj);
+ cmp(reg_oj, jcp.oh);
+ jl(oh_bpad_label, T_NEAR);
+ }
+ L(oh_bpad_label_end);
+ }
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_d_loop_common() {
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
+ int iw = jcp.ver == ver_4fma ? jcp.tr_iw : jcp.iw;
+ int ow = jcp.ow;
+ const int input_backpad_overlap
+ = div_up(jcp.id + jcp.f_pad - (jcp.kd - 1), jcp.stride_d);
+
+ const size_t filter_shift
+ = jcp.typesize_out * jcp.kh * jcp.kw * ic_block * oc_block;
+ const size_t input_shift = jcp.typesize_in * jcp.ih * iw * inp_mult;
+ const size_t output_shift = jcp.typesize_in * jcp.oh * ow * jcp.oc_block;
+
+ Label d_loop_label, loop_end_label, common_block_label, fpad_end_label,
+ backpad_end_label, backpad_label;
+
+ if (jcp.with_bias) bias_kernel();
+
+ /* initially offset 'kd' by f_pad */
+ add(reg_kernel, ptr[param + GET_OFF(kd_offset)]);
+
+ mov(reg_input_d, ptr[param + GET_OFF(src)]);
+ mov(reg_output_d, ptr[param + GET_OFF(dst)]);
+ mov(reg_d_index, ptr[param + GET_OFF(d_index)]);
+ mov(reg_kd_count, ptr[param + GET_OFF(kd_padding)]);
+
+ cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]);
+ jge(loop_end_label, T_NEAR);
+
+ L(d_loop_label);
+
+ mov(reg_input, reg_input_d);
+ mov(reg_output, reg_output_d);
+
+ push(reg_input_d);
+ push(reg_output_d);
+ push(reg_d_index);
+
+ compute_oh_loop_common();
+
+ pop(reg_d_index);
+ pop(reg_output_d);
+ pop(reg_input_d);
+
+ /* Compute 'front' edge */
+ if (jcp.f_pad > 0) {
+
+ /* Check if within fpad region */
+ cmp(reg_d_index, div_up(jcp.f_pad, jcp.stride_d));
+ jge(fpad_end_label, T_NEAR);
+
+ /* Fpad steps */
+ sub(reg_kernel, filter_shift * jcp.stride_d);
+ add(reg_kd_count, jcp.stride_d);
+
+ /* Final number of kernel elements that overlap with input */
+ const int inp_ker_overlap = nstl::min(jcp.kd, jcp.id);
+ cmp(reg_kd_count, inp_ker_overlap);
+ jl(common_block_label, T_NEAR);
+
+ /* Correct any excess shifts to kernel and input */
+ if (jcp.f_pad <= jcp.od * jcp.stride_d) {
+ /* Filter has moved beyond padding (adjust for stride effects) */
+ if (jcp.f_pad % jcp.stride_d != 0) {
+ int inp_corr = jcp.stride_d - jcp.f_pad % jcp.stride_d;
+ add(reg_kernel, filter_shift * inp_corr);
+ add(reg_input_d, input_shift * inp_corr);
+ }
+ } else {
+ /* Filter still overlaps padding (complete reset) */
+ sub(reg_kernel, (jcp.f_pad - jcp.od * jcp.stride_d) * filter_shift);
+ }
+
+ /* Apply correction */
+ mov(reg_kd_count, jcp.kd);
+ jmp(common_block_label);
+
+ L(fpad_end_label);
+ }
+
+ /* Compute bottom edge */
+ if (jcp.back_pad > 0) {
+
+ /* Check if within back_pad region */
+ cmp(reg_d_index, input_backpad_overlap - 1);
+ jl(backpad_end_label, T_NEAR);
+ jg(backpad_label, T_NEAR);
+
+ /* Execute overlap correction between the filter and the initial
+ * back_pad region. */
+ mov(reg_kd_count,
+ jcp.id + jcp.f_pad - input_backpad_overlap * jcp.stride_d);
+ jmp(backpad_end_label, T_NEAR);
+
+ L(backpad_label);
+ sub(reg_kd_count, jcp.stride_d);
+ cmp(reg_kd_count, 0);
+ jle(loop_end_label, T_NEAR);
+
+ L(backpad_end_label);
+ }
+
+ /* Compute middle block */
+ add(reg_input_d, input_shift * jcp.stride_d);
+
+ /* Execute common block and loop */
+ L(common_block_label);
+ add(reg_output_d, output_shift);
+ inc(reg_d_index);
+ cmp(reg_d_index, ptr[param + GET_OFF(d_worksize)]);
+ jl(d_loop_label, T_NEAR);
+
+ L(loop_end_label);
+}
+
+bool jit_avx512_common_conv_bwd_weights_kernel_f32::compute_full_spat_loop() {
+ // FIXME: use register mapping from the class declaration
+ bool ok = jcp.ver == ver_4fma
+ && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
+ && everyone_is(1, jcp.stride_h, jcp.stride_w);
+ if (!ok) return false;
+ if (jcp.l_pad != jcp.kw / 2 || jcp.t_pad != jcp.kh / 2)
+ return false;
+
+ // General code layout:
+ //
+ // Blocking over OH -- top level
+ // (Reduces L2 pressure; not very useful right now)
+ // Loop over all KHxKW kernel -- emit_kh_kw_loop()
+ // Loop over OH block -- emit_h_loop()
+ // Loop over OW blocks -- emit_fma_block()
+ // (Supports both fully unrolled and partially unrolled versions to
+ // reduce code size)
+ // Loop over OW block -- emit_fma_step()
+
+ int max_working_set_size = 128 * 1024;
+ int pad_ow = jcp.ow;
+
+ int inp_row_size = jcp.ic_block * jcp.tr_iw * jcp.typesize_in;
+ int out_row_size = jcp.oc_block * pad_ow * jcp.typesize_in;
+ int row_size = inp_row_size + out_row_size;
+
+ int h_block_size = jcp.oh;
+ int working_set_size = row_size * h_block_size;
+
+ if (working_set_size > max_working_set_size) {
+ int opt_working_set_size = 48 * 1024;
+ assert(opt_working_set_size < max_working_set_size);
+
+ while (working_set_size > opt_working_set_size) {
+ for (int i = 2; i <= h_block_size; i++)
+ if (i == h_block_size)
+ h_block_size = h_block_size / 2;
+ else if (h_block_size % i == 0) {
+ h_block_size = h_block_size / i;
+ break;
+ }
+ working_set_size = row_size * h_block_size;
+
+ if (h_block_size == 1 && working_set_size > opt_working_set_size)
+ return false;
+ }
+ }
+
+ // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size (see below)
+ if (h_block_size < nstl::max(1, jcp.t_pad)
+ || jcp.b_pad > (jcp.oh % h_block_size == 0 ? h_block_size
+ : jcp.oh % h_block_size))
+ return false;
+
+ // check that we can use simple arithmetic for prefetch address
+ // calculations
+ // TODO: we need some traits for this check (Roma)
+ int cache_line_size = 64;
+ assert(jcp.ic_block * typesize == 64);
+ assert(jcp.oc_block * typesize == 64);
+
+ int num_inp_l2_pfs = jcp.tr_iw * h_block_size;
+ int avg_h_loop_len = h_block_size;
+ int num_inp_l2_pfs_per_fma_block
+ = div_up(num_inp_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
+ int num_out_l2_pfs = pad_ow * h_block_size;
+ int num_out_l2_pfs_per_fma_block
+ = div_up(num_out_l2_pfs, avg_h_loop_len * jcp.kw * jcp.kh);
+
+ Opmask reg_h_block = k1; // 32-bit only on Intel(R) Xeon Phi(TM) processors
+ Reg64 reg_kh = rax;
+ Reg64 reg_kw = rbx;
+ Reg64 reg_tmp = abi_not_param1;
+ Reg32 reg_tmp_w = reg_tmp.cvt32();
+ Reg64 reg_ohs = rdx;
+ Reg64 reg_ihs = rsi;
+ Reg64 reg_h = r8;
+ Reg64 reg_i = r9;
+ Reg64 reg_j = r10;
+
+ Reg64 reg_inp = r13;
+ Reg64 reg_out = r14;
+ Reg64 reg_ker = r15;
+
+ Reg64 reg_inp_pf_l1 = rbp;
+
+ Reg64 reg_inp_pf_l2 = r11;
+ Reg64 reg_out_pf_l2 = r12;
+
+ Xmm reg_inp_pf_save = xmm17;
+ Xmm reg_out_pf_save = xmm18;
+
+ Reg64 reg_inp_save = abi_param1;
+ Reg64 reg_out_save = reg_tmp;
+
+ auto zmm_out = [&](int oi) { return Zmm(24 + oi % 8); };
+ auto zmm_ker = [&](int ic1) { return Zmm(ic1); };
+ auto inp_addr = [&](int oi, int ic1) {
+ return ptr[reg_inp + (ic1 * jcp.tr_iw + oi) * jcp.typesize_in];
+ };
+ auto out_addr = [&](int oi, int oj = 0) {
+ assert(jcp.ver == ver_4fma);
+ return ptr[reg_out
+ + ((oi + oj * jcp.ow) * jcp.oc_block) * jcp.typesize_in];
+ };
+ auto ker_addr = [&](int ic1) {
+ return ptr[reg_ker + ic1 * jcp.oc_block * jcp.typesize_out];
+ };
+
+ auto emit_block = [&](int h_block_size,
+ bool is_last_block, bool is_last_kh_kw_iter, bool is_last_row)
+ {
+ // TODO: add an fma version (Roma)
+ auto pad_ow = jcp.ow;
+
+ int ow4u = rnd_up(pad_ow, 4);
+ int def_step_size = 16;
+
+ bool has_w_tail = (pad_ow % def_step_size != 0
+ || pad_ow % 4 != 0);
+ bool full_w_unroll = pad_ow / def_step_size < 2 + has_w_tail;
+
+ auto emit_step = [&](int ur_ow,
+ int num_inp_l1_pfs_per_fma_step,
+ int num_inp_l2_pfs_per_fma_step,
+ int num_out_l2_pfs_per_fma_step, bool is_w_tail)
+ {
+ bool block_wraparound = is_w_tail && is_last_row;
+
+ assert(ur_ow % 4 == 0);
+ int tail_size = ow4u % ur_ow;
+ int this_ur_ow
+ = (is_w_tail && tail_size) ? tail_size : ur_ow;
+ int ow_last_chunk4 = pad_ow % 4;
+ int ow_zero_tail4 = ow_last_chunk4
+ ? 4 - ow_last_chunk4 : 0;
+
+ auto emit_out_pf = [&](int oi) {
+#if 1
+ if (oi + def_step_size < ur_ow || !block_wraparound)
+ mic_prefetcht0(ptr[reg_out
+ + ((def_step_size + oi)
+ * jcp.oc_block * jcp.typesize_in)]);
+ else {
+ assert(block_wraparound);
+ assert(oi + def_step_size >= ur_ow);
+ mic_prefetcht0(ptr[reg_out_save
+ + ((oi + def_step_size - ur_ow)
+ * jcp.oc_block * jcp.typesize_in)]);
+ }
+#else
+ // XXX: This is an alternative prefetching strategy that
+ // always prefetches the next row. Keeping it here for
+ // future experiments (Roma)
+ if (!block_wraparound)
+ mic_prefetcht0(ptr[reg_out
+ + (jcp.ow + oi) * jcp.oc_block * jcp.typesize_in]);
+ else
+ mic_prefetcht0(ptr[reg_out + reg_ohs
+ - ((h_block_size - 1) * jcp.ow
+ - oi) * jcp.oc_block * jcp.typesize_in]);
+#endif
+ if (oi < num_out_l2_pfs_per_fma_step)
+ mic_prefetcht1(ptr[reg_out_pf_l2
+ + oi * jcp.oc_block * jcp.typesize_in]);
+ };
+
+ auto emit_inp_pf = [&](int oi4, int ic1) {
+ int pf_slot_idx = ic1 + oi4 / 4 * jcp.ic_block;
+ int num_pf_slots = jcp.ic_block * ur_ow / 4;
+
+ int num_pfs = num_inp_l1_pfs_per_fma_step
+ + num_inp_l2_pfs_per_fma_step;
+ int pf_freq = nstl::max(1, num_pf_slots / num_pfs);
+
+ if (pf_slot_idx % pf_freq)
+ return;
+
+ int pf_idx = pf_slot_idx / pf_freq;
+
+ if (pf_idx < num_inp_l2_pfs_per_fma_step)
+ mic_prefetcht1(ptr[reg_inp_pf_l2
+ + pf_idx * jcp.ic_block * jcp.typesize_in]);
+ else {
+ pf_idx -= num_inp_l2_pfs_per_fma_step;
+ // prefetch the 'tail' of the cache line because most of
+ // the accesses are not aligned
+ mic_prefetcht0(ptr[reg_inp_pf_l1
+ + pf_idx * jcp.ic_block * jcp.typesize_in
+ + cache_line_size - jcp.typesize_in]);
+ }
+ };
+
+ auto numloads = 4;
+
+ int steps = this_ur_ow;
+ for (int oi4 = 0; oi4 < steps; oi4 += numloads) {
+ for (int oi1 = 0; oi1 < numloads; oi1++) {
+ int oi = oi4 + oi1;
+ if (!is_w_tail || oi < (this_ur_ow - ow_zero_tail4)) {
+ vmovups(zmm_out(oi), out_addr(oi));
+ emit_out_pf(oi);
+ } else {
+ auto zmm = zmm_out(oi);
+ vpxord(zmm, zmm, zmm);
+ }
+ }
+
+ for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
+ if (jcp.ver == ver_4fma) {
+ v4fmaddps(zmm_ker(ic1),
+ zmm_out(oi4), inp_addr(oi4, ic1));
+ } else {
+ assert(!"unknown convolution version");
+ }
+ emit_inp_pf(oi4, ic1);
+ }
+ }
+ };
+
+ // Input is transposed and padded but we only access about jcp.iw
+ // elements so use that to compute the # of cache lines in each 'row'
+ int num_inp_l1_pfs
+ = div_up(jcp.iw * jcp.typesize_in, cache_line_size) * jcp.ic_block;
+
+ if (full_w_unroll) {
+ emit_step(ow4u, num_inp_l1_pfs,
+ num_inp_l2_pfs_per_fma_block,
+ num_out_l2_pfs_per_fma_block, true);
+ add(reg_inp_pf_l2, num_inp_l2_pfs_per_fma_block * cache_line_size);
+ add(reg_out_pf_l2, num_out_l2_pfs_per_fma_block * cache_line_size);
+ } else {
+ Label w_loop;
+ int num_w_iters = pad_ow / def_step_size;
+ int num_w_iters_full = num_w_iters + has_w_tail;
+ int num_inp_l1_pfs_per_fma_step
+ = div_up(num_inp_l1_pfs, num_w_iters_full);
+ int num_inp_l2_pfs_per_fma_step
+ = div_up(num_inp_l2_pfs_per_fma_block, num_w_iters_full);
+ int num_out_l2_pfs_per_fma_step
+ = div_up(num_out_l2_pfs_per_fma_block, num_w_iters_full);
+ mov(reg_i, num_w_iters);
+ L(w_loop); {
+ emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
+ num_inp_l2_pfs_per_fma_step,
+ num_out_l2_pfs_per_fma_step, false);
+ add(reg_inp, def_step_size * jcp.typesize_in);
+ add(reg_out, def_step_size * jcp.oc_block * jcp.typesize_in);
+ add(reg_inp_pf_l1,
+ num_inp_l1_pfs_per_fma_step * cache_line_size);
+ add(reg_inp_pf_l2,
+ num_inp_l2_pfs_per_fma_step * cache_line_size);
+ add(reg_out_pf_l2,
+ num_out_l2_pfs_per_fma_step * cache_line_size);
+ sub(reg_i, 1);
+ jnz(w_loop);
+ }
+ if (has_w_tail) {
+ emit_step(def_step_size, num_inp_l1_pfs_per_fma_step,
+ num_inp_l2_pfs_per_fma_step,
+ num_out_l2_pfs_per_fma_step, true);
+ add(reg_inp_pf_l2,
+ num_inp_l2_pfs_per_fma_step * cache_line_size);
+ add(reg_out_pf_l2,
+ num_out_l2_pfs_per_fma_step * cache_line_size);
+ }
+ // reset reg_inp and reg_out because emit_h_loop expects
+ // unmodified pointers
+ int w_offset = num_w_iters * def_step_size;
+ sub(reg_inp, w_offset * jcp.typesize_in);
+ sub(reg_out, w_offset * jcp.oc_block * jcp.typesize_in);
+ }
+ };
+
+ auto emit_h_loop = [&](int h_block_size,
+ bool is_last_block, bool is_last_kh_kw_iter)
+ {
+ Label h_loop, skip_h_loop;
+ mov(reg_j, 1);
+ cmp(reg_j, reg_h);
+ je(skip_h_loop, T_NEAR);
+ L(h_loop); {
+
+ lea(reg_inp_pf_l1,
+ ptr[reg_inp + jcp.tr_iw * jcp.ic_block * jcp.typesize_in]);
+ emit_block(h_block_size,
+ is_last_block, is_last_kh_kw_iter, false);
+
+ add(reg_inp, jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
+ add(reg_out, pad_ow * jcp.oc_block * jcp.typesize_in);
+ add(reg_j, 1);
+ cmp(reg_j, reg_h);
+ jb(h_loop);
+ }
+
+ L(skip_h_loop);
+
+ for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
+ mic_prefetcht0(ker_addr(ic1));
+
+ lea(reg_inp_pf_l1, ptr[reg_inp_save + reg_kw * jcp.typesize_in]);
+ emit_block(h_block_size, is_last_block, is_last_kh_kw_iter, true);
+ };
+
+ auto emit_kh_kw_loop = [&](bool is_first_block, bool is_last_block,
+ int h_block_size)
+ {
+ xor_(reg_kh, reg_kh);
+ Label kh_loop, kh_loop_end;
+
+ int last_oh_block_size
+ = jcp.oh - rnd_up(jcp.oh - h_block_size, h_block_size);
+ int oh_block_size = (is_last_block) ? last_oh_block_size : h_block_size;
+ // NB1: t_pad <= oh_block_size and b_pad <= last_oh_block_size
+ int ih_block_size = oh_block_size - 1 + jcp.kh
+ - is_first_block * jcp.t_pad - is_last_block * jcp.b_pad;
+
+ L(kh_loop); {
+ // determine starting indices for this block
+ if (is_first_block) {
+ xor_(reg_tmp, reg_tmp);
+ mov(reg_ohs, jcp.t_pad);
+ sub(reg_ohs, reg_kh);
+ cmovb(reg_ohs, reg_tmp);
+
+ mov(reg_ihs, reg_ohs);
+ sub(reg_ihs, jcp.t_pad);
+ add(reg_ihs, reg_kh);
+ } else {
+ xor_(reg_ohs, reg_ohs);
+ mov(reg_ihs, reg_kh);
+ }
+
+ // determine effective size of block based on padding
+ mov(reg_tmp, oh_block_size);
+ sub(reg_tmp, reg_ohs);
+ mov(reg_h, ih_block_size);
+ sub(reg_h, reg_ihs);
+ cmp(reg_tmp, reg_h);
+ cmovb(reg_h, reg_tmp);
+
+ Label kh_loop_work;
+ cmp(reg_h, 0);
+ jg(kh_loop_work, T_NEAR);
+
+ // empty h loop for this jcp.kh:
+ // - set the output to 0 if necessary
+ // - move ker pt
+ // - jump to the end
+ sub(reg_h, 1);
+ Label skip_ker_zeroing;
+
+ // The reg_ker ptr has highest bit set if the output needs to be
+ // zeroed. Those who have byte-aligned their data will suffer the
+ // consiquences :(
+ // TODO: move the flag to a mask register? (Roma)
+ test(reg_ker, 1);
+ jz(skip_ker_zeroing, T_NEAR);
+
+ Label zeroing_loop;
+ vpxord(zmm0, zmm0, zmm0);
+ and_(reg_ker, ~1); // temporarily clear the zeroing flag
+ mov(reg_tmp, jcp.kw);
+ L(zeroing_loop); {
+ for (int ic1 = 0; ic1 < jcp.ic_block; ic1++)
+ vmovups(ker_addr(ic1), zmm0);
+ add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.typesize_out);
+ sub(reg_tmp, 1);
+ jnz(zeroing_loop, T_NEAR);
+ }
+ // restore the zeroing flag (it will be cleared after the end of
+ // emit_kh_kw_loop, but we may need it until then)
+ or_(reg_ker, 1);
+ jmp(kh_loop_end, T_NEAR);
+
+ L(skip_ker_zeroing);
+ add(reg_ker, jcp.oc_block * jcp.ic_block * jcp.kw
+ * jcp.typesize_out);
+ jmp(kh_loop_end, T_NEAR);
+
+ L(kh_loop_work);
+
+ mul_by_const(reg_ihs, reg_tmp,
+ jcp.tr_iw * jcp.ic_block * jcp.typesize_in);
+ mul_by_const(reg_ohs, reg_tmp,
+ pad_ow * jcp.oc_block * jcp.typesize_in);
+
+ add(reg_inp, reg_ihs);
+ add(reg_out, reg_ohs);
+
+ Label kw_loop;
+ xor_(reg_kw, reg_kw);
+ L(kw_loop); {
+ for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
+ auto zmm = zmm_ker(ic1);
+ vpxord(zmm, zmm, zmm);
+ mic_prefetcht1(ker_addr(ic1));
+ }
+
+ mov(reg_out_save, reg_out);
+ mov(reg_inp_save, reg_inp);
+ lea(reg_inp, ptr[reg_inp + reg_kw * jcp.typesize_in]);
+
+#if 0
+ // XXX: Generate code with special prefetches when switching
+ // blocks or at the end of the last block. Disabled to reduce
+ // code size and because there's no performance benefit (Roma)
+ Label regular_h_loop, end_h_loop;
+ cmp(reg_kw, jcp.kw - 1);
+ jne(regular_h_loop, T_NEAR);
+ cmp(reg_kh, jcp.kh - 1);
+ jne(regular_h_loop, T_NEAR);
+
+ emit_h_loop(oh_block_size, is_last_block, true);
+ jmp(end_h_loop, T_NEAR);
+
+ L(regular_h_loop);
+ emit_h_loop(oh_block_size, is_last_block, false);
+
+ L(end_h_loop);
+#else
+ emit_h_loop(oh_block_size, is_last_block, false);
+#endif
+
+ mov(reg_out, reg_out_save);
+ mov(reg_inp, reg_inp_save);
+
+ Label do_store;
+ // The reg_ker ptr has highest bit set if the output needs to
+ // be zeroed. Those who have byte-aligned their data will
+ // suffer the consiquences :(
+ mov(reg_tmp, reg_ker);
+ and_(reg_ker, ~1);
+ test(reg_tmp, 1);
+ jnz(do_store, T_NEAR);
+
+ for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
+ auto zmm = zmm_ker(ic1);
+ if (jcp.ver == ver_4fma) {
+ vaddps(zmm, ker_addr(ic1));
+ } else {
+ assert(!"unknown convolution version");
+ }
+ }
+
+ L(do_store);
+ for (int ic1 = 0; ic1 < jcp.ic_block; ic1++) {
+ auto zmm = zmm_ker(ic1);
+ vmovups(ker_addr(ic1), zmm);
+ }
+
+ mov(reg_ker, reg_tmp);
+ add(reg_ker, jcp.ic_block * jcp.oc_block * jcp.typesize_out);
+ add(reg_kw, 1);
+ cmp(reg_kw, jcp.kw);
+ jl(kw_loop);
+ }
+
+ sub(reg_inp, reg_ihs);
+ sub(reg_out, reg_ohs);
+
+
+ L(kh_loop_end);
+ add(reg_kh, 1);
+ cmp(reg_kh, jcp.kh);
+ jl(kh_loop);
+ }
+ };
+
+ mov(reg_inp, ptr[param + GET_OFF(src)]);
+ mov(reg_out, ptr[param + GET_OFF(dst)]);
+ mov(reg_ker, ptr[param + GET_OFF(filt)]);
+ mov(reg_inp_pf_l2, ptr[param + GET_OFF(src_prf)]);
+ mov(reg_out_pf_l2, ptr[param + GET_OFF(dst_prf)]);
+ mov(reg_tmp, ptr[param + GET_OFF(channel)]);
+ or_(reg_ker, reg_tmp);
+
+ bool single_kh_kw_loop = (h_block_size == jcp.oh);
+
+ size_t inp_row_step = jcp.tr_iw * jcp.ic_block * jcp.typesize_in;
+ size_t first_inp_block_step = inp_row_step * (h_block_size - jcp.t_pad);
+ size_t inp_block_step = inp_row_step * h_block_size;
+ size_t out_block_step = pad_ow * jcp.oc_block * jcp.typesize_in
+ * h_block_size;
+
+ if (!single_kh_kw_loop) {
+ // Save the original prefetch pointers from the OpenMP driver
+ vmovq(reg_inp_pf_save, reg_inp_pf_l2);
+ vmovq(reg_out_pf_save, reg_out_pf_l2);
+ mov(reg_inp_pf_l2, reg_inp);
+ add(reg_inp_pf_l2, first_inp_block_step);
+ mov(reg_out_pf_l2, reg_out);
+ add(reg_out_pf_l2, out_block_step);
+ }
+ emit_kh_kw_loop(true, single_kh_kw_loop, h_block_size);
+
+ if (!single_kh_kw_loop) {
+ size_t ker_reset_offset
+ = jcp.oc_block * jcp.ic_block * jcp.typesize_out * jcp.kw * jcp.kh;
+ sub(reg_ker, ker_reset_offset);
+ and_(reg_ker, ~1); // Clear the zeroing flag for subsequent updates
+
+ add(reg_inp, first_inp_block_step);
+ add(reg_out, out_block_step);
+ mov(reg_inp_pf_l2, reg_inp);
+ add(reg_inp_pf_l2, inp_block_step);
+ mov(reg_out_pf_l2, reg_out);
+ add(reg_out_pf_l2, out_block_step);
+
+ int num_innermost_iters = div_up(jcp.oh, h_block_size) - 2;
+ if (num_innermost_iters > 0) {
+ Label h_block_loop;
+
+ mov(reg_tmp_w, num_innermost_iters);
+ kmovw(reg_h_block, reg_tmp_w);
+ L(h_block_loop); {
+ emit_kh_kw_loop(false, false, h_block_size);
+ sub(reg_ker, ker_reset_offset);
+ add(reg_inp, inp_row_step * h_block_size);
+ add(reg_out, out_block_step);
+ mov(reg_inp_pf_l2, reg_inp);
+ add(reg_inp_pf_l2, inp_block_step);
+ mov(reg_out_pf_l2, reg_out);
+ add(reg_out_pf_l2, out_block_step);
+ kmovw(reg_tmp_w, reg_h_block);
+ sub(reg_tmp_w, 1);
+ kmovw(reg_h_block, reg_tmp_w);
+ jnz(h_block_loop);
+ }
+ }
+
+ // Restore the original prefetch pointers that came from the OpenMP
+ // driver
+ vmovq(reg_inp_pf_l2, reg_inp_pf_save);
+ vmovq(reg_out_pf_l2, reg_out_pf_save);
+ emit_kh_kw_loop(false, true, h_block_size);
+ }
+
+ return true;
+}
+
+bool jit_avx512_common_conv_bwd_weights_kernel_f32
+ ::flat_4ops_compute() {
+ const auto &j = jcp;
+ const bool ok = j.ver == ver_4fma && j.is_1stconv
+ && everyone_is(0, j.dilate_h, j.dilate_w);
+ if (!ok) return false;
+
+ Reg64 reg_ptr_tr_src = r8;
+ Reg64 reg_ptr_dst = r9;
+ Reg64 reg_ptr_wei = r10;
+ Reg64 reg_ptr_bia = r11;
+
+ Reg64 reg_kh_step = rax;
+ Reg64 reg_oh = abi_not_param1;
+ Reg64 reg_kh = rdx;
+
+ Reg32 reg_flag_save = ebx;
+ Reg32 reg_flag = esi;
+
+ Zmm vbia(31);
+
+ auto zmm_wei = [&](int kh, int kw) {
+ return Zmm(8 + kh * j.kw + kw);
+ };
+ auto zmm_dst = [&](int ow) {
+ return Zmm(ow % 8);
+ };
+
+ auto addr_tr_src = [&](int kh, int iw) {
+ return ptr[reg_ptr_tr_src
+ + (kh * j.stride_w * j.tr_ld + iw) * jcp.typesize_in];
+ };
+ auto addr_dst = [&](int ow) {
+ return ptr[reg_ptr_dst + ow * jcp.oc_block * jcp.typesize_in];
+ };
+ auto addr_wei = [&](int kh, int kw) {
+ return ptr[reg_ptr_wei + (kh * j.kw + kw) * j.oc_block
+ * jcp.typesize_out];
+ };
+
+ auto emit_fma_block = [&](int kh_step) {
+ for (int kh = 0; kh < kh_step; ++kh) {
+ for (int kw = 0; kw < j.kw; ++kw) {
+ auto vwei = zmm_wei(kh, kw);
+ vpxord(vwei, vwei, vwei);
+ }
+ }
+
+ for (int ow = 0; ow < j.ow; ow += 4) {
+ for (int _ow = ow; _ow < ow + 4; ++_ow) {
+ auto vdst = zmm_dst(_ow);
+ if (_ow < j.ow)
+ vmovups(vdst, addr_dst(_ow));
+ else
+ vpxord(vdst, vdst, vdst);
+ }
+
+ for (int kh = 0; kh < kh_step; ++kh) {
+ for (int kw = 0; kw < j.kw; ++kw) {
+ const int iw = ow + (kw % j.stride_w) * j.tr_ld
+ + (kw / j.stride_w);
+ v4fmaddps(zmm_wei(kh, kw), zmm_dst(ow),
+ addr_tr_src(kh, iw));
+ if (1 && kh == 0 && kw < 4) {
+ prefetcht1(ptr[reg_ptr_dst
+ + (j.ow + ow + kw) * jcp.oc_block
+ * jcp.typesize_in]);
+ }
+ if (j.with_bias && kh_step == 1) { /* [bwd_w:b:r1] */
+ const int off = kw + 4 - j.kw;
+ if (off >= 0 && ow + off < j.ow)
+ vaddps(vbia, vbia, zmm_dst(ow + off));
+ }
+ }
+ }
+ }
+
+ Label l_store;
+ test(reg_flag, FLAG_MB_FIRST);
+ jnz(l_store, T_NEAR);
+ for (int kh = 0; kh < kh_step; ++kh) {
+ for (int kw = 0; kw < j.kw; ++kw)
+ vaddps(zmm_wei(kh, kw), addr_wei(kh, kw));
+ }
+ L(l_store);
+ for (int kh = 0; kh < kh_step; ++kh) {
+ for (int kw = 0; kw < j.kw; ++kw)
+ vmovups(addr_wei(kh, kw), zmm_wei(kh, kw));
+ }
+ };
+
+ auto emit_kh_loop = [&]() {
+ const int kh_step_rem = j.kh % j.kh_step;
+ xor_(reg_kh, reg_kh);
+ mov(reg_kh_step, j.kh_step);
+
+ Label l_kh_loop;
+ L(l_kh_loop); {
+ Label l_done;
+
+ if (kh_step_rem != 0) {
+ Label l_keep_kh_step;
+ cmp(reg_kh, j.kh - j.kh_step);
+ jle(l_keep_kh_step, T_NEAR);
+
+ mov(reg_kh_step, kh_step_rem);
+ emit_fma_block(kh_step_rem);
+ jmp(l_done, T_NEAR);
+
+ L(l_keep_kh_step);
+ }
+
+ emit_fma_block(j.kh_step);
+
+ L(l_done);
+
+ add(reg_ptr_tr_src, j.kh_step * j.stride_w * j.tr_ld
+ * jcp.typesize_in);
+ add(reg_ptr_wei, j.kh_step * j.kw * j.oc_block * jcp.typesize_out);
+ add(reg_kh, j.kh_step);
+
+ cmp(reg_kh, j.kh);
+ jl(l_kh_loop, T_NEAR);
+ }
+
+ const int kh_steps = rnd_up(j.kh, j.kh_step);
+ sub(reg_ptr_tr_src, kh_steps * j.stride_w * j.tr_ld * jcp.typesize_in);
+ sub(reg_ptr_wei, kh_steps * j.kw * j.oc_block * jcp.typesize_out);
+ };
+
+ auto emit_oh_loop = [&]() {
+ mov(reg_oh, j.oh);
+
+ Label l_oh_loop;
+ L(l_oh_loop); {
+ Label l_restore_mb_flag, l_jump;
+
+ cmp(reg_oh, j.oh);
+ je(l_restore_mb_flag, T_NEAR);
+
+ and_(reg_flag, ~FLAG_MB_FIRST);
+ jmp(l_jump, T_NEAR);
+
+ L(l_restore_mb_flag);
+ mov(reg_flag, reg_flag_save);
+
+ L(l_jump);
+
+ emit_kh_loop();
+
+ add(reg_ptr_tr_src, j.stride_h * j.stride_w * j.tr_ld
+ * jcp.typesize_in);
+ add(reg_ptr_dst, j.ow * j.oc_block * jcp.typesize_in);
+
+ dec(reg_oh);
+ jnz(l_oh_loop, T_NEAR);
+ }
+ };
+
+ auto emit_bia_store = [&]() {
+ if (!j.with_bias) return;
+
+ Label l_bia_store, l_bia_skip;
+ test(reg_flag, FLAG_IC_FIRST);
+ jz(l_bia_skip);
+
+ test(reg_flag, FLAG_MB_FIRST);
+ jnz(l_bia_store, T_NEAR);
+ vaddps(vbia, ptr[reg_ptr_bia]);
+ L(l_bia_store);
+ vmovups(ptr[reg_ptr_bia], vbia);
+ L(l_bia_skip);
+ };
+
+ mov(reg_ptr_tr_src, ptr[param + GET_OFF(src)]);
+ mov(reg_ptr_dst, ptr[param + GET_OFF(dst)]);
+ mov(reg_ptr_wei, ptr[param + GET_OFF(filt)]);
+ mov(reg_ptr_bia, ptr[param + GET_OFF(bias)]);
+ mov(reg_flag_save, ptr[param + GET_OFF(flags)]);
+
+ vpxord(vbia, vbia, vbia);
+ emit_oh_loop();
+ emit_bia_store();
+
+ return true;
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::compute_loop()
+{
+ if (flat_4ops_compute())
+ return;
+ if (compute_full_spat_loop())
+ return;
+
+ maybe_zero_kernel();
+
+ if (jcp.ndims == 5) compute_d_loop_common();
+ else compute_oh_loop_common();
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::generate()
+{
+ preamble();
+
+ mov(reg_input, ptr[param + GET_OFF(src)]);
+ mov(reg_output, ptr[param + GET_OFF(dst)]);
+ mov(reg_kernel, ptr[param + GET_OFF(filt)]);
+
+ compute_loop();
+
+ postamble();
+}
+
+status_t jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf(
+ jit_conv_conf_t &jcp, const convolution_desc_t &cd,
+ memory_desc_t &src_md, memory_desc_t &diff_weights_md,
+ memory_desc_t &diff_bias_md, memory_desc_t &diff_dst_md) {
+ if (!mayiuse(avx512_common))
+ return status::unimplemented;
+
+ const memory_desc_wrapper src_d(&src_md);
+ const memory_desc_wrapper diff_weights_d(&diff_weights_md);
+ const memory_desc_wrapper diff_bias_d(&diff_bias_md);
+ const memory_desc_wrapper diff_dst_d(&diff_dst_md);
+
+ const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
+ int ndims = src_d.ndims();
+
+ jcp = zero<decltype(jcp)>();
+
+ jcp.simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
+ jcp.ndims = ndims;
+ jcp.prop_kind = cd.prop_kind;
+
+ jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+
+ jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+
+ jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
+ jcp.ih = (ndims == 3) ? 1 : src_d.dims()[ndims-2];
+ jcp.iw = src_d.dims()[ndims-1];
+ jcp.od = (ndims == 5) ? diff_dst_d.dims()[2] : 1;
+ jcp.oh = (ndims == 3) ? 1 : diff_dst_d.dims()[ndims-2];
+ jcp.ow = diff_dst_d.dims()[ndims-1];
+
+ jcp.kd = (ndims == 5) ? diff_weights_d.dims()[with_groups + 2] : 1;
+ jcp.kh = (ndims == 3) ? 1 : diff_weights_d.dims()[with_groups + ndims-2];
+ jcp.kw = diff_weights_d.dims()[with_groups + ndims-1];
+
+ jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][ndims-4];
+ jcp.l_pad = cd.padding[0][ndims-3];
+
+ jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
+ jcp.stride_h = (ndims == 3) ? 1 : cd.strides[ndims-4];
+ jcp.stride_w = cd.strides[ndims-3];
+
+ jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
+ jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[ndims-4];
+ jcp.dilate_w = cd.dilates[ndims-3];
+
+ const int kh_range = 1 + (jcp.kh - 1) * (jcp.dilate_h + 1);
+ bool ok = true
+ // general condition to simplify dilations
+ && IMPLICATION(jcp.dilate_d != 0, jcp.stride_d == 1)
+ && IMPLICATION(jcp.dilate_h != 0, jcp.stride_h == 1)
+ && IMPLICATION(jcp.dilate_w != 0, jcp.stride_w == 1)
+ // special condition to simplify dilations in compute_oh_loop_common
+ && IMPLICATION(jcp.dilate_h != 0, kh_range <= jcp.ih);
+ if (!ok)
+ return status::unimplemented;
+
+ jcp.r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+ jcp.b_pad = nstl::max(0, (jcp.oh - 1) * jcp.stride_h
+ + (jcp.kh - 1) * (jcp.dilate_h + 1) - (jcp.ih + jcp.t_pad - 1));
+ jcp.back_pad = nstl::max(0, (jcp.od - 1) * jcp.stride_d
+ + (jcp.kd - 1) * (jcp.dilate_d + 1) - (jcp.id + jcp.f_pad - 1));
+
+ /* XXX: currently, does not support dilation_d > 0 */
+ if (ndims == 5)
+ if (jcp.dilate_d > 0)
+ return status::unimplemented;
+
+ jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
+ jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
+ jcp.ohp = jcp.oh;
+ jcp.owp = jcp.ow;
+ jcp.aligned_threads = 0;
+
+ /* check for the 1st convolution */
+ jcp.is_1stconv = is_1stconv(jcp);
+
+ jcp.oc_block = jcp.simd_w;
+
+ bool ok_to_pad_channels = true
+ && jcp.ngroups == 1
+ && src_d.data_type() == data_type::f32;
+
+ if (ok_to_pad_channels)
+ jcp.oc = rnd_up(jcp.oc, jcp.simd_w);
+
+ if (jcp.oc % jcp.oc_block)
+ return status::unimplemented;
+
+ auto dst_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
+ auto wei_tag = with_groups
+ ? pick(ndims - 3, gOIw16i16o, gOIhw16i16o, gOIdhw16i16o)
+ : pick(ndims - 3, OIw16i16o, OIhw16i16o, OIdhw16i16o);
+
+ if (diff_dst_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(diff_dst_md, dst_tag));
+ jcp.dst_tag = dst_tag;
+ } else {
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(dst_tag);
+ }
+ if (jcp.dst_tag != dst_tag)
+ return status::unimplemented;
+
+ /* conditions on bias memory */
+ jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
+ if (jcp.with_bias) {
+ if (diff_bias_d.format_kind() == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_bias_md, x));
+ }
+
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+
+ /* kernel applicability check wrt boundaries
+ * the conditions are quite general across the kernels we have,
+ * but ideally the check should belong to a specific kernel... */
+ const int max_pad = ((jcp.kh - 1) * (jcp.dilate_h + 1) + 1) / 2;
+ const bool boundaries_ok = true
+ && jcp.t_pad <= max_pad
+ && jcp.b_pad <= max_pad
+ && IMPLICATION(jcp.f_pad > 0, jcp.kd < jcp.id + jcp.f_pad)
+ && jcp.f_pad < jcp.kd;
+ if (!boundaries_ok)
+ return status::unimplemented;
+
+ /* yet another common check */
+ if (jcp.kw > 14)
+ return status::unimplemented;
+
+ /* setting register strategy */
+ for (int ur_w = nstl::min(max_ur_w, jcp.ow); ur_w > 0; --ur_w) {
+ if (jcp.ow % ur_w == 0) { jcp.ur_w = ur_w; break; }
+ }
+
+ if (jcp.is_1stconv) {
+ auto src_tag = pick(ndims - 3, ncw, nchw, ncdhw);
+ if (src_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(src_md, src_tag));
+ jcp.src_tag = src_tag;
+ } else {
+ jcp.src_tag = src_d.matches_one_of_tag(src_tag);
+ if (jcp.ic == 1 && jcp.src_tag != src_tag)
+ jcp.src_tag = src_d.matches_one_of_tag(
+ pick(ndims - 3, nwc, nhwc, ndhwc));
+ }
+ if (jcp.src_tag == format_tag::undef)
+ return status::unimplemented;
+
+ const bool src_ok = true
+ && utils::everyone_is(data_type::f32,
+ src_d.data_type(), diff_weights_d.data_type(),
+ diff_dst_d.data_type())
+ && one_of(jcp.ic, 1, 2, 3)
+ && jcp.ngroups == 1;
+ if (!src_ok)
+ return status::unimplemented;
+
+ const int tr_ld = rnd_up(div_up(jcp.iw + jcp.l_pad + jcp.r_pad,
+ jcp.stride_w), 16);
+ const int kh_step = nstl::max((28 - jcp.with_bias) / jcp.kw, 1);
+ const int kh_step_rem = jcp.kh % kh_step;
+
+ const auto wei_4fma_tag = with_groups
+ ? pick(ndims - 3, gOiw16o, gOihw16o, gOidhw16o)
+ : pick(ndims - 3, Oiw16o, Oihw16o, Oidhw16o);
+
+ auto current_wei_tag = format_tag::undef;
+ if (diff_weights_d.format_kind() != format_kind::any)
+ current_wei_tag = diff_weights_d.matches_one_of_tag(wei_4fma_tag);
+
+ const bool use_4fma = true
+ && one_of(ndims, 3, 4)
+ && mayiuse(avx512_mic_4ops)
+ && mkldnn_thr_syncable()
+ && everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w)
+ && everyone_is(0, jcp.l_pad, jcp.r_pad, jcp.t_pad, jcp.b_pad)
+ && jcp.kw <= 28 - jcp.with_bias
+ && jcp.stride_w == 4
+ && tr_ld / jcp.simd_w <= 4 /* [bwd_w:tr_src:r1] */
+ && IMPLICATION(jcp.with_bias, kh_step_rem == 1) /* [bwd_w:b:r1] */
+ && IMPLICATION(diff_weights_d.format_kind() != format_kind::any,
+ current_wei_tag == wei_4fma_tag);
+
+ if (use_4fma) {
+ jcp.ver = ver_4fma;
+ jcp.kh_step = kh_step;
+ jcp.tr_ld = tr_ld;
+ jcp.ic_block = 1;
+ if (diff_weights_d.format_kind() == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_weights_md, wei_4fma_tag));
+ jcp.wei_tag = wei_4fma_tag;
+ } else {
+ jcp.ver = ver_fma;
+ jcp.ic_block = jcp.ic;
+
+ wei_tag = with_groups
+ ? pick(ndims - 3, gOwi16o, gOhwi16o, gOdhwi16o)
+ : pick(ndims - 3, Owi16o, Ohwi16o, Odhwi16o);
+
+ if (diff_weights_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
+ jcp.wei_tag = wei_tag;
+ } else {
+ jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
+ }
+ if (jcp.wei_tag != wei_tag)
+ return status::unimplemented;
+ }
+
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+ } else {
+ auto src_tag = pick(ndims - 3, nCw16c, nChw16c, nCdhw16c);
+ if (src_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(src_md, src_tag));
+ jcp.src_tag = src_tag;
+ } else {
+ jcp.src_tag = src_d.matches_one_of_tag(src_tag);
+ }
+ if (jcp.src_tag != src_tag)
+ return status::unimplemented;
+
+ if (diff_weights_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(diff_weights_md, wei_tag));
+ jcp.wei_tag = wei_tag;
+ } else {
+ jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
+ }
+ if (jcp.wei_tag != wei_tag)
+ return status::unimplemented;
+
+ jcp.ic_block = jcp.simd_w;
+ if (ok_to_pad_channels)
+ jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+ if ((mayiuse(avx512_mic) || mayiuse(avx512_core))
+ && utils::everyone_is(data_type::f32,
+ src_d.data_type(), diff_weights_d.data_type(),
+ diff_dst_d.data_type())) {
+ jcp.ver = ver_fma;
+ if (one_of(ndims, 3, 4) && mayiuse(avx512_mic_4ops) && jcp.stride_w == 1 &&
+ everyone_is(0, jcp.dilate_d, jcp.dilate_h, jcp.dilate_w) &&
+ mkldnn_thr_syncable()) {
+ jcp.ver = ver_4fma;
+ }
+ } else {
+ return status::unimplemented;
+ }
+ if (jcp.ver == ver_4fma) {
+ jcp.ur_w = jcp.ow;
+ // XXX, BUGBUGBUG, but not a FIXME: this assumes that it's OK to
+ // cross the right boundary. The only requirement is not to have
+ // NaNs there because another multiplicand is always guaranteed to
+ // be zero. This also may require the top-level driver to allocate
+ // four extra guarding elements at the very end of the buffer.
+ // I'm not proud of this hack, but it improves performance by
+ // about 5-10% depending on the dimensions (Roma)
+
+ const int tr_round = 4;
+
+ jcp.tr_iw = rnd_up(jcp.iw + jcp.kw - 1, tr_round);
+ jcp.tr_src_num_guard_elems = tr_round; // upper bound
+ }
+ }
+
+ if (utils::one_of(jcp.ver, ver_4fma, ver_fma)) {
+ jcp.typesize_in = sizeof(float);
+ jcp.typesize_out = sizeof(float);
+ } else
+ return status::unimplemented;
+
+ bool args_ok = true
+ && jcp.ic % jcp.ic_block == 0
+ && jcp.oc % jcp.oc_block == 0
+ && jcp.ic <= src_d.padded_dims()[1]
+ && jcp.oc <= diff_dst_d.padded_dims()[1]
+ && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
+ && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
+ if (!args_ok) return status::unimplemented;
+
+ { // balancing
+ int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
+ balance(jcp, nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b);
+ jcp.nthr = nthr;
+ jcp.nthr_mb = nthr_mb;
+ jcp.nthr_g = nthr_g;
+ jcp.nthr_oc_b = nthr_oc_b;
+ jcp.nthr_ic_b = nthr_ic_b;
+ }
+
+ return status::success;
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ if (jcp.ver == ver_4fma) {
+ if (jcp.is_1stconv) {
+ const size_t tr_src_size =
+ jcp.nthr / jcp.nthr_oc_b * jcp.ih * jcp.stride_w * jcp.tr_ld;
+ scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
+ } else {
+ // XXX: See the comment about tr_iw and guarding elements in
+ // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
+ const size_t max_nthr = jcp.nthr_mb * jcp.ngroups * jcp.nb_ic;
+ const size_t min_tr_src_size_per_thr
+ = jcp.ih * jcp.ic_block * jcp.tr_iw;
+ const size_t tr_src_size = max_nthr * min_tr_src_size_per_thr
+ + jcp.tr_src_num_guard_elems;
+ scratchpad.book(key_conv_tr_src, jcp.typesize_in * tr_src_size);
+ }
+
+ /* prepare synchronization contexts */
+ if (jcp.nthr_oc_b > 1) {
+ const int tr_src_bctx_size = jcp.nthr / jcp.nthr_oc_b;
+ scratchpad.book(key_conv_tr_src_bctx,
+ sizeof(simple_barrier::ctx_t) * tr_src_bctx_size);
+ }
+ }
+
+ if (jcp.nthr_mb > 1) {
+ const int wei_size = jcp.ngroups * jcp.oc * jcp.ic
+ * jcp.kh * jcp.kw * jcp.kd;
+ const int bia_size = jcp.ngroups * jcp.oc;
+ const size_t wei_bia_reduction_size = wei_size + bia_size;
+
+ scratchpad.book(key_conv_wei_bia_reduction,
+ jcp.typesize_out * wei_bia_reduction_size * (jcp.nthr_mb - 1));
+ scratchpad.book(key_conv_wei_bia_reduction_bctx,
+ sizeof(simple_barrier::ctx_t));
+ }
+
+ if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+ scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
+}
+
+void jit_avx512_common_conv_bwd_weights_kernel_f32::balance(
+ const jit_conv_conf_t &j, int &nthr_, int &nthr_mb_, int &nthr_g_,
+ int &nthr_oc_b_, int &nthr_ic_b_)
+{
+ nthr_ = nthr_mb_ = nthr_g_ = nthr_oc_b_ = nthr_ic_b_ = 1;
+
+ const int max_threads = mkldnn_get_max_threads();
+
+ if (max_threads < j.ngroups) {
+ /* simplification... fortunately it doesn't hurt much */
+ return;
+ }
+
+ if (!mkldnn_thr_syncable() && j.ver == ver_4fma) {
+ // should not happen -- the driver is not ready
+ // for TBB-like non-synchronous threading yet
+ return;
+ }
+
+ if (j.ver == ver_4fma && j.is_1stconv) {
+ nthr_g_ = 1;
+ nthr_oc_b_ = 1;
+ nthr_ic_b_ = nstl::min(j.nb_ic, max_threads);
+ nthr_mb_ = nstl::min(max_threads / nthr_ic_b_, j.mb);
+ nthr_ = nthr_mb_ * nthr_oc_b_ * nthr_ic_b_ * nthr_g_;
+ return;
+ }
+
+ nthr_g_ = j.ngroups;
+ const int nthr = max_threads / nthr_g_;
+
+ auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
+ /* calculate per thread memory cost (read/write). high level optimizer
+ * tries to minimize memory consumption. few notes:
+ * (n1) unclear why, but that essentially helps first convolution...
+ * (n2) assuming the reduction over minibatch is always there:
+ * - instead of 8 it should be 5 here (write ~= 2 read):
+ * kernel: temporal workspace 1 write
+ * reduction: 1 read from workspace and 1 write to the diff_wei
+ * - but experiments showed 8 works better than 5 or 6... */
+
+ const int src_coef = j.ver == ver_4fma ? 4 : 1;
+ const int dst_coef = 1;
+ const int wei_coef = 8;
+
+ return 0
+ + src_coef
+ * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
+ * div_up(j.nb_ic, nthr_ic_b) * j.ic_block * j.ih * j.iw * j.id
+ / j.stride_d / j.stride_h / j.stride_w /* (n1) */
+ + dst_coef
+ * div_up(j.mb, nthr_mb) * div_up(j.ngroups, nthr_g_)
+ * div_up(j.nb_oc, nthr_oc_b) * j.oc_block * j.oh * j.ow * j.od
+ + wei_coef /* (n2) */
+ * div_up(j.ngroups, nthr_g_)
+ * div_up(j.nb_oc, nthr_oc_b) * div_up(j.nb_ic, nthr_ic_b)
+ * j.kh * j.kw * j.kd * j.ic_block * j.oc_block;
+ };
+
+ int best_mem_cost = calc_mem_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
+
+ /* step 1: find the best thread distribution with lowest memory cost */
+ const int nthr_mb_max = nstl::min(nthr, j.mb * j.od);
+ for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
+ const int nthr_par = nthr / nthr_mb;
+ const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
+ for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
+ int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
+
+ int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+ if (mem_cost <= best_mem_cost) {
+ best_mem_cost = mem_cost;
+ nthr_mb_ = nthr_mb;
+ nthr_oc_b_ = nthr_oc_b;
+ nthr_ic_b_ = nthr_ic_b;
+ }
+ }
+
+ if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
+ }
+
+ if (!mayiuse(avx512_mic)) {
+ auto calc_comp_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
+ return 1
+ * div_up(j.mb, nthr_mb)
+ * div_up(j.ngroups, nthr_g_)
+ * div_up(j.nb_oc, nthr_oc_b)
+ * div_up(j.nb_ic, nthr_ic_b);
+ };
+
+ /* step 2: search for a thread distribution with lower compute cost.
+ * the constrains:
+ * - memory cost cannot exceed 110% of the best found in the step 1
+ * - unless compute cost is 133% lower than the current best case
+ * note: both constants were found empirically */
+ int best_comp_cost = calc_comp_cost(nthr_mb_, nthr_oc_b_, nthr_ic_b_);
+ for (int nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
+ const int nthr_par = nthr / nthr_mb;
+ const int nthr_oc_b_max = nstl::min(nthr_par, j.nb_oc);
+ for (int nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
+ int nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, j.nb_ic);
+ int mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+ int comp_cost = calc_comp_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
+
+ const bool opt1 = comp_cost <= best_comp_cost
+ && mem_cost < 1.1 * best_mem_cost;
+ const bool opt2 = 4 * comp_cost <= 3 * best_comp_cost;
+
+ if (opt1 || opt2) {
+ best_comp_cost = comp_cost;
+ nthr_mb_ = nthr_mb;
+ nthr_oc_b_ = nthr_oc_b;
+ nthr_ic_b_ = nthr_ic_b;
+ }
+ }
+
+ if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
+ }
+ }
+
+ if (nthr_mb_ > max_threads/2 && nthr_mb_ < max_threads)
+ nthr_mb_ = nstl::min(j.mb * j.od, max_threads);
+ nthr_ = nthr_mb_ * nthr_g_ * nthr_oc_b_ * nthr_ic_b_;
+
+ assert(nthr_ <= max_threads);
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_mb_ == 1));
+}
+
+template struct _jit_avx512_common_conv_fwd_kernel<Zmm>;
+template struct _jit_avx512_common_conv_fwd_kernel<Xmm>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp
new file mode 100644
index 0000000000..f76770797a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.hpp
@@ -0,0 +1,423 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP
+#define JIT_AVX512_COMMON_CONV_KERNEL_F32_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template<typename Vmm>
+struct _jit_avx512_common_conv_fwd_kernel : public jit_generator {
+
+ _jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
+ const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
+ {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
+ this, jcp.eltwise);
+
+ generate();
+ jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
+ }
+
+ ~_jit_avx512_common_conv_fwd_kernel() {
+ delete eltwise_injector_;
+ }
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_fwd_kernel)
+
+ jit_conv_conf_t jcp;
+ const primitive_attr_t &attr_;
+ void (*jit_ker_)(jit_conv_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ enum {
+ typesize = sizeof(float),
+ ker_reg_base_idx = 28,
+ };
+
+ reg64_t param = abi_param1;
+ reg64_t reg_inp = r8;
+ reg64_t reg_ker = r9;
+ reg64_t reg_out = r10;
+
+ reg64_t reg_inp_prf = r11;
+ reg64_t reg_ker_prf = r12;
+ reg64_t reg_out_prf = r13;
+ reg64_t reg_owb = r12;
+
+ reg64_t aux_reg_inp = r14;
+ reg64_t aux_reg_ker = r15;
+
+ reg64_t aux_reg_inp_prf = rsi;
+ reg64_t aux_reg_ker_prf = rdx;
+
+ reg64_t reg_channel = rsi;
+ reg64_t reg_bias = rdx;
+
+ reg64_t aux_reg_ker_d = r9;
+ reg64_t aux_reg_inp_d = rbx;
+ reg64_t aux_reg_inp_d_prf = r13;
+ reg64_t aux_reg_ker_d_prf = abi_not_param1;
+ reg64_t reg_ki = r10;
+
+ reg64_t reg_kj = rax;
+ reg64_t reg_relu_ns = rax;
+ reg64_t reg_oi = rbx;
+ reg64_t reg_kh = abi_not_param1;
+
+ reg64_t reg_tmp = rbp;
+
+ reg64_t reg_ic_loop = rdx;
+ reg64_t reg_inp_loop = rsi;
+
+ reg64_t reg_init_flag = r13;
+ reg64_t reg_bias_ptr = param;
+
+ reg64_t aux_reg_ic = r12;
+ reg64_t reg_binp = rax;
+ reg64_t reg_bout = r11;
+ reg64_t aux1_reg_inp = rbx;
+ reg64_t aux_reg_out = abi_not_param1;
+
+ reg64_t reg_long_offt = r11;
+ reg64_t reg_out_long_offt = r14;
+
+ inline Vmm vmm_ker(int i_ic) {
+ assert(i_ic < 4);
+ return Vmm(ker_reg_base_idx + i_ic);
+ }
+
+ inline Vmm vmm_out(int i_ur, int i_oc) {
+ int idx = i_ur + i_oc * jcp.ur_w;
+ assert(idx < ker_reg_base_idx);
+ return Vmm(idx);
+ }
+
+ inline Vmm vmm_inp(int i_ic, int nb_x_blocking) {
+ int idx = i_ic + nb_x_blocking * jcp.ur_w;
+ assert(idx < 31);
+ return Vmm(idx);
+ }
+
+ Xbyak::Reg64 imm_addr64 = r15;
+ Vmm vmm_wei = Vmm(31);
+
+ jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
+
+ inline void prepare_output(int ur_w);
+ inline void store_output(int ur_w);
+ inline void compute_loop_fma(int ur_w, int pad_l, int pad_r);
+ inline void compute_loop_fma_core(int ur_w, int pad_l, int pad_r);
+ inline void compute_loop_4fma(int ur_w, int pad_l, int pad_r);
+ inline void compute_loop_4fma_1st(int ur_w, int pad_l, int pad_r);
+ inline void compute_loop(int ur_w, int pad_l, int pad_r);
+
+ void generate();
+
+ inline size_t get_output_offset(int oi, int n_oc_block) {
+ return (size_t)jcp.typesize_out * ((size_t)n_oc_block * jcp.oh
+ * jcp.ow * jcp.od + oi) * jcp.oc_block;
+ }
+
+ inline size_t get_input_offset(int ki, int ic, int oi, int pad_l) {
+ size_t iw_str = !jcp.is_1stconv ? jcp.ic_block : 1;
+ size_t ic_str = !jcp.is_1stconv ? 1 : (size_t)jcp.iw * jcp.ih * jcp.id;
+ return (size_t)jcp.typesize_in * ((size_t)(ki * (jcp.dilate_w + 1)
+ + oi * jcp.stride_w - pad_l) * iw_str + ic * ic_str);
+ }
+
+ inline int get_kernel_offset(int ki,int ic,int n_oc_block,int ker_number) {
+ return jcp.typesize_in * jcp.oc_block
+ * (n_oc_block * jcp.nb_ic * jcp.ic_block * jcp.kh * jcp.kw * jcp.kd
+ + (ic + ker_number) + ki * jcp.ic_block);
+ }
+
+ inline int get_ow_start(int ki, int pad_l) {
+ return nstl::max(0,
+ utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
+ }
+
+ inline int get_ow_end(int ur_w, int ki, int pad_r) {
+ return ur_w - nstl::max(0, utils::div_up(pad_r
+ - (jcp.kw - 1 - ki)
+ * (jcp.dilate_w + 1),
+ jcp.stride_w));
+ }
+};
+
+struct jit_avx512_common_conv_fwd_kernel {
+
+ jit_avx512_common_conv_fwd_kernel(jit_conv_conf_t ajcp,
+ const primitive_attr_t &attr) :
+ jit_ker(nullptr),
+ zmm_kernel_(nullptr),
+ xmm_kernel_(nullptr) {
+ int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.oc_block;
+ switch (ch_block) {
+ case 16:
+ zmm_kernel_ =
+ new _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm>(
+ ajcp, attr);
+ jit_ker = zmm_kernel_->jit_ker_;
+ return;
+ case 4:
+ xmm_kernel_ =
+ new _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm>(
+ ajcp, attr);
+ jit_ker = xmm_kernel_->jit_ker_;
+ return;
+ default:
+ assert(!"invalid channel blocking");
+ }
+ }
+
+ ~jit_avx512_common_conv_fwd_kernel() {
+ delete xmm_kernel_;
+ delete zmm_kernel_;
+ }
+
+ enum {
+ typesize = sizeof(float)
+ };
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ memory_desc_t &src_pd,
+ memory_desc_t &weights_pd,
+ memory_desc_t &dst_pd,
+ memory_desc_t &bias_pd,
+ const primitive_attr_t &attr,
+ int nthreads);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ void(*jit_ker)(jit_conv_call_s *);
+ _jit_avx512_common_conv_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
+ _jit_avx512_common_conv_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
+};
+
+struct jit_avx512_common_conv_bwd_data_kernel_f32: public jit_generator {
+
+ jit_avx512_common_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp)
+ {
+ generate();
+ jit_ker = (void (*)(jit_conv_call_s *))getCode();
+ }
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_data_kernel_f32)
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ jit_conv_conf_t jcp;
+ void (*jit_ker)(jit_conv_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ enum {
+ typesize = sizeof(float),
+ ker_reg_base_idx = 28,
+ };
+
+ reg64_t param = abi_param1;
+ reg64_t reg_dst = r8;
+ reg64_t reg_ker = r9;
+ reg64_t reg_src = r10;
+
+ reg64_t reg_dst_prf = r11;
+ reg64_t reg_ker_prf = r12;
+ reg64_t reg_src_prf = r13;
+
+ reg64_t aux_reg_dst = r14;
+ reg64_t aux_reg_ker = r15;
+
+ reg64_t aux_reg_dst_prf = rsi;
+ reg64_t aux_reg_ker_prf = rdx;
+
+ reg64_t aux_reg_dst_d_prf = r13;
+ reg64_t aux_reg_dst_d = rbx;
+ reg64_t aux_reg_ker_d_prf = abi_not_param1;
+ reg64_t aux_reg_ker_d = r9;
+ reg64_t reg_ki = r10;
+
+ reg64_t reg_kj = rax;
+ reg64_t reg_oi = rbx;
+ reg64_t reg_kh = abi_not_param1;
+
+ reg64_t reg_channel = rsi;
+
+ reg64_t reg_tmp = rbp;
+ reg64_t reg_long_offt = r14;
+
+ inline Xbyak::Zmm zmm_ker(int i_ic) {
+ assert(i_ic < 4);
+ return Xbyak::Zmm(ker_reg_base_idx + i_ic);
+ }
+ inline Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) {
+ int idx = i_ic + nb_x_blocking * jcp.ur_w;
+ assert(idx < 31);
+ return Xbyak::Zmm(idx);
+ }
+ inline Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
+ int idx = i_ur + i_oc * jcp.ur_w;
+ assert(idx < ker_reg_base_idx);
+ return Xbyak::Zmm(idx);
+ }
+
+ Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
+
+ inline void prepare_output(int ur_w);
+ inline void store_output(int ur_w);
+ inline void compute_loop_4fma(int ur_w, int l_overflow, int r_overflow);
+ inline void compute_loop_fma(int ur_w, int l_overflow, int r_overflow);
+ inline void compute_loop_fma_core(int ur_w, int l_overflow, int r_overflow);
+ inline void compute_loop(int ur_w, int l_overflow, int r_overflow);
+ void generate();
+
+ inline int get_iw_start(int ki, int l_overflow)
+ {
+ int res = (jcp.iw - 1 + jcp.r_pad) % jcp.stride_w
+ + l_overflow * jcp.stride_w
+ - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
+ while (res < 0)
+ res += jcp.stride_w;
+
+ return res;
+ }
+
+ inline int get_iw_end(int ur_w, int ki, int r_overflow)
+ {
+ if (utils::one_of(ur_w, jcp.iw, jcp.ur_w_tail))
+ ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
+ int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
+ + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
+ while (res < 0)
+ res += jcp.stride_w;
+
+ return ur_w - res;
+ }
+};
+
+struct jit_avx512_common_conv_bwd_weights_kernel_f32 : public jit_generator {
+
+ jit_avx512_common_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp)
+ : jcp(ajcp)
+ {
+ generate();
+ jit_ker = (void (*)(jit_conv_call_s *))getCode();
+ }
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_conv_bwd_weights_kernel_f32)
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ memory_desc_t &src_md,
+ memory_desc_t &diff_weights_md,
+ memory_desc_t &diff_bias_md,
+ memory_desc_t &diff_dst_md);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ jit_conv_conf_t jcp;
+ void (*jit_ker)(jit_conv_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ enum {typesize = sizeof(float)};
+ static const int max_ur_w;
+
+ reg64_t param = abi_param1;
+ reg64_t reg_input = rax;
+ reg64_t reg_kernel = rdx;
+ reg64_t reg_output = rsi;
+ reg64_t b_ic = abi_not_param1;
+ reg64_t kj = r8;
+ reg64_t reg_kh = r9;
+ reg64_t reg_ur_w_trips = r10;
+ reg64_t reg_oj = r15;
+ reg64_t reg_ih_count = rbx;
+ reg64_t reg_tmp = r14;
+ reg64_t reg_long_offt = r14;
+
+ reg64_t ki = r11;
+ reg64_t reg_kd_count = r12;
+ reg64_t reg_oi = r12;
+ reg64_t reg_d_index = r13;
+ reg64_t reg_input_d = r15;
+ reg64_t reg_output_d = rbx;
+ reg64_t aux_reg_input = r12;
+ reg64_t aux_reg_kernel = r13;
+ reg64_t reg_bias = rbx;
+
+ inline void bias_kernel();
+ inline void maybe_zero_kernel();
+ inline void compute_oh_step_unroll_ow_icblock(int ic_block_step,
+ int max_ur_w);
+ inline void od_step_comeback_pointers();
+ inline void oh_step_comeback_pointers();
+ inline void compute_oh_step_unroll_ow(int ic_block_step, int max_ur_w);
+ inline void compute_ic_block_step(int ur_w,
+ int pad_l, int pad_r, int ic_block_step,
+ int input_offset, int kernel_offset, int output_offset,
+ bool input_wraparound = false);
+ inline void compute_ic_block_step_fma(int ur_w,
+ int pad_l, int pad_r, int ic_block_step,
+ int input_offset, int kernel_offset, int output_offset,
+ bool input_wraparound);
+ inline void compute_ic_block_step_4fma(int ur_w,
+ int pad_l, int pad_r, int ic_block_step,
+ int input_offset, int kernel_offset, int output_offset,
+ bool input_wraparound);
+ inline void compute_oh_step_common(int ic_block_step, int max_ur_w);
+ inline void compute_oh_step_disp();
+ inline void compute_oh_loop_common();
+ inline void compute_d_loop_common();
+
+ inline bool compute_full_spat_loop();
+ inline bool flat_4ops_compute();
+
+ inline void compute_loop();
+
+ void generate();
+
+ static void balance(const jit_conv_conf_t &j, int &nthr, int &nthr_mb,
+ int &nthr_g, int &nthr_oc_b, int &nthr_ic_b);
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp
new file mode 100644
index 0000000000..1bdcd0d6a8
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp
@@ -0,0 +1,1163 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "cpu_memory.hpp"
+
+#include <math.h>
+
+#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
+
+#ifndef KERNEL_SIZE_THRESHOLD
+#define KERNEL_SIZE_THRESHOLD 16
+#endif
+
+#define MIN_REQUIRED_DIMN_REG_BLOCK 14
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace {
+
+using namespace mkldnn::impl::utils;
+
+unsigned int L1_cache_size = get_cache_size(1, true);
+unsigned int L2_cache_size = get_cache_size(2, true);
+unsigned int LLC_data_size = get_cache_size(3, false);
+
+// the test funtion takes jcp, the candidate and the current best.
+// it returns true if the new candidate is better
+int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
+ int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
+{
+ int best_divisor = default_best;
+ auto test_num
+ = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
+ if (test(jcp, num, best_divisor)) {
+ best_divisor = num;
+ }
+ };
+
+ for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
+ if (number % divisor == 0) {
+ test_num(jcp, divisor);
+ test_num(jcp, number / divisor);
+ }
+ }
+
+ return best_divisor;
+}
+
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
+ if (jcp.ver == ver_4fma)
+ return jcp.mb >= 32;
+ else
+ return jcp.mb >= 16;
+}
+}
+
+/* assumes 512 bits registers */
+/* TODO: add support for strides */
+/* TODO: handle the prefetch distance automatically */
+typedef enum cache_t_ { L1, L2, L3 } cache_t;
+
+template <typename data_t>
+struct prefetcher_t {
+ prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
+ cache_t cache_type, size_t block_size, /* in number of elements*/
+ int nb_instructions_in_block, int fma_ipc)
+ : cg_(generator)
+ , reg_base_addr_(reg_base_addr)
+ , cache_type_(cache_type)
+ , cache_block_size_(block_size)
+ {
+ nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
+ prefetch_spread_
+ = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
+ prefetch_blk_
+ = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
+
+ /* assumption: when fetch in Li, data is already in L(i+1) */
+ int cache_latency;
+ switch (cache_type_) {
+ case L1: cache_latency = 14; break;
+ case L2:
+ case L3:
+ default: cache_latency = 250; break;
+ }
+
+ prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
+ }
+
+ void prefetch(int instruction_number)
+ {
+ if (instruction_number % prefetch_spread_ == 0) {
+ for (int i = 0; (i < prefetch_blk_)
+ && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
+ i++, prefetches_issued_++) {
+ prefetch_inst_(cg_->EVEX_compress_addr(
+ reg_base_addr_, (cache_block_size_ * prefetch_distance_)
+ * sizeof(data_t)
+ + (prefetches_issued_ * 64)));
+ }
+ }
+ }
+
+private:
+ void prefetch_inst_(const Xbyak::Address &addr)
+ {
+ switch (cache_type_) {
+ case L1: cg_->prefetcht0(addr); break;
+ case L2: cg_->prefetcht1(addr); break;
+ case L3: cg_->prefetcht2(addr); break;
+ default:
+ break; // TODO: raise an exception or put an assert
+ }
+ }
+
+ jit_generator *cg_;
+ Xbyak::Reg64 reg_base_addr_;
+ cache_t cache_type_;
+ int cache_block_size_ = 0;
+ int nb_cache_lines_to_prefetch_ = 0;
+ int prefetches_issued_ = 0;
+ int prefetch_spread_ = 0;
+ int prefetch_blk_ = 0;
+ int prefetch_distance_ = 0;
+};
+
+// utilities to support kernel parameter selection
+bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
+ int dimM_block, int dimM_simd_block, float C)
+{
+ float lhs = (dimM_block * dimN_reg_block * dimM_simd_block
+ + dimM_block * dimK_block * dimK_reg_block
+ * dimM_simd_block
+ + dimK_block * dimN_reg_block * dimK_reg_block)
+ * (float)sizeof(float);
+ float rhs = C * L1_cache_size;
+ return (lhs < rhs);
+}
+
+bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
+ int dimM_block, int dimM_simd_block, float C)
+{
+ float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block
+ + dimK_block * dimN_reg_block * dimK_reg_block)
+ * (float)sizeof(float);
+ float rhs = C * L1_cache_size;
+ return (lhs < rhs);
+}
+
+bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
+ int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block,
+ float C)
+{
+ float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block
+ + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
+ * dimM_simd_block
+ + nb_dimN_reg_block * dimK_nb_block * dimK_block
+ * dimN_reg_block * dimK_reg_block)
+ * (float)sizeof(float);
+ float rhs = C * L2_cache_size;
+ return (lhs < rhs);
+}
+}
+
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::utils;
+using namespace Xbyak;
+
+void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate(
+ bool is_beta_zero)
+{
+ // const int dimK_simd_block = jcp.dimK_reg_block;
+
+ // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
+ // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
+ // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
+ // dimK_reg_block++)
+ // for (int tile =0; tile < jcp.dimN_reg_block; tile++)
+ // C[dimM_block][tile] +=
+ // A[dimM_block][dimK_block][dimK_reg_block] *
+ // broadcast(B[dimK_block][tile][dimK_reg_block]);
+ // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block],
+ // so we load it before the loop on tile
+ // 2) the loop on tile must be fully unrolled. Don't know about the one on
+ // dimK_reg_block. I think it should be
+
+ auto inner_loops = [=]() {
+ Label dimM_block_loop, dimK_block_loop;
+ const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1;
+ const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
+
+ prefetcher_t<float> L1_pf(this, reg_srcB, L1,
+ jcp.dimN_reg_block * jcp.dimK_reg_block,
+ jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
+ fma_ipc);
+ prefetcher_t<float> L2_pf(this, reg_srcB, L2,
+ jcp.dimN_reg_block * jcp.dimK_reg_block,
+ jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
+ fma_ipc);
+
+ if (jcp.dimM_block > 1) {
+ mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
+ L(dimM_block_loop);
+ }
+ {
+ // First, we zero the accumulators if first nb_ic iteration,
+ // otherwise we load them
+ for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
+ Zmm zmm(jcp.zmm_start + tile);
+ if (is_beta_zero)
+ vpxord(zmm, zmm, zmm);
+ else
+ vmovups(zmm, zword[reg_dstC + 64 * tile]);
+ }
+
+ if (jcp.dimK_block > 1) {
+ mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
+ L(dimK_block_loop);
+ }
+ {
+ auto load_A = [=](int reg_idx, int offset) {
+ for (int i = 0; i < inc_dimK_reg_block; i++)
+ vmovups(Zmm(reg_idx + i),
+ zword[reg_srcA + 64 * (offset + i)]);
+ };
+
+ // Used when doing double buffering
+ int next = 0;
+ if (jcp.double_buffering) {
+ load_A(next, 0);
+ }
+ for (int dimK_reg_block = 0;
+ dimK_reg_block < jcp.dimK_reg_block;
+ dimK_reg_block += inc_dimK_reg_block) {
+ int current;
+ /* Loading the next vector from A */
+ current = next;
+ if (jcp.double_buffering) {
+ next = (dimK_reg_block + inc_dimK_reg_block)
+ % (2 * inc_dimK_reg_block);
+ load_A(next, dimK_reg_block + inc_dimK_reg_block);
+ } else {
+ next = 0;
+ load_A(next, dimK_reg_block);
+ }
+ /* Performing the fmas */
+ for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
+ Zmm zmm(jcp.zmm_start + tile);
+ if (jcp.ver != ver_avx512_core)
+ L1_pf.prefetch(
+ dimK_reg_block * jcp.dimN_reg_block + tile);
+ if (jcp.ver == ver_4fma)
+ v4fmaddps(zmm, Zmm(current),
+ EVEX_compress_addr(reg_srcB,
+ 64 * tile + dimK_reg_block * 4));
+ else
+ vfmadd231ps(zmm, Zmm(current),
+ EVEX_compress_addr(reg_srcB,
+ 64 * tile + dimK_reg_block * 4,
+ true));
+ if (jcp.ver != ver_avx512_core)
+ L2_pf.prefetch(
+ dimK_reg_block * jcp.dimN_reg_block + tile);
+ }
+ }
+
+ add(reg_srcA, jcp.dimK_reg_block * 64);
+ add(reg_srcB, jcp.dimN_reg_block * 64);
+ if (jcp.dimK_block > 1) {
+ sub(reg_dimK_block_loop_cnt, 1);
+ jnz(dimK_block_loop);
+ }
+ }
+
+
+ auto store_output = [=](bool output_is_aligned) {
+ for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
+ Zmm zmm(jcp.zmm_start + tile);
+ if (output_is_aligned
+ && jcp.dimK_nb_block == 1
+ && (jcp.dimN * jcp.dimM * alpha * alpha
+ * sizeof(float) > 2 * LLC_data_size))
+ vmovntps(zword[reg_dstC + 64 * tile], zmm);
+ else
+ vmovups(zword[reg_dstC + 64 * tile], zmm);
+ }
+ };
+
+ Label unaligned_store, end_store;
+ test(reg_dstC, cpu_isa_traits<avx512_common>::vlen - 1);
+ jnz(unaligned_store, T_NEAR);
+ store_output(true);
+ jmp(end_store, T_NEAR);
+ L(unaligned_store); {
+ store_output(false);
+ }
+ L(end_store);
+
+ if (jcp.dimM_block > 1) {
+ sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
+ add(reg_dstC, jcp.dimN_reg_block * 64);
+ sub(reg_dimM_block_loop_cnt, 1);
+ jnz(dimM_block_loop);
+ }
+ }
+ };
+
+ /* Preamble */
+ preamble();
+
+ /* kernel */
+ inner_loops();
+
+ /* Postamble */
+ postamble();
+ ret();
+}
+
+status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common(
+ jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d)
+{
+
+ if (mayiuse(avx512_core))
+ return status::unimplemented;
+ else if (!mayiuse(avx512_common))
+ return status::unimplemented;
+ else if (mayiuse(avx512_mic_4ops))
+ jcp.ver = ver_4fma;
+ else
+ jcp.ver = ver_fma;
+
+ jcp.nthr = mkldnn_get_max_threads();
+
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.ih = src_d.dims()[2];
+ jcp.iw = src_d.dims()[3];
+ jcp.oh = dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[3];
+ jcp.kh = weights_d.dims()[with_groups + 2];
+ jcp.kw = weights_d.dims()[with_groups + 3];
+ jcp.t_pad = cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][1];
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+ jcp.dilate_h = cd.dilates[0];
+ jcp.dilate_w = cd.dilates[1];
+ jcp.r_pad = nstl::max(
+ 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
+ jcp.b_pad = nstl::max(
+ 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
+ jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
+ jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
+ jcp.ohp = jcp.oh;
+ jcp.owp = jcp.ow;
+
+ bool ok_to_pad_channels = jcp.ngroups == 1;
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+ }
+
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
+ // Checking conditions not supported by these kernels
+ if (jcp.ngroups != 1)
+ return status::unimplemented;
+ if ((jcp.kh != 3) || (jcp.kw != 3))
+ return status::unimplemented;
+ if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
+ return status::unimplemented;
+ if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
+ return status::unimplemented;
+ if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
+ return status::unimplemented;
+
+ format_tag_t dat_tag = nChw16c;
+ format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+
+ if (jcp.src_tag != dat_tag) return status::unimplemented;
+ if (jcp.wei_tag != wei_tag) return status::unimplemented;
+ if (jcp.dst_tag != dat_tag) return status::unimplemented;
+
+ bool layout_consistency = true
+ && jcp.ic <= src_d.padded_dims()[1]
+ && jcp.oc <= dst_d.padded_dims()[1]
+ && jcp.ic <= weights_d.padded_dims()[with_groups + 1]
+ && jcp.oc <= weights_d.padded_dims()[with_groups + 0];
+ if (!layout_consistency) return status::unimplemented;
+
+ return status::success;
+}
+
+
+status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) {
+
+ auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
+ int dimN_reg_block, int current_best) {
+ return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK)
+ && (dimN_reg_block < jcp.nb_reg)
+ && (dimN_reg_block < current_best);
+ };
+ jcp.dimN_reg_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block);
+
+ if (jcp.dimN_reg_block >= jcp.nb_reg) {
+ auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
+ int dimN_reg_block, int current_best) {
+ return (dimN_reg_block < jcp.nb_reg)
+ && (dimN_reg_block > current_best);
+ };
+
+ jcp.dimN_reg_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimN, 1, test_cond_dimN_reg_block);
+ }
+
+ //********************* Choosing dimK_block **********************//
+ auto test_cond1_dimK_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
+ return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
+ 1, jcp.dimM_simd_block, .75f)
+ && (dimK_block > current_best);
+ };
+
+ auto test_cond1_bis_dimK_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
+ return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
+ jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f)
+ && (dimK_block > current_best);
+ };
+
+ jcp.dimK_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
+ // If we are not able to use streams, we fall back to condition [1]
+ if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
+ jcp.dimK_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
+ jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
+
+ //********************* Choosing dimM_block **********************//
+ jcp.dimM_simd_block = 16;
+ /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/
+ auto test_cond1_dimM_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
+ return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
+ jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f)
+ && (dimM_block > current_best);
+ };
+
+ auto test_cond1_bis_dimM_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
+ return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
+ jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f)
+ && (dimM_block > current_best);
+ };
+
+ if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
+ jcp.dimM_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
+ else
+ jcp.dimM_block = get_divisor_satisfying_cond(jcp,
+ jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block);
+ jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
+
+ //******************* Choosing dimN_block *******************//
+ auto test_cond2_dimN_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
+ return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
+ jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
+ jcp.dimM_simd_block, .5f)
+ && (dimN_block > current_best);
+ };
+
+ jcp.dimN_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
+ jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
+ jcp.sched_policy = WSCHED_DATA_W_S_G_D;
+ return status::success;
+}
+
+status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel(
+ jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
+{
+ jcp.dimK_reg_block = 16;
+ jcp.dimM_simd_block = 16;
+
+ // TODO: replace double buffering with nuple buffering to maximize register
+ // usage.
+ // the choice of the number of buffers will then come after choosing
+ // dimN_reg_block
+ jcp.double_buffering = true;
+ if (jcp.double_buffering)
+ jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2);
+ else
+ jcp.zmm_start = 1;
+ jcp.nb_reg = 32 - jcp.zmm_start;
+
+ jcp.dimN = dimN;
+ jcp.dimK = dimK;
+ jcp.dimM = dimM;
+
+ jcp.sched_policy = WSCHED_INVALID;
+ set_wsched_DATA_W_S_G_D_avx512_common(jcp);
+
+ assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D);
+ return status::success;
+}
+
+bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok(
+ jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+
+ switch (p.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_relu(0) || is_sum(0); // relu or sum
+ case 2: return (is_sum(0) && is_relu(1)) ||
+ (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
+ case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
+ jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
+ status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d);
+
+ if (st != status::success)
+ return st;
+
+ // Winograd specific initialization
+ jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
+ jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
+ jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
+
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+ jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
+
+ status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
+ jcp.ic_simd_block = jcp.dimK_reg_block;
+ jcp.ic_block = jcp.dimK_block;
+ jcp.nb_ic = jcp.dimK_nb_block;
+ jcp.oc_simd_block = jcp.dimM_simd_block;
+ jcp.oc_block = jcp.dimM_block;
+ jcp.nb_oc = jcp.dimM_nb_block;
+ jcp.tile_block_ur = jcp.dimN_reg_block;
+ jcp.nb_tile_block_ur = jcp.dimN_block;
+ jcp.tile_block = jcp.dimN_nb_block;
+ jcp.tile_4fma_padding = 0; // only relevant for backward weights
+
+ return res;
+}
+
+status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf(
+ jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d)
+{
+ status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
+
+ if (st != status::success)
+ return st;
+
+ jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
+ jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
+ jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
+
+ status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
+ jcp.oc_simd_block = jcp.dimK_reg_block;
+ jcp.oc_block = jcp.dimK_block;
+ jcp.nb_oc = jcp.dimK_nb_block;
+ jcp.ic_simd_block = jcp.dimM_simd_block;
+ jcp.ic_block = jcp.dimM_block;
+ jcp.nb_ic = jcp.dimM_nb_block;
+ jcp.tile_block_ur = jcp.dimN_reg_block;
+ jcp.nb_tile_block_ur = jcp.dimN_block;
+ jcp.tile_block = jcp.dimN_nb_block;
+ jcp.tile_4fma_padding = 0; // only relevant for backward weights
+
+ return res;
+}
+
+void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate()
+{
+ auto load_B = [=](int reg_idx, int offset) {
+ for (int i = 0; i < 4; i++) {
+ vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]);
+ }
+ };
+
+ preamble();
+ int curr = 0;
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ int origB_offset = (j * alpha + i) * jcp.dimK_4fma;
+ size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block *
+ jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block *
+ jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float);
+ mov(reg_transB_idx, transB_offset);
+ for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) {
+ /*double buffering to hide load latencies*/
+ int next = (curr + 4) % 8;
+ if (i == 0 && tb == 0) {
+ load_B(0, origB_offset);
+ }
+ if (tb + 4 < (jcp.dimK_4fma -1)) {
+ load_B(next, origB_offset + 4);
+ } else if (i < alpha - 1) {
+ load_B(next, origB_offset + jcp.dimK_4fma);
+ }
+
+ vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1));
+ vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3));
+ vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1));
+ vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3));
+
+ vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9));
+ vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9));
+
+ vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1));
+ vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1));
+
+ vmovntps(zword[reg_transB + reg_transB_idx
+ + sizeof(float) * tb * jcp.dimN_reg_block],
+ Zmm(curr+2));
+ vmovntps(zword[reg_transB + reg_transB_idx
+ + sizeof(float) * (tb + 1) * jcp.dimN_reg_block],
+ Zmm(curr+3));
+ vmovntps(zword[reg_transB + reg_transB_idx
+ + sizeof(float) * (tb + 2) * jcp.dimN_reg_block],
+ Zmm(8));
+ vmovntps(zword[reg_transB + reg_transB_idx
+ + sizeof(float) * (tb + 3) * jcp.dimN_reg_block],
+ Zmm(9));
+ curr = next;
+
+ }
+ }
+ }
+ postamble();
+ ret();
+}
+void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate(
+ bool is_first_tile)
+{
+ // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++)
+ // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++)
+ // for (int nb_tile_block_ur = 0; nb_tile_block_ur <
+ // jcp.nb_tile_block_ur; nb_tile_block_ur++)
+ // for (int tile_block_ur = 0; tile_block_ur <
+ // jcp.tile_block_ur; tile_block_ur++)
+ // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3)
+ // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] +=
+ // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block]
+ // *
+ // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur])
+ auto inner_loops = [=]() {
+ int inc_fma = jcp.ver == ver_4fma ? 4 : 1;
+ const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
+ prefetcher_t<float> L1_pf(this, reg_srcB, L1,
+ jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
+ jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
+ / inc_fma,
+ fma_ipc);
+ prefetcher_t<float> L2_pf(this, reg_srcB, L2,
+ jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
+ jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
+ / inc_fma,
+ fma_ipc);
+
+ auto load_A = [=](int reg_idx, int offset) {
+ for (int i = 0; i < inc_fma; i++) {
+ vmovups(Zmm(reg_idx + i),
+ zword[reg_srcA +
+ sizeof(float) * jcp.dimM_simd_block * (offset + i)]);
+ }
+ };
+
+ Label dimM_block_loop, dimK_block_loop, dimN_block_loop;
+ if (jcp.dimM_block > 1) {
+ mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
+ L(dimM_block_loop);
+ }
+ { /************* OC_block (M) loop ***********/
+ if (jcp.dimN_block > 1) {
+ mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
+ L(dimN_block_loop);
+ }
+ { /*************** IC_block (N) loop *********/
+ for (int dimN_reg_block = 0;
+ dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
+ Zmm zmm(jcp.zmm_start + dimN_reg_block);
+ if (is_first_tile)
+ vpxord(zmm, zmm, zmm);
+ else
+ vmovups(zmm, zword[reg_dstC +
+ dimN_reg_block * jcp.dimM_simd_block *
+ sizeof(float)]);
+ }
+
+ if (jcp.dimK_block > 1) {
+ mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
+ L(dimK_block_loop);
+ }
+ { /************* nb_tile_ur(K) loop ********/
+ int next = 0;
+ if (jcp.double_buffering) {
+ load_A(next, 0);
+ }
+ for (int dimK_reg_block = 0;
+ dimK_reg_block < jcp.dimK_reg_block;
+ dimK_reg_block++) {
+ int srcB_offset = dimK_reg_block * jcp.dimK_4fma
+ * jcp.dimN_reg_block;
+ for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma;
+ dimK_4fma += inc_fma) {
+ int current = next;
+ if (jcp.double_buffering) {
+ next = (dimK_reg_block * jcp.dimK_4fma
+ + dimK_4fma + inc_fma)
+ % (2 * inc_fma);
+ load_A(next, dimK_reg_block * jcp.dimK_4fma
+ + dimK_4fma + inc_fma);
+ } else {
+ next = 0;
+ load_A(next, dimK_reg_block * jcp.dimK_4fma
+ + dimK_4fma);
+ }
+ for (int dimN_reg_block = 0;
+ dimN_reg_block < jcp.dimN_reg_block;
+ ++dimN_reg_block) {
+ L1_pf.prefetch(srcB_offset / inc_fma
+ + dimK_4fma / inc_fma
+ * jcp.dimN_reg_block
+ + dimN_reg_block);
+ L2_pf.prefetch(srcB_offset / inc_fma
+ + dimK_4fma / inc_fma
+ * jcp.dimN_reg_block
+ + dimN_reg_block);
+ if (jcp.ver == ver_4fma) {
+ int srcB_trans_offset = (dimK_4fma / 4) * 64
+ + dimK_4fma % 4;
+ v4fmaddps(
+ Zmm(jcp.zmm_start + dimN_reg_block),
+ Zmm(current),
+ EVEX_compress_addr(reg_srcB,
+ sizeof(float) * (
+ srcB_offset +
+ srcB_trans_offset +
+ (dimN_reg_block % 4) * 16 +
+ (dimN_reg_block / 4) * 4)));
+ } else {
+ vfmadd231ps(
+ Zmm(jcp.zmm_start + dimN_reg_block),
+ Zmm(current),
+ EVEX_compress_addr(reg_srcB,
+ sizeof(float) * (srcB_offset + dimN_reg_block),
+ true));
+ }
+ }
+ }
+ }
+ }
+
+ add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma
+ * jcp.dimM_simd_block * sizeof(float));
+ add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block
+ * jcp.dimK_4fma * sizeof(float));
+ if (jcp.dimK_block > 1) {
+ sub(reg_dimK_block_loop_cnt, 1);
+ jnz(dimK_block_loop);
+ }
+
+ /******** Write C back to memory *******/
+ for (int dimN_reg_block = 0;
+ dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
+ Zmm zmm(jcp.zmm_start + dimN_reg_block);
+ vmovups(zword[reg_dstC +
+ dimN_reg_block * jcp.dimM_simd_block * sizeof(float)],
+ zmm);
+ }
+
+ sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block *
+ jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
+ add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block
+ * sizeof(float));
+ if (jcp.dimN_block > 1) {
+ sub(reg_dimN_block_loop_cnt, 1);
+ jnz(dimN_block_loop);
+ }
+ }
+
+ if (jcp.dimM_block > 1) {
+ sub(reg_srcB, jcp.dimN_block * jcp.dimK_block
+ * jcp.dimK_reg_block * jcp.dimN_reg_block
+ * jcp.dimK_4fma * sizeof(float));
+ add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
+ * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
+ sub(reg_dimM_block_loop_cnt, 1);
+ jnz(dimM_block_loop);
+ }
+ }
+ };
+
+ /* Preamble */
+ // register used to handle long fma encoding
+ preamble();
+ mov(reg_srcA, reg_srcA_const);
+ inner_loops();
+
+ /* Postamble */
+ postamble();
+ ret();
+}
+
+namespace {
+bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block,
+ int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
+{
+ float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw;
+ lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw;
+ lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
+ lhs *= sizeof(float);
+ float rhs = C * L1_cache_size;
+ return (lhs <= rhs);
+}
+
+bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
+ int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
+{
+ float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma
+ * dimM_simdw;
+ lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
+ lhs *= sizeof(float);
+ float rhs = C * L1_cache_size;
+ return (lhs <= rhs);
+}
+
+bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
+ int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
+ float C)
+{
+ float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block
+ * dimK_4fma;
+ lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
+ * dimN_reg_block;
+ lhs *= sizeof(float);
+ float rhs = C * L2_cache_size;
+ return (lhs <= rhs);
+}
+
+bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block,
+ int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
+ float C)
+{
+ float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block;
+ lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma;
+ lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
+ * dimN_reg_block;
+ lhs *= sizeof(float);
+ float rhs = C * L2_cache_size;
+ return (lhs <= rhs);
+}
+} // namespace
+
+status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
+{
+ /*************** Choose dimN_reg_block (ic_simd_block)
+ * *******************************/
+ jcp.dimN = jcp.ic;
+ /*Hardcoded to 16 because N = ic for bwd weights and
+ innermost dimension for ic is assumed 16 in src transforms. This
+ choice covers load latencies while maintaining simplicity of kernel
+ for POR topologies. FIXME in future??: Will not work for future topologies
+ when ic%16 != 0*/
+ jcp.dimN_reg_block = jcp.ic_simd_block;
+
+ /****************************** Choose dimK_block
+ * **************************/
+ // No freedom for choosing dimM_simd_block because ic_simd_block
+ // is determined by input data format
+ jcp.dimM_simd_block = jcp.oc_simd_block;
+
+ auto test_cond1bis_dimK_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
+ return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
+ jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
+ && (dimK_block > current_best);
+ };
+
+ auto test_cond1_dimK_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
+ return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1,
+ jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
+ && (dimK_block > current_best);
+ };
+
+ auto test_cond2bis_dimK_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
+ return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
+ jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f)
+ && (dimK_block > current_best);
+ };
+
+ auto test_cond2_dimK_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
+ return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1,
+ jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f)
+ && (dimK_block > current_best);
+ };
+
+ jcp.dimK_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block);
+ if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma)
+ jcp.dimK_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block);
+
+ jcp.dimK_reg_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block);
+ if (jcp.dimK_reg_block < jcp.dimK_block) {
+ jcp.dimK_reg_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimK_block, 1, test_cond1_dimK_block);
+ }
+ jcp.dimK_block /= jcp.dimK_reg_block;
+ jcp.dimK_nb_block
+ = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block;
+ jcp.tile_block_ur = jcp.dimK_reg_block;
+ jcp.nb_tile_block_ur = jcp.dimK_block;
+ jcp.tile_block = jcp.dimK_nb_block;
+
+ /***************************** Chose dimN block
+ * ****************************/
+ auto test_cond2_dimN_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
+ return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block,
+ jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block,
+ jcp.dimN_reg_block, 0.5f)
+ && (dimN_block > current_best);
+ };
+
+ jcp.dimN_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
+ jcp.ic_block = jcp.dimN_block;
+ jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block;
+ jcp.nb_ic = jcp.dimN_nb_block;
+
+ /********************************* Choose dimM block
+ * ************************/
+ jcp.dimM = jcp.oc;
+
+ auto test_cond1_dimM_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
+ return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1,
+ jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block,
+ 1.0f)
+ && (dimM_block > current_best)
+ && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2;
+ };
+
+ jcp.dimM_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
+ jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
+
+ jcp.sched_policy = WSCHED_WEI_S_D_G_W;
+ return status::success;
+}
+
+status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
+ jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
+ const memory_desc_wrapper &diff_weights_d)
+{
+ jcp.nthr = mkldnn_get_max_threads();
+
+ const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
+
+ jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+ jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.ih = src_d.dims()[2];
+ jcp.iw = src_d.dims()[3];
+ jcp.oh = diff_dst_d.dims()[2];
+ jcp.ow = diff_dst_d.dims()[3];
+ jcp.kh = diff_weights_d.dims()[with_groups + 2];
+ jcp.kw = diff_weights_d.dims()[with_groups + 3];
+ jcp.t_pad = cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][1];
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+ jcp.r_pad = nstl::max(
+ 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
+ jcp.b_pad = nstl::max(
+ 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
+ jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
+ jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
+ jcp.ohp = jcp.oh;
+ jcp.owp = jcp.ow;
+ jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef);
+ jcp.dilate_h = cd.dilates[0];
+ jcp.dilate_w = cd.dilates[1];
+
+ bool ok_to_pad_channels = jcp.ngroups == 1;
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+ }
+
+ if (mayiuse(avx512_core))
+ return status::unimplemented;
+ if (!mayiuse(avx512_common))
+ return status::unimplemented;
+ else if (mayiuse(avx512_mic_4ops))
+ jcp.ver = ver_4fma;
+ else
+ jcp.ver = ver_fma;
+
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+ // Winograd specific initialization
+ jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
+ jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
+ jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
+
+ // Winograd kernel works only for 3x3 convolution with stride 1
+ if (jcp.ngroups != 1)
+ return status::unimplemented;
+ if ((jcp.kh != 3) || (jcp.kw != 3))
+ return status::unimplemented;
+ if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
+ return status::unimplemented;
+ if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
+ return status::unimplemented;
+ if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
+ return status::unimplemented;
+
+ format_tag_t dat_tag = nChw16c;
+ format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
+
+ if (jcp.src_tag != dat_tag) return status::unimplemented;
+ if (jcp.wei_tag != wei_tag) return status::unimplemented;
+ if (jcp.dst_tag != dat_tag) return status::unimplemented;
+
+ bool layout_consistency = true
+ && jcp.ic <= src_d.padded_dims()[1]
+ && jcp.oc <= diff_dst_d.padded_dims()[1]
+ && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
+ && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
+ if (!layout_consistency) return status::unimplemented;
+
+ /*************************** New Kernel Parameters
+ * *****************************/
+ jcp.ic_simd_block = simd_w;
+ jcp.oc_simd_block = simd_w;
+ jcp.dimK_4fma = 1;
+ jcp.tile_4fma_padding = 0;
+
+#define MAX_4FMA_UR 8
+ if (jcp.ver == ver_4fma) {
+ auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma,
+ int current_best) {
+ return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR)
+ && (dimK_4fma > current_best);
+ };
+ jcp.dimK_4fma = get_divisor_satisfying_cond(
+ jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma);
+ if (jcp.dimK_4fma == 1)
+ jcp.dimK_4fma = 4;
+ if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0)
+ jcp.tile_4fma_padding = jcp.dimK_4fma
+ - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma);
+ }
+
+ jcp.tile_4fma = jcp.dimK_4fma;
+ /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src
+ * transform
+ * will not work correctly, this is solved by applying padding.*/
+ jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
+ jcp.dimN = jcp.ic;
+ jcp.dimM = jcp.oc;
+
+ jcp.double_buffering = true;
+ if (jcp.double_buffering)
+ jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2;
+ else
+ jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1;
+ jcp.nb_reg = 32 - jcp.zmm_start;
+
+ jcp.sched_policy = WSCHED_INVALID;
+ status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp);
+ assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W);
+
+ jcp.tile_block_ur = jcp.dimK_reg_block;
+ jcp.nb_tile_block_ur = jcp.dimK_block;
+ jcp.tile_block = jcp.dimK_nb_block;
+
+ jcp.ic_block = jcp.dimN_block;
+ jcp.nb_ic = jcp.dimN_nb_block;
+
+ jcp.oc_block = jcp.dimM_block;
+ jcp.nb_oc = jcp.dimM_nb_block;
+
+ return res;
+
+}
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp
new file mode 100644
index 0000000000..6c117143f5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.hpp
@@ -0,0 +1,179 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
+#define JIT_AVX512_COMMON_CONV_WINOGRAD_KERNEL_F32_HPP
+
+#include "c_types_map.hpp"
+#include "cpu_memory.hpp"
+
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+//alpha determines the output tile_size
+constexpr int alpha = 6;
+constexpr int tile_size = 4;
+//simd length used for vectorization
+constexpr int simd_w = 16;
+
+struct _jit_avx512_common_conv_winograd_data_kernel_f32 : public jit_generator {
+ _jit_avx512_common_conv_winograd_data_kernel_f32(
+ jit_conv_winograd_conf_t ajcp)
+ : jcp(ajcp)
+ {
+ //******************* First iter kernel ********************//
+ this->gemm_loop_generate(true);
+ gemm_loop_ker_first_iter
+ = (decltype(gemm_loop_ker_first_iter)) this->getCode();
+
+ //************** Subsequent iterations kernel **************//
+ if (jcp.dimK_nb_block > 1) {
+ align();
+ const Xbyak::uint8 *addr = getCurr();
+ this->gemm_loop_generate(false);
+ gemm_loop_ker = (decltype(gemm_loop_ker))addr;
+ }
+ }
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_data_kernel_f32)
+
+ static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d);
+
+ static status_t init_conf_kernel(
+ jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
+
+ jit_conv_winograd_conf_t jcp;
+ void (*gemm_loop_ker)(float *, const float *, const float *);
+ void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
+
+protected:
+ using reg64_t = const Xbyak::Reg64;
+ enum { typesize = sizeof(float) };
+
+ void gemm_loop_generate(bool is_beta_zero);
+
+ /* registers used for GEMM */
+ reg64_t reg_dstC = abi_param1;
+ reg64_t reg_srcA = abi_param2;
+ reg64_t reg_srcB = abi_param3;
+
+ reg64_t reg_dimM_block_loop_cnt = r10;
+ reg64_t reg_dimK_block_loop_cnt = r11;
+};
+
+struct jit_avx512_common_conv_winograd_fwd_kernel_f32
+ : _jit_avx512_common_conv_winograd_data_kernel_f32 {
+ using _jit_avx512_common_conv_winograd_data_kernel_f32::
+ _jit_avx512_common_conv_winograd_data_kernel_f32;
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_conv_winograd_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
+};
+
+struct jit_avx512_common_conv_winograd_bwd_data_kernel_f32
+ : public _jit_avx512_common_conv_winograd_data_kernel_f32 {
+ using _jit_avx512_common_conv_winograd_data_kernel_f32::
+ _jit_avx512_common_conv_winograd_data_kernel_f32;
+
+ static status_t init_conf(jit_conv_winograd_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d);
+};
+
+struct jit_avx512_common_conv_winograd_bwd_weights_kernel_f32
+ : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_common_conv_winograd_bwd_weights_kernel_f32)
+
+ jit_avx512_common_conv_winograd_bwd_weights_kernel_f32(
+ jit_conv_winograd_conf_t ajcp)
+ : jcp(ajcp)
+ {
+
+ //******************* First iter kernel ********************//
+ {
+ align();
+ const Xbyak::uint8 *addr = getCurr();
+ this->gemm_loop_generate(true);
+ gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))addr;
+ }
+
+ if (jcp.tile_block > 1) {
+ align();
+ const Xbyak::uint8 *addr = getCurr();
+ this->gemm_loop_generate(false);
+ gemm_loop_ker = (decltype(gemm_loop_ker))addr;
+ }
+
+ if (jcp.ver == ver_4fma) {
+ align();
+ const Xbyak::uint8 *addr = getCurr();
+ this->transpose_ker_generate();
+ transpose_4fma_ker = (decltype(transpose_4fma_ker))addr;
+ }
+ }
+
+ static status_t init_conf(jit_conv_winograd_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &diff_dst_d,
+ const memory_desc_wrapper &diff_weights_d);
+
+ jit_conv_winograd_conf_t jcp;
+ void (*gemm_loop_ker)(float *, const float *, const float *);
+ void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
+ void (*transpose_4fma_ker)(float *, float *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ enum { typesize = sizeof(float) };
+
+ void gemm_loop_generate(bool is_first_tile);
+ void transpose_ker_generate();
+
+ reg64_t reg_origB = abi_param2;
+ reg64_t reg_transB = abi_param1;
+
+ reg64_t reg_dstC = abi_param1;
+ reg64_t reg_srcA_const = abi_param2;
+ reg64_t reg_srcB = abi_param3;
+
+ reg64_t reg_sp = rsp;
+ reg64_t reg_srcA = r9;
+ reg64_t reg_nb_ic = r10;
+ reg64_t reg_loop_cpt = r11;
+ reg64_t reg_transB_idx = r13;
+
+ /* Registers used by new kernel */
+ reg64_t reg_dimM_block_loop_cnt = r10;
+ reg64_t reg_dimK_block_loop_cnt = r12;
+ reg64_t reg_dimN_block_loop_cnt = r11;
+};
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp
new file mode 100644
index 0000000000..abddc19221
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp
@@ -0,0 +1,1526 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_avx512_common_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+using namespace nstl;
+
+using jit_conv_ker_t = void (*)(jit_conv_call_s *);
+
+#define PIPELINE(field) \
+ do { \
+ p.field = p.field ## _prf; \
+ p.field ## _prf = field; \
+ } while (0)
+
+inline void jit_conv_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
+ const void *src, const void *dst, const void *filt, const void *bias,
+ int channel, int kh_padding)
+{
+ PIPELINE(src);
+ PIPELINE(dst);
+ PIPELINE(filt);
+ PIPELINE(bias);
+ PIPELINE(channel);
+ PIPELINE(kh_padding);
+
+ if (p.src)
+ ker(&p);
+}
+// The special case for the driver with ow-parallelization (FWD)
+// TODO: implement it for BWD_D and BWD_W too
+inline void jit_conv_ker_pipeline_ow_thr(jit_conv_ker_t ker, jit_conv_call_s &p,
+ const void *src, const void *dst, const void *filt, const void *bias,
+ int channel, int kh_padding, int owb)
+{
+ PIPELINE(src);
+ PIPELINE(dst);
+ PIPELINE(filt);
+ PIPELINE(bias);
+ PIPELINE(channel);
+ PIPELINE(kh_padding);
+ PIPELINE(owb);
+
+ if (p.src)
+ ker(&p);
+}
+
+inline void jit_conv_3d_ker_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
+ const void *src, const void *dst, const void *filt, const void *bias,
+ int channel, int kh_padding, int kd_padding)
+{
+ PIPELINE(src);
+ PIPELINE(dst);
+ PIPELINE(filt);
+ PIPELINE(bias);
+ PIPELINE(channel);
+ PIPELINE(kh_padding);
+ PIPELINE(kd_padding);
+
+ if (p.src)
+ ker(&p);
+}
+// The special case for the driver with ow-parallelization (FWD)
+// TODO: implement it for BWD_D and BWD_W too
+inline void jit_conv_3d_ker_pipeline_ow_thr(jit_conv_ker_t ker,
+ jit_conv_call_s &p, const void *src, const void *dst, const void *filt,
+ const void *bias, int channel, int kh_padding, int kd_padding, int owb)
+{
+ PIPELINE(src);
+ PIPELINE(dst);
+ PIPELINE(filt);
+ PIPELINE(bias);
+ PIPELINE(channel);
+ PIPELINE(kh_padding);
+ PIPELINE(kd_padding);
+ PIPELINE(owb);
+
+ if (p.src)
+ ker(&p);
+}
+
+void jit_conv_3d_ker_bwd_w_pipeline(jit_conv_ker_t ker, jit_conv_call_s &p,
+ const void *src, const void *dst, const void *filt, const void *bias,
+ int channel, int d_index, int d_worksize,
+ int kd_padding /* kd_work_size */, size_t kd_offset) {
+ PIPELINE(src);
+ PIPELINE(dst);
+ PIPELINE(filt);
+ PIPELINE(bias);
+ PIPELINE(channel);
+ PIPELINE(kd_padding);
+ PIPELINE(d_worksize);
+ PIPELINE(d_index);
+ PIPELINE(kd_offset);
+
+ if (p.src)
+ ker(&p);
+}
+#define wht_blk_off(d, g, ...) \
+ (pd()->with_groups() \
+ ? (d).blk_off((g), __VA_ARGS__) \
+ : (d).blk_off(__VA_ARGS__))
+
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type>
+void jit_avx512_common_convolution_fwd_t<src_type, wei_type,
+ dst_type>::prepare_padded_bias(const dst_data_t *&bias,
+ const memory_tracking::grantor_t &scratchpad) const {
+ if (!pd()->wants_padded_bias()) return;
+
+ auto padded_bias = scratchpad.template get<dst_data_t>(
+ key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, pd()->jcp_.oc_without_padding);
+ utils::array_set(padded_bias + pd()->jcp_.oc_without_padding,
+ (dst_data_t)0, pd()->jcp_.oc - pd()->jcp_.oc_without_padding);
+ bias = padded_bias;
+}
+
+template <data_type_t src_type, data_type_t wei_type,
+ data_type_t dst_type>
+void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
+execute_forward_1d(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ prepare_padded_bias(bias, this->scratchpad(ctx));
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = pd()->jcp_;
+ assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
+
+ int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
+ int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.nb_ow;
+
+ int nthr;
+ if (jcp.aligned_threads)
+ nthr = jcp.aligned_threads;
+ else
+ nthr = mkldnn_get_max_threads();
+
+ parallel(nthr, [&](const int ithr, const int nthr) {
+ int start{0}, end{0}, start_copy;
+ balance211(work_amount, nthr, ithr, start, end);
+ start_copy = start;
+
+ auto par_conv = jit_conv_call_s();
+ size_t src_c_stride = src_d.blk_off(0, 1);
+ size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
+
+ for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
+ start = start_copy;
+ int n{0}, g{0}, occ{0}, owb{0};
+
+ if (jcp.loop_order == loop_cwgn) {
+ int dummy{0};
+ nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
+ g, jcp.ngroups, n, jcp.mb, dummy, 1);
+ } else if (jcp.loop_order == loop_gncw) {
+ int dummy{0};
+ nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, occ,
+ oc_chunks, owb, jcp.nb_ow, dummy, 1);
+ } else {
+ assert(!"unsupported loop order");
+ }
+
+ while (start < end) {
+ int ocb = occ * jcp.nb_oc_blocking;
+ int g_ocb = g * jcp.nb_oc + ocb;
+ int g_oc = g_ocb * jcp.oc_block;
+ int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
+
+ int ow_s = owb * jcp.ow_block;
+ int iw_s = ow_s * jcp.stride_w;
+ auto bias_w = bias ? bias + g_oc : nullptr;
+ auto dst_w = dst + dst_d.blk_off(n, g_ocb, ow_s);
+ auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, iw_s);
+ auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
+
+ for (int icb = icb_l2;
+ icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
+ jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
+ src_w, dst_w, wht_w, bias_w, icb, 1, owb);
+
+ src_w += src_c_stride;
+ wht_w += wht_ic_stride;
+ }
+ if (jcp.loop_order == loop_cwgn) {
+ int dummy{0};
+ nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
+ g, jcp.ngroups, n, jcp.mb, dummy, 1);
+ } else if (jcp.loop_order == loop_gncw) {
+ int dummy{0};
+ nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb,
+ occ, oc_chunks, owb, jcp.nb_ow, dummy, 1);
+ } else {
+ assert(!"unsupported loop order");
+ }
+ }
+ }
+ jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
+ src, dst, weights, bias, 0, 0, 0);
+ });
+}
+
+template <data_type_t src_type, data_type_t wei_type,
+ data_type_t dst_type>
+void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
+execute_forward_2d(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ prepare_padded_bias(bias, this->scratchpad(ctx));
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = pd()->jcp_;
+ assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
+
+ int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
+ int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.oh * jcp.nb_ow;
+
+ int nthr;
+ if (jcp.aligned_threads)
+ nthr = jcp.aligned_threads;
+ else
+ nthr = mkldnn_get_max_threads();
+
+ parallel(nthr, [&](const int ithr, const int nthr) {
+ int start{0}, end{0}, start_copy;
+ balance211(work_amount, nthr, ithr, start, end);
+ start_copy = start;
+
+ auto par_conv = jit_conv_call_s();
+ size_t src_h_stride = src_d.blk_off(0, 0, 1);
+ size_t src_c_stride = src_d.blk_off(0, 1);
+ size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
+ size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
+ size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
+
+ for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
+ start = start_copy;
+ int n{0}, g{0}, occ{0}, oh_s{0}, owb{0};
+
+ if (jcp.loop_order == loop_cwgn)
+ nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow,
+ g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh);
+ else if (jcp.loop_order == loop_gncw)
+ nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb,
+ occ, oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
+ else
+ assert(!"unsupported loop order");
+
+ while (start < end) {
+ int ocb = occ * jcp.nb_oc_blocking;
+ int g_ocb = g * jcp.nb_oc + ocb;
+ int g_oc = g_ocb * jcp.oc_block;
+ int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
+
+ int work_rem = end - start;
+
+ int ow_s = owb * jcp.ow_block;
+ int iw_s = ow_s * jcp.stride_w;
+ int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
+ auto bias_w = bias ? bias + g_oc : nullptr;
+
+ for (int oh_b = oh_s; oh_b < oh_e; oh_b += jcp.h_blocking) {
+ int ih_b = -jcp.t_pad + oh_b * jcp.stride_h;
+
+ auto dst_w = dst + dst_d.blk_off(n, g_ocb, oh_b, ow_s);
+ auto src_w
+ = src + src_d.blk_off(n, g_icb + icb_l2, ih_b, iw_s);
+ auto wht_w
+ = weights + wht_blk_off(weights_d, g, ocb, icb_l2);
+
+ for (int icb = icb_l2;
+ icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2);
+ ++icb) {
+ auto src_c = src_w;
+ auto dst_c = dst_w;
+ for (int oj = oh_b, ij = ih_b;
+ oj < min(oh_e, oh_b + jcp.h_blocking);
+ ++oj, ij += jcp.stride_h) {
+ int dilate_h = jcp.dilate_h + 1;
+ int i_t_overflow = div_up(max(0, -ij), dilate_h);
+ int i_b_overflow = div_up(max(0, ij - jcp.ih
+ + (jcp.kh - 1) * dilate_h + 1), dilate_h);
+ int kh_padding = nstl::max(
+ 0, jcp.kh - i_t_overflow - i_b_overflow);
+
+ auto aux_src = src_c
+ + i_t_overflow * dilate_h * src_h_stride;
+ auto aux_wht = wht_w + i_t_overflow * wht_h_stride;
+
+ jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker,
+ par_conv, aux_src, dst_c, aux_wht, bias_w, icb,
+ kh_padding, owb);
+
+ src_c += src_h_stride * jcp.stride_h;
+ dst_c += dst_h_stride;
+ }
+ src_w += src_c_stride;
+ wht_w += wht_ic_stride;
+ }
+ }
+
+ if (jcp.loop_order == loop_cwgn)
+ nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow,
+ g, jcp.ngroups, n, jcp.mb, oh_s, jcp.oh);
+ else if (jcp.loop_order == loop_gncw)
+ nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, occ,
+ oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
+ else
+ assert(!"unsupported loop order");
+ }
+ }
+
+ jit_conv_ker_pipeline_ow_thr(kernel_->jit_ker, par_conv,
+ src, dst, weights, bias, 0, 0, 0);
+ });
+}
+
+template <data_type_t src_type, data_type_t wei_type,
+ data_type_t dst_type>
+void jit_avx512_common_convolution_fwd_t<src_type, wei_type, dst_type>::
+execute_forward_3d(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const dst_data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ prepare_padded_bias(bias, this->scratchpad(ctx));
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const auto &jcp = pd()->jcp_;
+ assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
+ int start{0}, end{0}, start_copy;
+ int work_amount = jcp.mb * jcp.ngroups * oc_chunks * jcp.od * jcp.oh
+ * jcp.nb_ow;
+ balance211(work_amount, nthr, ithr, start, end);
+ start_copy = start;
+
+ auto par_conv = jit_conv_call_s();
+ size_t src_d_stride = src_d.blk_off(0, 0, 1);
+ size_t src_h_stride = src_d.blk_off(0, 0, 0, 1);
+ size_t src_c_stride = src_d.blk_off(0, 1);
+ size_t dst_h_stride = dst_d.blk_off(0, 0, 0, 1);
+ size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
+ size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
+ size_t wht_ic_stride = wht_blk_off(weights_d, 0, 0, 1);
+
+ for (int icb_l2 = 0 ; icb_l2 < jcp.nb_ic; icb_l2 += jcp.nb_ic_L2) {
+ start = start_copy;
+ int n{0}, g{0}, occ{0}, oh_s{0}, od_s{0}, owb{0};
+
+ if (jcp.loop_order == loop_cwgn)
+ nd_iterator_init(start,
+ occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb,
+ od_s, jcp.od, oh_s, jcp.oh);
+ else if (jcp.loop_order == loop_gncw)
+ nd_iterator_init(start,
+ g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow,
+ od_s, jcp.od, oh_s, jcp.oh);
+ else
+ assert(!"unsupported loop order");
+
+ while (start < end) {
+ int ocb = occ * jcp.nb_oc_blocking;
+ int g_ocb = g * jcp.nb_oc + ocb;
+ int g_oc = g_ocb * jcp.oc_block;
+ int g_icb = g * jcp.nb_ic * jcp.nonblk_group_off;
+
+ int work_rem = end - start;
+ int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
+ int ow_s = owb * jcp.ow_block;
+ int iw_s = ow_s * jcp.stride_w;
+ int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
+
+ int id_s = -jcp.f_pad + od_s * jcp.stride_d;
+
+ int dilate_d = jcp.dilate_d + 1;
+ int d_t_overflow = div_up(max(0, -id_s), dilate_d);
+ int d_b_overflow = div_up(
+ max(0, id_s - jcp.id + (jcp.kd - 1) * dilate_d + 1),
+ dilate_d);
+ int kd_padding = nstl::max(0,
+ jcp.kd - d_t_overflow - d_b_overflow);
+
+ auto bias_w = bias ? bias + bias_d.blk_off(g_oc) : 0;
+ auto dst_w = dst + dst_d.blk_off(n, g_ocb, od_s, oh_s, ow_s);
+ auto src_w = src + src_d.blk_off(n, g_icb + icb_l2, id_s, ih_s,
+ iw_s) + d_t_overflow * dilate_d * src_d_stride;
+ auto wht_w = weights + wht_blk_off(weights_d, g, ocb, icb_l2)
+ + d_t_overflow * wht_d_stride;
+
+ for (int icb = icb_l2;
+ icb < min(jcp.nb_ic, icb_l2 + jcp.nb_ic_L2); ++icb) {
+ auto src_c = src_w;
+ auto dst_c = dst_w;
+ for (int oj = oh_s, ij = ih_s;
+ oj < oh_e; ++oj, ij += jcp.stride_h)
+ {
+ int dilate_h = jcp.dilate_h + 1;
+ int i_t_overflow = div_up(max(0, -ij), dilate_h);
+ int i_b_overflow = div_up(
+ max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h
+ + 1),
+ dilate_h);
+ int kh_padding = nstl::max(0,
+ jcp.kh - i_t_overflow - i_b_overflow);
+ jit_conv_3d_ker_pipeline_ow_thr(kernel_->jit_ker,
+ par_conv,
+ src_c + i_t_overflow * dilate_h * src_h_stride,
+ dst_c, wht_w + i_t_overflow * wht_h_stride,
+ bias_w, icb, kh_padding, kd_padding, owb);
+
+ src_c += src_h_stride * jcp.stride_h;
+ dst_c += dst_h_stride;
+ }
+ src_w += src_c_stride;
+ wht_w += wht_ic_stride;
+ }
+
+ if (jcp.loop_order == loop_cwgn)
+ nd_iterator_jump(start, end,
+ occ, oc_chunks, owb, jcp.nb_ow, g, jcp.ngroups, n, jcp.mb,
+ od_s, jcp.od, oh_s, jcp.oh);
+ else if (jcp.loop_order == loop_gncw)
+ nd_iterator_jump(start, end,
+ g, jcp.ngroups, n, jcp.mb, occ, oc_chunks, owb, jcp.nb_ow,
+ od_s, jcp.od, oh_s, jcp.oh);
+ else
+ assert(!"unsupported loop order");
+ }
+ }
+ jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
+ src, dst, weights, bias, 0, 0, 0);
+ });
+}
+
+template struct jit_avx512_common_convolution_fwd_t<data_type::f32>;
+
+template <data_type_t diff_dst_type, data_type_t wei_type,
+ data_type_t diff_src_type>
+void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
+ diff_src_type>::execute_backward_data_1d(const exec_ctx_t &ctx) const
+{
+ auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ int start{0}, end{0}, start_copy;
+ int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
+ int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
+ balance211(work_amount, nthr, ithr, start, end);
+ start_copy = start;
+
+ auto par_conv = jit_conv_call_s();
+ size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
+ size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
+
+ for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
+ start = start_copy;
+ int n{0}, g{0}, icc{0};
+ if (jcp.loop_order == loop_cgn) {
+ int dummy{0};
+ nd_iterator_init(start, icc, ic_chunks, g, jcp.ngroups, n,
+ jcp.mb, dummy, 1);
+ } else if (jcp.loop_order == loop_gnc) {
+ int dummy{0};
+ nd_iterator_init(start, g, jcp.ngroups, n, jcp.mb, icc,
+ ic_chunks, dummy, 1);
+ } else {
+ assert(!"unsupported loop order");
+ }
+
+ while (start < end) {
+ int icb = icc * jcp.nb_ic_blocking;
+ int g_icb = g * jcp.nb_ic + icb;
+ int g_ocb = g * jcp.nb_oc;
+
+ auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
+ auto diff_dst_w = diff_dst
+ + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
+ auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
+
+ for (int ocb = ocb_l2;
+ ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
+ jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
+ diff_src_w, diff_dst_w, wht_w, 0, ocb, 1);
+ diff_dst_w += diff_dst_c_stride;
+ wht_w += wht_oc_stride;
+ }
+
+ if (jcp.loop_order == loop_cgn) {
+ int dummy{0};
+ nd_iterator_jump(start, end, icc, ic_chunks, g, jcp.ngroups,
+ n, jcp.mb, dummy, 1);
+ } else if (jcp.loop_order == loop_gnc) {
+ int dummy{0};
+ nd_iterator_jump(start, end, g, jcp.ngroups, n, jcp.mb, icc,
+ ic_chunks, dummy, 1);
+ } else {
+ assert(!"unsupported loop order");
+ }
+ }
+ }
+
+ jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
+ diff_src, diff_dst, weights, 0, 0, 1);
+ });
+}
+
+template <data_type_t diff_dst_type, data_type_t wei_type,
+ data_type_t diff_src_type>
+void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
+ diff_src_type>::execute_backward_data_2d(const exec_ctx_t &ctx) const
+{
+ auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ int start{0}, end{0}, start_copy;
+ int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
+ int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.ih;
+ balance211(work_amount, nthr, ithr, start, end);
+ start_copy = start;
+
+ auto par_conv = jit_conv_call_s();
+ size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 1);
+ size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 1);
+ size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
+ size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
+ size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
+
+ bool is_fast_path = jcp.dilate_h == 0 && jcp.stride_h == 1;
+
+ for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
+ start = start_copy;
+ int n{0}, g{0}, icc{0}, ih_s{0};
+ if (jcp.loop_order == loop_cgn)
+ nd_iterator_init(start,
+ icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih);
+ else if (jcp.loop_order == loop_gnc)
+ nd_iterator_init(start,
+ g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih);
+ else
+ assert(!"unsupported loop order");
+
+ while (start < end) {
+ int icb = icc * jcp.nb_ic_blocking;
+ int g_icb = g * jcp.nb_ic + icb;
+ int g_ocb = g * jcp.nb_oc;
+
+ int work_rem = end - start;
+ int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
+
+ auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb);
+ auto diff_dst_w = diff_dst
+ + diff_dst_d.blk_off(n, g_ocb + ocb_l2);
+ auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb);
+
+ for (int ocb = ocb_l2;
+ ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
+ for (int ij = ih_s; ij < ih_e; ++ij) {
+ int oj, k_len, k_lo;
+ if (is_fast_path) { // dilate == 0 && stride == 1
+ int i_t_overflow = max(0, jcp.kh - 1 - ij
+ - jcp.t_pad);
+ int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
+ - jcp.b_pad);
+ k_len = jcp.kh - i_t_overflow - i_b_overflow;
+ k_lo = i_b_overflow;
+ oj = ij + jcp.t_pad - i_b_overflow;
+ } else if (jcp.dilate_h != 0) { // stride == 1
+ int dilate_h = jcp.dilate_h + 1;
+ // Note: use div_up to account for "holes" in filter
+ int i_t_overflow
+ = div_up(max(0, (jcp.kh - 1) * dilate_h
+ - ij - jcp.t_pad), dilate_h);
+ int i_b_overflow
+ = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
+ - jcp.ih + ij - jcp.b_pad), dilate_h);
+ k_len = jcp.kh - i_t_overflow - i_b_overflow;
+ k_lo = i_b_overflow;
+ oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
+ } else { // dilate == 0
+ int i_t_overflow = max(0, (jcp.kh - 1 - ij
+ - jcp.t_pad) / jcp.stride_h);
+ int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
+ - jcp.b_pad) / jcp.stride_h);
+ int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
+ + jcp.b_pad - ij) % jcp.stride_h);
+ int overflow_kh_lo = (ij + jcp.t_pad)
+ % jcp.stride_h;
+
+ k_len = (overflow_kh_hi - overflow_kh_lo)
+ / jcp.stride_h + 1 - i_t_overflow
+ - i_b_overflow;
+ k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
+ oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
+ }
+ assert(k_len >= 0);
+
+ jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
+ diff_src_w + ij * diff_src_h_stride,
+ diff_dst_w + oj * diff_dst_h_stride,
+ wht_w + k_lo * wht_h_stride,
+ 0, ocb, k_len);
+ }
+ diff_dst_w += diff_dst_c_stride;
+ wht_w += wht_oc_stride;
+ }
+
+ if (jcp.loop_order == loop_cgn)
+ nd_iterator_jump(start, end,
+ icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, ih_s, jcp.ih);
+ else if (jcp.loop_order == loop_gnc)
+ nd_iterator_jump(start, end,
+ g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, ih_s, jcp.ih);
+ else
+ assert(!"unsupported loop order");
+ }
+ }
+
+ jit_conv_ker_pipeline(kernel_->jit_ker, par_conv,
+ diff_src, diff_dst, weights, 0, 0, 1);
+ });
+}
+
+template <data_type_t diff_dst_type, data_type_t wei_type,
+ data_type_t diff_src_type>
+void jit_avx512_common_convolution_bwd_data_t<diff_dst_type, wei_type,
+ diff_src_type>::execute_backward_data_3d(const exec_ctx_t &ctx) const
+{
+ auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ int start{0}, end{0}, start_copy;
+ int ic_chunks = jcp.nb_ic / jcp.nb_ic_blocking;
+ int work_amount = jcp.ngroups * jcp.mb * ic_chunks * jcp.id * jcp.ih;
+ balance211(work_amount, nthr, ithr, start, end);
+ start_copy = start;
+
+ auto par_conv = jit_conv_call_s();
+ size_t diff_src_h_stride = diff_src_d.blk_off(0, 0, 0, 1);
+ size_t diff_src_d_stride = diff_src_d.blk_off(0, 0, 1);
+ size_t diff_dst_h_stride = diff_dst_d.blk_off(0, 0, 0, 1);
+ size_t diff_dst_d_stride = diff_dst_d.blk_off(0, 0, 1);
+ size_t diff_dst_c_stride = diff_dst_d.blk_off(0, 1);
+ size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 0, 1);
+ size_t wht_d_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
+ size_t wht_oc_stride = wht_blk_off(weights_d, 0, 1);
+
+ bool is_fast_path_d = jcp.dilate_d == 0 && jcp.stride_d == 1;
+ bool is_fast_path_h = jcp.dilate_h == 0 && jcp.stride_h == 1;
+
+ for (int ocb_l2 = 0; ocb_l2 < jcp.nb_oc; ocb_l2 += jcp.nb_oc_L2) {
+ start = start_copy;
+ int n{0}, g{0}, icc{0}, ih_s{0}, id_s{0};
+ if (jcp.loop_order == loop_cgn)
+ nd_iterator_init(start,
+ icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id,
+ ih_s, jcp.ih);
+ else if (jcp.loop_order == loop_gnc)
+ nd_iterator_init(start,
+ g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id,
+ ih_s, jcp.ih);
+ else
+ assert(!"unsupported loop order");
+
+ while (start < end) {
+ int icb = icc * jcp.nb_ic_blocking;
+ int g_icb = g * jcp.nb_ic + icb;
+ int g_ocb = g * jcp.nb_oc;
+
+ int work_rem = end - start;
+ int ih_e = ih_s + work_rem > jcp.ih ? jcp.ih : ih_s + work_rem;
+ int d_len = 0, d_lo = 0, d_oj = 0;
+ if (is_fast_path_d) { // dilate == 0 && stride == 1
+ int d_t_overflow = max(0, jcp.kd - 1 - id_s
+ - jcp.f_pad);
+ int d_b_overflow = max(0, jcp.kd - jcp.id + id_s
+ - jcp.back_pad);
+ d_len = jcp.kd - d_t_overflow - d_b_overflow;
+ d_lo = d_b_overflow;
+ d_oj = id_s + jcp.f_pad - d_b_overflow;
+ } else if (jcp.dilate_d != 0) { // stride == 1
+ int dilate_d = jcp.dilate_d + 1;
+ // Note: use div_up to account for "holes" in filter
+ int d_t_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d
+ - id_s - jcp.f_pad), dilate_d);
+ int d_b_overflow = div_up(max(0, (jcp.kd - 1) * dilate_d + 1
+ - jcp.id + id_s - jcp.back_pad), dilate_d);
+ d_len = jcp.kd - d_t_overflow - d_b_overflow;
+ d_lo = d_b_overflow;
+ d_oj = id_s + jcp.f_pad - d_b_overflow * dilate_d;
+ } else { // dilate == 0
+ int d_t_overflow = max(0, (jcp.kd - 1 - id_s
+ - jcp.f_pad) / jcp.stride_d);
+ int d_b_overflow = max(0, (jcp.kd - jcp.id + id_s
+ - jcp.back_pad) / jcp.stride_d);
+ int overflow_kd_hi = jcp.kd - 1 - abs((jcp.id - 1
+ + jcp.back_pad - id_s) % jcp.stride_d);
+ int overflow_kd_lo = (id_s + jcp.f_pad)
+ % jcp.stride_d;
+
+ d_len = (overflow_kd_hi - overflow_kd_lo)
+ / jcp.stride_d + 1 - d_t_overflow
+ - d_b_overflow;
+ d_lo = overflow_kd_lo + d_b_overflow * jcp.stride_d;
+ d_oj = (id_s + jcp.f_pad - d_lo) / jcp.stride_d;
+ }
+ assert(d_len >= 0);
+
+ auto diff_src_w = diff_src + diff_src_d.blk_off(n, g_icb)
+ + id_s * diff_src_d_stride;
+ auto diff_dst_w = diff_dst
+ + diff_dst_d.blk_off(n, g_ocb + ocb_l2)
+ + d_oj * diff_dst_d_stride;
+ auto wht_w = weights + wht_blk_off(weights_d, g, ocb_l2, icb)
+ + d_lo * wht_d_stride;
+
+ for (int ocb = ocb_l2;
+ ocb < min(jcp.nb_oc, ocb_l2 + jcp.nb_oc_L2); ++ocb) {
+ for (int ij = ih_s; ij < ih_e; ++ij) {
+ int oj, k_len, k_lo;
+ if (is_fast_path_h) { // dilate == 0 && stride == 1
+ int i_t_overflow = max(0, jcp.kh - 1 - ij
+ - jcp.t_pad);
+ int i_b_overflow = max(0, jcp.kh - jcp.ih + ij
+ - jcp.b_pad);
+ k_len = jcp.kh - i_t_overflow - i_b_overflow;
+ k_lo = i_b_overflow;
+ oj = ij + jcp.t_pad - i_b_overflow;
+ } else if (jcp.dilate_h != 0) { // stride == 1
+ int dilate_h = jcp.dilate_h + 1;
+ // Note: use div_up to account for "holes" in filter
+ int i_t_overflow
+ = div_up(max(0, (jcp.kh - 1) * dilate_h
+ - ij - jcp.t_pad), dilate_h);
+ int i_b_overflow
+ = div_up(max(0, (jcp.kh - 1) * dilate_h + 1
+ - jcp.ih + ij - jcp.b_pad), dilate_h);
+ k_len = jcp.kh - i_t_overflow - i_b_overflow;
+ k_lo = i_b_overflow;
+ oj = ij + jcp.t_pad - i_b_overflow * dilate_h;
+ } else { // dilate == 0
+ int i_t_overflow = max(0, (jcp.kh - 1 - ij
+ - jcp.t_pad) / jcp.stride_h);
+ int i_b_overflow = max(0, (jcp.kh - jcp.ih + ij
+ - jcp.b_pad) / jcp.stride_h);
+ int overflow_kh_hi = jcp.kh - 1 - abs((jcp.ih - 1
+ + jcp.b_pad - ij) % jcp.stride_h);
+ int overflow_kh_lo = (ij + jcp.t_pad)
+ % jcp.stride_h;
+
+ k_len = (overflow_kh_hi - overflow_kh_lo)
+ / jcp.stride_h + 1 - i_t_overflow
+ - i_b_overflow;
+ k_lo = overflow_kh_lo + i_b_overflow * jcp.stride_h;
+ oj = (ij + jcp.t_pad - k_lo) / jcp.stride_h;
+ }
+ assert(k_len >= 0);
+
+ jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
+ diff_src_w + ij * diff_src_h_stride,
+ diff_dst_w + oj * diff_dst_h_stride,
+ wht_w + k_lo * wht_h_stride,
+ 0, ocb, k_len, d_len);
+ }
+ diff_dst_w += diff_dst_c_stride;
+ wht_w += wht_oc_stride;
+ }
+
+ if (jcp.loop_order == loop_cgn)
+ nd_iterator_jump(start, end,
+ icc, ic_chunks, g, jcp.ngroups, n, jcp.mb, id_s, jcp.id,
+ ih_s, jcp.ih);
+ else if (jcp.loop_order == loop_gnc)
+ nd_iterator_jump(start, end,
+ g, jcp.ngroups, n, jcp.mb, icc, ic_chunks, id_s, jcp.id,
+ ih_s, jcp.ih);
+ else
+ assert(!"unsupported loop order");
+ }
+ }
+
+ jit_conv_3d_ker_pipeline(kernel_->jit_ker, par_conv,
+ diff_src, diff_dst, weights, 0, 0, 1, 1);
+ });
+}
+
+template struct jit_avx512_common_convolution_bwd_data_t<data_type::f32>;
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::
+jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd)
+ : cpu_primitive_t(apd), kernel_(nullptr)
+ , trans_kernel_(nullptr), acc_ker_(nullptr), reducer_bias_(nullptr)
+{
+ const auto &j = pd()->jcp_;
+
+ nthr_ = j.nthr;
+ nthr_mb_ = j.nthr_mb;
+ nthr_g_ = j.nthr_g;
+ nthr_oc_b_ = j.nthr_oc_b;
+ nthr_ic_b_ = j.nthr_ic_b;
+
+ kernel_ = new jit_avx512_common_conv_bwd_weights_kernel_f32(j);
+
+ if (j.ver == ver_4fma)
+ trans_kernel_ = create_trans_src(&j);
+
+ if (nthr_mb_ > 1)
+ acc_ker_ = new cpu_accumulator_1d_t<diff_weights_type>();
+
+ reducer_bias_ =
+ new cpu_reducer_t<diff_weights_type>(pd()->reducer_bia_conf_);
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+struct jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::thread_info_t {
+ const src_data_t *src;
+ const diff_dst_data_t *diff_dst;
+ const diff_weights_data_t *diff_weights;
+ diff_weights_data_t *diff_bias;
+
+ const memory_tracking::grantor_t scratchpad;
+
+ src_data_t *tr_src;
+ simple_barrier::ctx_t *tr_src_bctx;
+
+ diff_dst_data_t *tr_diff_dst;
+ simple_barrier::ctx_t *tr_diff_dst_bctx;
+
+ diff_weights_data_t *wei_bia_reduction;
+ simple_barrier::ctx_t *wei_bia_reduction_bctx;
+
+ int ithr;
+ int ithr_ic_b, ithr_oc_b, ithr_g, ithr_mb;
+ int ithr_but_oc;
+ int ithr_but_ic;
+
+ int img_start = 0, img_end = 0, img_work;
+ int g_start = 0, g_end = 0, g_work;
+ int oc_b_start = 0, oc_b_end = 0, oc_b_work;
+ int ic_b_start = 0, ic_b_end = 0, ic_b_work;
+
+ thread_info_t(const jit_avx512_common_convolution_bwd_weights_t *self,
+ const exec_ctx_t &ctx, int ithr)
+ : scratchpad(self->scratchpad(ctx)), ithr(ithr)
+ {
+ diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
+ src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ diff_weights = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
+ diff_bias = self->pd()->wants_padded_bias()
+ ? scratchpad.template get<diff_weights_data_t>(
+ key_conv_padded_bias)
+ : CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
+ tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_tr_src_bctx);
+
+ tr_diff_dst = scratchpad.template get<diff_dst_data_t>(
+ key_conv_tr_diff_dst);
+ tr_diff_dst_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_tr_diff_dst_bctx);
+
+ wei_bia_reduction = scratchpad.template get<diff_weights_data_t>(
+ key_conv_wei_bia_reduction);
+ wei_bia_reduction_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_wei_bia_reduction_bctx);
+
+ ithr_ic_b = ithr % self->nthr_ic_b_;
+ ithr_oc_b = ithr / self->nthr_ic_b_ % self->nthr_oc_b_;
+ ithr_g = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ % self->nthr_g_;
+ ithr_mb = ithr / self->nthr_ic_b_ / self->nthr_oc_b_ / self->nthr_g_;
+
+ ithr_but_oc = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_ic_b_
+ + ithr_ic_b;
+
+ ithr_but_ic = (ithr_mb * self->nthr_g_ + ithr_g) * self->nthr_oc_b_
+ + ithr_oc_b;
+
+ const auto &jcp = self->kernel_->jcp;
+
+ /* reduction dimension */
+ balance211(jcp.mb*jcp.od, self->nthr_mb_, ithr_mb, img_start, img_end);
+ img_work = img_end - img_start;
+
+ /* independent dimensions */
+ balance211(jcp.ngroups, self->nthr_g_, ithr_g, g_start, g_end);
+ g_work = g_end - g_start;
+
+ balance211(jcp.nb_oc, self->nthr_oc_b_, ithr_oc_b, oc_b_start,
+ oc_b_end);
+ oc_b_work = oc_b_end - oc_b_start;
+
+ balance211(jcp.nb_ic, self->nthr_ic_b_, ithr_ic_b, ic_b_start,
+ ic_b_end);
+ ic_b_work = ic_b_end - ic_b_start;
+ }
+};
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::compute_diff_weights(const thread_info_t *ti) const {
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+ const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh*jcp.kw*jcp.kd;
+
+ diff_weights_data_t *diff_wei = ti->ithr_mb == 0
+ ? (diff_weights_data_t*)ti->diff_weights
+ : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
+ diff_weights_data_t *diff_bia = ti->ithr_mb == 0
+ ? (diff_weights_data_t*)ti->diff_bias
+ : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
+ + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
+
+ // TODO: use memory descriptor with the same fmt as src (or use a macro :))
+ auto tr_src_off = [&](int ithr_mb, int ic, int ij) {
+ const size_t tr_row_size = jcp.tr_iw * jcp.ic_block;
+ const size_t tr_chn_size = tr_row_size * jcp.ih;
+ const size_t tr_img_size = tr_chn_size * jcp.nb_ic * jcp.ngroups;
+
+ return ti->ithr_mb * tr_img_size + ic * tr_chn_size + ij * tr_row_size;
+ };
+
+ auto uker_trans = [&](int img) {
+ const int work_amount = ti->g_work * ti->ic_b_work * jcp.ih;
+
+ int start{0}, end{0};
+ balance211(work_amount, nthr_oc_b_, ti->ithr_oc_b, start, end);
+ const int my_work = end - start;
+
+ int g{0}, ic_b{0}, j{0};
+ nd_iterator_init(start, g, ti->g_work, ic_b, ti->ic_b_work, j, jcp.ih);
+ g += ti->g_start;
+ ic_b += ti->ic_b_start;
+
+ const int _ic = g * jcp.nb_ic + ic_b;
+ src_data_t *src1 = (src_data_t*)&ti->src[src_d.blk_off(img, _ic, j)];
+ src_data_t *tr_src1 = &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, j)];
+
+ assert(jcp.ic_block == 16);
+ const int src_stride = jcp.iw * jcp.ic_block;
+ const int tr_src_stride = jcp.tr_iw * jcp.ic_block;
+
+ const int pf_depth = 2;
+ struct { src_data_t *src, *tr_src; } pf_circ_buf[pf_depth];
+
+ for (int iwork = 0; iwork < my_work + pf_depth - 1; iwork++) {
+ pf_circ_buf[iwork % pf_depth] = {src1, tr_src1};
+
+ if (iwork >= pf_depth - 1) {
+ int old_idx = (iwork - pf_depth + 1) % pf_depth;
+ auto ctx = jit_trans_src_t::ctx_t();
+ ctx.src = pf_circ_buf[old_idx].src;
+ ctx.tr_src = pf_circ_buf[old_idx].tr_src;
+ ctx.src_prf = src1;
+ ctx.tr_src_prf = tr_src1;
+ (*trans_kernel_)(&ctx);
+ }
+ src1 += src_stride;
+ tr_src1 += tr_src_stride;
+ }
+#if 0
+ // reference transposition
+ const int l_pad = jcp.l_pad;
+ const int iwlp = l_pad + jcp.iw;
+ const int tr_iw = jcp.tr_iw;
+
+ for (size_t iwork = start; iwork < end; iwork++) {
+ PRAGMA_OMP_SIMD()
+# pragma unroll
+ for (int i = 0; i < l_pad; i++)
+ for (int j = 0; j < jcp.ic_block; j++)
+ tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
+
+ PRAGMA_OMP_SIMD()
+# pragma unroll
+ for (int i = l_pad; i < iwlp; i++)
+ for (int j = 0; j < jcp.ic_block; j++)
+ tr_src1[j * jcp.tr_iw + i]
+ = (src_data_t)src1[(i - l_pad) * 16 + j];
+
+ PRAGMA_OMP_SIMD()
+# pragma unroll
+ for (int i = iwlp; i < tr_iw; i++)
+ for (int j = 0; j < jcp.ic_block; j++)
+ tr_src1[j * jcp.tr_iw + i] = (src_data_t)0.0;
+
+ src1 += src_stride;
+ tr_src1 += tr_src_stride;
+ }
+#endif
+ };
+
+ if (jcp.is_1stconv && jcp.ver == ver_4fma) {
+ /* prepare contexts */
+ auto tr_ctx = jit_trans_src_t::ctx_t();
+ tr_ctx.tr_src = ti->tr_src
+ + ti->ithr_but_oc * jcp.ih * jcp.stride_w * jcp.tr_ld;
+
+ assert(IMPLICATION(!mkldnn_thr_syncable(), nthr_oc_b_ == 1));
+ tr_ctx.nthr_oc_b = nthr_oc_b_;
+ int ih_start{0}, ih_end{0};
+ balance211(jcp.ih, nthr_oc_b_, ti->ithr_oc_b, ih_start, ih_end);
+ tr_ctx.tr_src_ih_start = ih_start;
+ tr_ctx.tr_src_ih_end = ih_end;
+ tr_ctx.tr_src_bctx = ti->tr_src_bctx + ti->ithr_but_oc;
+
+ auto p = jit_conv_call_s();
+ p.src = tr_ctx.tr_src;
+
+ /* zero diff_bias if applicable */
+ if (jcp.with_bias && ti->ithr_ic_b == 0) {
+ assert(jcp.oc_block == 16);
+ for (int oc_b = ti->ic_b_start; oc_b < ti->oc_b_end; ++oc_b) {
+ diff_weights_data_t *db = &diff_bia[oc_b * 16];
+ for (int o = 0; o < 16; ++o)
+ db[o] = 0;
+ }
+ }
+
+ for (int img = ti->img_start; img < ti->img_end; ++img) {
+ p.flags = (img == ti->img_start) * FLAG_MB_FIRST;
+
+ for (int g = ti->g_start; g < ti->g_end; ++g) {
+ for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
+ const int _ic = g * jcp.nb_ic + ic_b;
+ tr_ctx.src = &ti->src[src_d.blk_off(img, _ic)];
+
+ (*trans_kernel_)(&tr_ctx);
+
+ if (ic_b == 0)
+ p.flags |= FLAG_IC_FIRST;
+ else
+ p.flags &= ~FLAG_IC_FIRST;
+
+ for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
+ const int _oc = g * jcp.nb_oc + oc_b;
+ p.dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
+
+ const size_t off =
+ wht_blk_off(diff_weights_d, g, oc_b, ic_b);
+ p.filt = diff_wei + off;
+ p.bias = diff_bia + _oc * jcp.oc_block;
+
+ kernel_->jit_ker(&p);
+ }
+ }
+ }
+ }
+ } else {
+ for (int img = ti->img_start; img < ti->img_end; ++img) {
+ auto p = jit_conv_call_s();
+
+ if (jcp.ver == ver_4fma) {
+ /* tr_src[nb_ic][ih][16][~iw~] <- src[nb_ic][ih][iw][16] */
+ using simple_barrier::barrier;
+ if (nthr_oc_b_ > 1)
+ barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
+ uker_trans(img);
+ if (nthr_oc_b_ > 1)
+ barrier(&ti->tr_src_bctx[ti->ithr_but_oc], nthr_oc_b_);
+ }
+
+ for (int g = ti->g_start; g < ti->g_end; ++g) {
+ for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
+ for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
+ const int _oc = g * jcp.nb_oc + oc_b;
+ const int _ic = g * jcp.nb_ic + ic_b;
+
+ jit_conv_ker_pipeline(kernel_->jit_ker, p,
+ jcp.ver == ver_4fma
+ ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
+ : &ti->src[src_d.blk_off(img, _ic)],
+ &ti->diff_dst[diff_dst_d.blk_off(img, _oc)],
+ diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
+ 0, (img == ti->img_start), 0);
+
+ }
+ }
+ }
+
+ const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
+ const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
+ jit_conv_ker_pipeline(kernel_->jit_ker, p,
+ jcp.ver == ver_4fma
+ ? &ti->tr_src[tr_src_off(ti->ithr_mb, _ic, 0)]
+ : &ti->src[src_d.blk_off(img + 1, _ic)],
+ &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
+ diff_wei + wht_blk_off(
+ diff_weights_d, ti->g_start,
+ ti->oc_b_start, ti->ic_b_start),
+ 0, 0, 0);
+ }
+ }
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::compute_diff_weights_3d(const thread_info_t *ti) const
+{
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+ const int wei_size
+ = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw * jcp.kd;
+
+ diff_weights_data_t *diff_wei = ti->ithr_mb == 0
+ ? (diff_weights_data_t*)ti->diff_weights
+ : ti->wei_bia_reduction + (ti->ithr_mb - 1) * wei_size;
+ diff_weights_data_t *diff_bia = ti->ithr_mb == 0
+ ? (diff_weights_data_t*)ti->diff_bias
+ : ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size
+ + (ti->ithr_mb - 1) * jcp.ngroups * jcp.oc;
+
+ const int inp_mult = jcp.is_1stconv ? 1 : jcp.ic_block;
+ const int input_step = jcp.ih * jcp.iw * inp_mult;
+ const int output_step = jcp.ow * jcp.oh * jcp.oc_block;
+ int img{0}, od_s{0};
+ int img_start = ti->img_start, img_end = ti->img_end;
+ nd_iterator_init(img_start, img, jcp.mb, od_s, jcp.od);
+ const int img_first = img;
+
+ while (img_start < img_end) {
+ auto p = jit_conv_call_s();
+
+ int work_rem = img_end - img_start;
+ const int od_e = od_s + work_rem > jcp.od ? jcp.od : od_s + work_rem;
+ const int id_s = od_s * jcp.stride_d;
+ const int ik_overlap = nstl::max(0, id_s - jcp.f_pad);
+ const int kd_front_pad = nstl::max(0, jcp.f_pad - id_s);
+ const int kd_back_pad
+ = nstl::max(0, id_s - jcp.f_pad - jcp.id + jcp.kd);
+ int kd_pad_off = nstl::min(jcp.kd - 1, kd_front_pad) * jcp.kh * jcp.kw
+ * jcp.ic_block * jcp.oc_block * jcp.typesize_out;
+
+ for (int g = ti->g_start; g < ti->g_end; ++g) {
+ for (int oc_b = ti->oc_b_start; oc_b < ti->oc_b_end; ++oc_b) {
+ for (int ic_b = ti->ic_b_start; ic_b < ti->ic_b_end; ++ic_b) {
+ const int _oc = g * jcp.nb_oc + oc_b;
+ const int _ic = g * jcp.nb_ic + ic_b;
+
+ auto src = &ti->src[src_d.blk_off(img, _ic)
+ + ik_overlap * input_step];
+ auto dst = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)
+ + od_s * output_step];
+
+ jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p, src, dst,
+ diff_wei + wht_blk_off(diff_weights_d, g, oc_b, ic_b),
+ diff_bia + _oc * 16, (img == img_first), od_s, od_e,
+ jcp.kd - kd_front_pad - kd_back_pad, kd_pad_off);
+
+ if (ic_b == 0) p.flags = 0;
+ else p.flags = 1;
+ }
+ }
+ }
+
+ const int _oc = ti->g_start * jcp.nb_oc + ti->oc_b_start;
+ const int _ic = ti->g_start * jcp.nb_ic + ti->ic_b_start;
+ jit_conv_3d_ker_bwd_w_pipeline(kernel_->jit_ker, p,
+ &ti->src[src_d.blk_off(img + 1, _ic)],
+ &ti->diff_dst[diff_dst_d.blk_off(img + 1, _oc)],
+ diff_wei + wht_blk_off(diff_weights_d, ti->g_start,
+ ti->oc_b_start, ti->ic_b_start),
+ diff_bia, 0, 0, 0, 0, 0);
+ nd_iterator_jump(img_start, img_end, img, jcp.mb, od_s, jcp.od);
+ }
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::reduce_diff_weights(const thread_info_t *ti) const {
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+ const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
+ const int bia_size = jcp.ngroups * jcp.oc;
+ const diff_weights_data_t *diff_bias_ws
+ = ti->wei_bia_reduction + (nthr_mb_ - 1) * wei_size;
+
+ /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
+ simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
+
+ const int ic_b_kh_work = ti->ic_b_work * jcp.kh;
+ const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
+
+ int start{0}, end{0};
+ balance211(work, nthr_mb_, ti->ithr_mb, start, end);
+ if (start == end) return;
+
+ for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
+ int w = start;
+ int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
+ nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
+ ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
+ while (w < end) {
+ const int g = ti->g_start + sub_g_start;
+ const int oc_b = ti->oc_b_start + sub_oc_b_start;
+ const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kh;
+ const int kh = sub_ic_b_kh_start % jcp.kh;
+
+ const int acc_size
+ = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
+ * jcp.kw * jcp.ic_block * jcp.oc_block;
+
+ const size_t off
+ = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kh);
+
+ diff_weights_data_t *d
+ = (diff_weights_data_t *)ti->diff_weights + off;
+ diff_weights_data_t *s
+ = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
+
+ acc_ker_->accumulate(d, s, acc_size);
+
+ nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
+ ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
+ }
+
+ if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) {
+ if (ti->ithr == 0)
+ acc_ker_->accumulate((diff_weights_data_t *)ti->diff_bias,
+ diff_bias_ws, bia_size);
+ diff_bias_ws += bia_size;
+ }
+ }
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::reduce_diff_weights_3d(const thread_info_t *ti) const {
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+ const int wei_size = jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw
+ * jcp.kd;
+
+ /* diff_weights[:] += sum(wei_reduction_[thr_mb][:]) */
+ simple_barrier::barrier(ti->wei_bia_reduction_bctx, nthr_);
+
+ const int ic_b_kh_work = ti->ic_b_work * jcp.kd;
+ const int work = ti->g_work * ti->oc_b_work * ic_b_kh_work;
+
+ int start{0}, end{0};
+ balance211(work, nthr_mb_, ti->ithr_mb, start, end);
+ if (start == end) return;
+
+ for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
+ int w = start;
+ int sub_g_start{0}, sub_oc_b_start{0}, sub_ic_b_kh_start{0};
+ nd_iterator_init(w, sub_g_start, ti->g_work, sub_oc_b_start,
+ ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
+ while (w < end) {
+ const int g = ti->g_start + sub_g_start;
+ const int oc_b = ti->oc_b_start + sub_oc_b_start;
+ const int ic_b = ti->ic_b_start + sub_ic_b_kh_start / jcp.kd;
+ const int kd = sub_ic_b_kh_start % jcp.kd;
+
+ const int acc_size
+ = nstl::min(end - w, ic_b_kh_work - sub_ic_b_kh_start)
+ * jcp.kw * jcp.ic_block * jcp.oc_block * jcp.kh;
+
+ const size_t off
+ = wht_blk_off(diff_weights_d, g, oc_b, ic_b, kd);
+ diff_weights_data_t *d
+ = (diff_weights_data_t *)ti->diff_weights + off;
+ diff_weights_data_t *s
+ = ti->wei_bia_reduction + (thr_mb - 1) * wei_size + off;
+ acc_ker_->accumulate(d, s, acc_size);
+
+ nd_iterator_jump(w, end, sub_g_start, ti->g_work, sub_oc_b_start,
+ ti->oc_b_work, sub_ic_b_kh_start, ic_b_kh_work);
+ }
+ }
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::compute_diff_bias(const thread_info_t *ti) const {
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+
+ auto rb = this->reducer_bias_;
+ assert(nthr_ == rb->balancer().nthr_);
+
+ const auto reducer_bia_scratchpad = memory_tracking::grantor_t(
+ ti->scratchpad, prefix_reducer_bia);
+
+ const auto &jcp = kernel_->jcp;
+
+ if (jcp.with_bias && jcp.is_1stconv && jcp.ver == ver_4fma) return;
+
+ const int b_job_start = rb->balancer().ithr_job_off(ti->ithr);
+ const int b_njobs = rb->balancer().ithr_njobs(ti->ithr);
+
+ if (b_njobs == 0) return;
+
+ /* reduction dimension */
+ int img_start{0}, img_end{0};
+ balance211(jcp.mb, rb->balancer().nthr_per_group_,
+ rb->balancer().id_in_group(ti->ithr), img_start, img_end);
+
+ /* jobs */
+ int g_start{0}, ocb_start{0};
+ nd_iterator_init(b_job_start, g_start, jcp.ngroups, ocb_start, jcp.nb_oc);
+ for (int img = img_start; img < img_end; ++img) {
+ int g = g_start, ocb = ocb_start;
+ for (int b_job_loc = 0; b_job_loc < b_njobs; ++b_job_loc) {
+ const size_t _oc = g * jcp.nb_oc + ocb;
+
+ const diff_dst_data_t *d_dst
+ = &ti->diff_dst[diff_dst_d.blk_off(img, _oc)];
+ diff_weights_data_t *d_bias = rb->get_local_ptr(ti->ithr,
+ ti->diff_bias, reducer_bia_scratchpad)
+ + b_job_loc * rb->balancer().job_size_;
+
+ if (img == img_start)
+ for (int o = 0; o < 16; ++o)
+ d_bias[o] = 0;
+ for (int hw = 0; hw < jcp.oh * jcp.ow * jcp.od; ++hw) {
+ PRAGMA_OMP_SIMD()
+ for (int o = 0; o < 16; ++o)
+ d_bias[o] += d_dst[o];
+ d_dst += 16;
+ }
+
+ nd_iterator_step(g, jcp.ngroups, ocb, jcp.nb_oc);
+ }
+ }
+
+ rb->reduce(ti->ithr, ti->diff_bias, reducer_bia_scratchpad);
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::compute_diff_bias_3d(const thread_info_t *ti) const {
+
+ const auto &jcp = kernel_->jcp;
+
+ const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh
+ * jcp.kw * jcp.kd;
+ const int bia_size = jcp.ngroups * jcp.oc;
+ const diff_weights_data_t *diff_bias_ws
+ = ti->wei_bia_reduction + (size_t)(nthr_mb_ - 1) * wei_size;
+
+ if (nthr_mb_ > 1) mkldnn_thr_barrier();
+
+ if (ti->ithr == 0)
+ {
+ for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
+ acc_ker_->accumulate(ti->diff_bias, diff_bias_ws, bia_size);
+ diff_bias_ws += bia_size;
+ }
+ }
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::prepare_scratchpad_data(const exec_ctx_t &ctx) const
+{
+ const auto &j = pd()->jcp_;
+ auto scratchpad = this->scratchpad(ctx);
+
+ if (j.ver == ver_4fma) {
+ if (!j.is_1stconv) {
+ // XXX: See the comment about tr_iw and guarding elements in
+ // jit_avx512_common_conv_bwd_weights_kernel_f32::init_conf()
+ const int max_nthr = j.nthr_mb * j.ngroups * j.nb_ic;
+ const int min_tr_src_size_per_thr = j.ih * j.ic_block * j.tr_iw;
+
+ auto tr_src = scratchpad.template get<src_data_t>(key_conv_tr_src);
+ /* to avoid NaNs in computations we zero tail num_guard_elems for
+ * each possible thread group */
+
+ for (int ithr = 1; ithr <= max_nthr; ++ithr) {
+ src_data_t *ts = &tr_src[ithr * min_tr_src_size_per_thr];
+ for (int i = 0; i < j.tr_src_num_guard_elems; ++i)
+ ts[i] = 0;
+ }
+ }
+
+ if (j.nthr_oc_b > 1) {
+ const int tr_src_bctx_size = j.nthr / j.nthr_oc_b;
+ auto tr_src_bctx = scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_tr_src_bctx);
+ for (int i = 0; i < tr_src_bctx_size; ++i)
+ simple_barrier::ctx_init(&tr_src_bctx[i]);
+ }
+ }
+
+ if (nthr_mb_ > 1) {
+ simple_barrier::ctx_init(scratchpad.template get<simple_barrier::ctx_t>(
+ key_conv_wei_bia_reduction_bctx));
+ }
+
+ const auto reducer_bia_scratchpad = memory_tracking::grantor_t(scratchpad,
+ prefix_reducer_bia);
+ auto rb = this->reducer_bias_;
+ rb->init(reducer_bia_scratchpad);
+}
+
+template <data_type_t src_type, data_type_t diff_dst_type,
+ data_type_t diff_weights_type>
+void jit_avx512_common_convolution_bwd_weights_t<src_type, diff_dst_type,
+ diff_weights_type>::execute_backward_weights(const exec_ctx_t &ctx) const {
+ prepare_scratchpad_data(ctx);
+
+ parallel(nthr_, [&](const int ithr, const int nthr) {
+ assert(nthr_ == nthr);
+
+ thread_info_t thread_info(this, ctx, ithr);
+
+ if (utils::one_of(pd()->ndims(), 3, 4)) {
+ compute_diff_weights(&thread_info);
+ if (nthr_mb_ > 1) reduce_diff_weights(&thread_info);
+ if (pd()->with_bias()) compute_diff_bias(&thread_info);
+ } else if (pd()->ndims() == 5) {
+ compute_diff_weights_3d(&thread_info);
+ if (nthr_mb_ > 1) reduce_diff_weights_3d(&thread_info);
+ if (pd()->with_bias()) compute_diff_bias_3d(&thread_info);
+ } else {
+ assert(false);
+ }
+ });
+
+ /* TODO: put that into compute_diff_bias() */
+ if (pd()->wants_padded_bias()) {
+ auto diff_bias = scratchpad(ctx).template get<const diff_weights_data_t>(
+ key_conv_padded_bias);
+ auto diff_bias_in = CTX_OUT_MEM(diff_weights_data_t *, MKLDNN_ARG_DIFF_BIAS);
+ for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
+ diff_bias_in[oc] = diff_bias[oc];
+ }
+}
+
+template struct jit_avx512_common_convolution_bwd_weights_t<data_type::f32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp
new file mode 100644
index 0000000000..3341c3ebe0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution.hpp
@@ -0,0 +1,302 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP
+#define CPU_JIT_AVX512_COMMON_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "cpu_barrier.hpp"
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+#include "cpu_reducer.hpp"
+
+#include "jit_transpose_src_utils.hpp"
+#include "jit_avx512_common_conv_kernel.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t src_type,
+ impl::data_type_t wei_type = src_type,
+ impl::data_type_t dst_type = src_type>
+struct jit_avx512_common_convolution_fwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_()
+ {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
+ jit_avx512_common_convolution_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(src_type, wei_type, dst_type, dst_type,
+ data_type::undef)
+ && !has_zero_dim_memory();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx512_common_conv_fwd_kernel::init_conf(
+ jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_,
+ *attr(), mkldnn_get_max_threads());
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx512_common_conv_fwd_kernel::init_scratchpad(scratchpad,
+ jcp_);
+
+ return status;
+ }
+
+ jit_conv_conf_t jcp_;
+ };
+
+ jit_avx512_common_convolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ {
+ kernel_ = new jit_avx512_common_conv_fwd_kernel(pd()->jcp_,
+ *pd()->attr());
+ }
+ ~jit_avx512_common_convolution_fwd_t() { delete kernel_; }
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<wei_type>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ if (pd()->ndims() == 3)
+ execute_forward_1d(ctx);
+ else if (pd()->ndims() == 4)
+ execute_forward_2d(ctx);
+ else if (pd()->ndims() == 5)
+ execute_forward_3d(ctx);
+ else
+ assert(false);
+
+ if (pd()->wants_zero_pad_dst())
+ ctx.memory(MKLDNN_ARG_DST)->zero_pad();
+
+ return status::success;
+ }
+
+private:
+ void prepare_padded_bias(const dst_data_t *&bias,
+ const memory_tracking::grantor_t &scratchpad) const;
+ void execute_forward_1d(const exec_ctx_t &ctx) const;
+ void execute_forward_2d(const exec_ctx_t &ctx) const;
+ void execute_forward_3d(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx512_common_conv_fwd_kernel *kernel_;
+};
+
+template <impl::data_type_t diff_dst_type,
+ impl::data_type_t wei_type = diff_dst_type,
+ impl::data_type_t diff_src_type = diff_dst_type>
+struct jit_avx512_common_convolution_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_data_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_()
+ {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
+ jit_avx512_common_convolution_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(diff_src_type, wei_type,
+ data_type::undef, diff_dst_type, data_type::undef)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ status_t status =
+ jit_avx512_common_conv_bwd_data_kernel_f32::init_conf(jcp_,
+ *desc(), *diff_src_md(), *weights_md(), *diff_dst_md());
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx512_common_conv_bwd_data_kernel_f32::init_scratchpad(
+ scratchpad, jcp_);
+
+ return status::success;
+ }
+
+ jit_conv_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
+ auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
+ OIw16o16i, gOIw16o16i, OIhw16o16i, gOIhw16o16i,
+ OIdhw16o16i, gOIdhw16o16i);
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ jit_avx512_common_convolution_bwd_data_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ { kernel_ = new jit_avx512_common_conv_bwd_data_kernel_f32(pd()->jcp_); }
+ ~jit_avx512_common_convolution_bwd_data_t() { delete kernel_; };
+
+ typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
+ typedef typename prec_traits<wei_type>::type wei_data_t;
+ typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ if (pd()->ndims() == 3)
+ execute_backward_data_1d(ctx);
+ else if (pd()->ndims() == 4)
+ execute_backward_data_2d(ctx);
+ else if (pd()->ndims() == 5)
+ execute_backward_data_3d(ctx);
+ else
+ assert(false);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data_1d(const exec_ctx_t &ctx) const;
+ void execute_backward_data_2d(const exec_ctx_t &ctx) const;
+ void execute_backward_data_3d(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx512_common_conv_bwd_data_kernel_f32 *kernel_;
+};
+
+template <impl::data_type_t src_type,
+ impl::data_type_t diff_dst_type = src_type,
+ impl::data_type_t diff_weights_type = src_type>
+struct jit_avx512_common_convolution_bwd_weights_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_weights_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
+ jit_avx512_common_convolution_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_weights
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(src_type, diff_weights_type,
+ diff_weights_type, diff_dst_type, data_type::undef)
+ && !has_zero_dim_memory();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx512_common_conv_bwd_weights_kernel_f32::
+ init_conf(jcp_, *desc(), src_md_, diff_weights_md_,
+ diff_bias_md_, diff_dst_md_);
+ if (status != status::success) return status;
+
+ init_balancers();
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx512_common_conv_bwd_weights_kernel_f32::init_scratchpad(
+ scratchpad, jcp_);
+
+ auto reducer_bia_scratchpad = memory_tracking::registrar_t(
+ scratchpad, memory_tracking::names::prefix_reducer_bia);
+ reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
+
+ return status;
+ }
+
+ jit_conv_conf_t jcp_;
+ typename cpu_reducer_t<diff_weights_type>::conf_t reducer_bia_conf_;
+
+ private:
+ void init_balancers() {
+ const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
+ if (with_bias()) {
+ reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
+ jcp_.oc_block, jcp_.ngroups * jcp_.nb_oc, jcp_.mb,
+ max_buffer_size));
+ }
+ }
+ };
+
+ jit_avx512_common_convolution_bwd_weights_t(const pd_t *apd);
+ ~jit_avx512_common_convolution_bwd_weights_t() {
+ delete kernel_;
+ if (trans_kernel_)
+ delete trans_kernel_;
+ if (acc_ker_)
+ delete acc_ker_;
+ delete reducer_bias_;
+ }
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
+ typedef typename prec_traits<diff_weights_type>::type diff_weights_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_weights(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_weights(const exec_ctx_t &ctx) const;
+ void prepare_scratchpad_data(const exec_ctx_t &ctx) const;
+ struct thread_info_t;
+ void compute_diff_weights(const thread_info_t *) const;
+ void compute_diff_weights_3d(const thread_info_t *) const;
+ void reduce_diff_weights(const thread_info_t *) const;
+ void reduce_diff_weights_3d(const thread_info_t *) const;
+ void compute_diff_bias(const thread_info_t *) const;
+ void compute_diff_bias_3d(const thread_info_t *) const;
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ int nthr_, nthr_mb_, nthr_g_, nthr_oc_b_, nthr_ic_b_;
+
+ jit_avx512_common_conv_bwd_weights_kernel_f32 *kernel_;
+ jit_trans_src_t *trans_kernel_;
+ cpu_accumulator_1d_t<diff_weights_type> *acc_ker_;
+ cpu_reducer_t<diff_weights_type> *reducer_bias_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp
new file mode 100644
index 0000000000..62247c0264
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp
@@ -0,0 +1,1215 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifdef __INTEL_COMPILER
+#include <immintrin.h>
+#endif
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_avx512_common_convolution_winograd.hpp"
+
+#ifndef _MSC_VER
+#define pragma_unroll _Pragma("unroll")
+#else
+#define pragma_unroll
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace memory_tracking::names;
+
+namespace {
+
+unsigned int LLC_cache_size = get_cache_size(3, false);
+
+void inline load_ps(float *dest, const float *src_mem) {
+#ifdef __INTEL_COMPILER
+ __m512 *Iv512 = (__m512 *)dest;
+ Iv512[0] = _mm512_load_ps(src_mem);
+#else
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) dest[v] = src_mem[v];
+#endif
+}
+
+void inline store_output(float *dest, const float *data, bool streamout) {
+#ifdef __INTEL_COMPILER
+ if (streamout)
+ _mm512_stream_ps(dest, *((__m512 *)data));
+ else
+ _mm512_store_ps(dest, *((__m512 *)data));
+#else
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++)
+ dest[v] = data[v];
+#endif
+}
+
+void inline accum_output(
+ float *dest, float *data, bool streamout, bool with_relu_postsum) {
+#ifdef __INTEL_COMPILER
+ __m512 _data = _mm512_loadu_ps(data);
+ __m512 _dest = _mm512_loadu_ps(dest);
+ _data = _mm512_add_ps(_data, _dest);
+ if (with_relu_postsum)
+ _data = _mm512_max_ps(_data, _mm512_setzero_ps());
+ if (streamout)
+ _mm512_stream_ps(dest, _data);
+ else
+ _mm512_store_ps(dest, _data);
+#else
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++)
+ data[v] += dest[v];
+
+ if (with_relu_postsum) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++)
+ if (data[v] < 0.f)
+ data[v] = 0.f;
+ }
+
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++)
+ dest[v] = data[v];
+#endif
+}
+}
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+
+void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]) {
+ float Fw[6][16];
+ float T[6][3][16];
+ float t0[16];
+ float t1[16];
+ float t2[16];
+
+ for (int j = 0; j < 16; j++) {
+#pragma unroll
+ for (int i = 0; i < 3; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int k = 0; k < 16; k++) {
+ t0[k] = 0.26890756302521f * F[2][i][j][k];
+ t1[k] = -t0[k] - 0.688403361344538f * F[0][i][j][k];
+ t2[k] = t0[k] + 0.119514472455649f * F[0][i][j][k];
+
+ T[0][i][k] = 1.13777777777778f * F[0][i][j][k];
+ T[1][i][k] = t1[k] - 0.430252100840336f * F[1][i][j][k];
+ T[2][i][k] = t1[k] + 0.430252100840336f * F[1][i][j][k];
+ T[3][i][k] = t2[k] + 0.179271708683473f * F[1][i][j][k];
+ T[4][i][k] = t2[k] - 0.179271708683473f * F[1][i][j][k];
+ T[5][i][k] = F[2][i][j][k];
+ }
+ }
+#pragma unroll
+ for (int i = 0; i < 6; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int k = 0; k < 16; k++) {
+ t0[k] = 0.26890756302521f * T[i][2][k];
+ t1[k] = -t0[k] - 0.688403361344538f * T[i][0][k];
+ t2[k] = t0[k] + 0.119514472455649f * T[i][0][k];
+
+ Fw[0][k] = 1.13777777777778f * T[i][0][k];
+ Fw[1][k] = t1[k] - 0.430252100840336f * T[i][1][k];
+ Fw[2][k] = t1[k] + 0.430252100840336f * T[i][1][k];
+ Fw[3][k] = t2[k] + 0.179271708683473f * T[i][1][k];
+ Fw[4][k] = t2[k] - 0.179271708683473f * T[i][1][k];
+ Fw[5][k] = T[i][2][k];
+#pragma unroll
+ for (int l = 0; l < 6; l++) {
+ Fw_[i][l][j][k] = Fw[l][k];
+ }
+ }
+ }
+ }
+}
+
+void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]) {
+ float T[4][6][16];
+ float t0[16];
+ float t1[16];
+ float t2[16];
+ float t3[16];
+
+#pragma unroll
+ for (int i = 0; i < 6; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < 16; v++) {
+ t0[v] = Mw[1][i][v] + Mw[2][i][v];
+ t1[v] = Mw[3][i][v] + Mw[4][i][v];
+ t2[v] = Mw[1][i][v] - Mw[2][i][v];
+ t3[v] = Mw[3][i][v] - Mw[4][i][v];
+
+ T[0][i][v] = t0[v] + t1[v] + Mw[0][i][v];
+ T[1][i][v] = t2[v] * 0.625f + t3[v] * 1.5f;
+ T[2][i][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
+ T[3][i][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + Mw[5][i][v];
+ }
+ }
+#pragma unroll
+ for (int i = 0; i < 4; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < 16; v++) {
+ t0[v] = T[i][1][v] + T[i][2][v];
+ t1[v] = T[i][3][v] + T[i][4][v];
+ t2[v] = T[i][1][v] - T[i][2][v];
+ t3[v] = T[i][3][v] - T[i][4][v];
+
+ O[i][0][v] = t0[v] + t1[v] + T[i][0][v];
+ O[i][1][v] = t2[v] * 0.625f + t3[v] * 1.5f;
+ O[i][2][v] = t0[v] * 0.390625f + t1[v] * 2.25f;
+ O[i][3][v] = t2[v] * 0.244140625f + t3[v] * 3.375f + T[i][5][v];
+ }
+ }
+}
+
+
+void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16])
+{
+ const float rcp3 = 1.0f / 3.0f;
+ const float rcp4 = 1.0f / 4.0f;
+ const float rcp6 = 1.0f / 6.0f;
+ const float rcp12 = 1.0f / 12.0f;
+ const float rcp24 = 1.0f / 24.0f;
+ float t0[16];
+ float t1[16];
+ float t2[16];
+ float t3[16];
+ float t4[16];
+ float T[6][4][16];
+
+pragma_unroll
+ for (int i = 0; i < 4; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < 16; j++) {
+ t0[j] = F[2][i][j] * rcp6;
+ t1[j] = F[0][i][j] * -rcp6 - t0[j];
+ t2[j] = F[0][i][j] * rcp24 + t0[j];
+ t3[j] = (F[1][i][j] + F[3][i][j]) * rcp6;
+ t4[j] = F[1][i][j] * rcp12 + F[3][i][j] * rcp3;
+
+ T[0][i][j] = F[0][i][j] * rcp4;
+ T[1][i][j] = t1[j] - t3[j];
+ T[2][i][j] = t1[j] + t3[j];
+ T[3][i][j] = t2[j] + t4[j];
+ T[4][i][j] = t2[j] - t4[j];
+ T[5][i][j] = F[3][i][j];
+ }
+ }
+pragma_unroll
+ for (int i = 0; i < 6; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < 16; j++) {
+ t0[j] = T[i][2][j] * rcp6;
+ t1[j] = T[i][0][j] * -rcp6 - t0[j];
+ t2[j] = T[i][0][j] * rcp24 + t0[j];
+ t3[j] = (T[i][1][j] + T[i][3][j]) * rcp6;
+ t4[j] = T[i][1][j] * rcp12 + T[i][3][j] * rcp3;
+
+ Fw[i][0][j] = T[i][0][j] * rcp4;
+ Fw[i][1][j] = t1[j] - t3[j];
+ Fw[i][2][j] = t1[j] + t3[j];
+ Fw[i][3][j] = t2[j] + t4[j];
+ Fw[i][4][j] = t2[j] - t4[j];
+ Fw[i][5][j] = T[i][3][j];
+ }
+ }
+}
+
+void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16])
+{
+ float T[4][6][16];
+ float M_[3][16];
+ float t0[16];
+ float t1[16];
+ float t2[16];
+
+ for (int j = 0; j < 16; j++) {
+pragma_unroll
+ for (int i = 0; i < 6; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int l = 0; l < 16; l++) {
+ t0[l] = Mw[1][i][j][l] + Mw[2][i][j][l];
+ t1[l] = Mw[3][i][j][l] + Mw[4][i][j][l];
+ t2[l] = t1[l] * 4.0f + Mw[5][i][j][l];
+
+ T[0][i][l] = Mw[0][i][j][l] + t0[l] + t1[l];
+ T[1][i][l] = (Mw[1][i][j][l] - Mw[2][i][j][l]) +
+ 2.0f * (Mw[3][i][j][l] - Mw[4][i][j][l]);
+ T[2][i][l] = t0[l] + t2[l];
+ }
+ }
+pragma_unroll
+ for (int i = 0; i < 3; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int l = 0; l < 16; l++) {
+ t0[l] = T[i][1][l] + T[i][2][l];
+ t1[l] = T[i][3][l] + T[i][4][l];
+ t2[l] = t1[l] * 4.0f + T[i][5][l];
+
+ M_[0][l] = T[i][0][l] + t0[l] + t1[l];
+ M_[1][l] = (T[i][1][l] - T[i][2][l]) +
+ 2.0f * (T[i][3][l] - T[i][4][l]);
+ M_[2][l] = t0[l] + t2[l];
+
+ for (int k = 0; k < 3; k++) {
+ M[i][k][j][l] = M_[k][l];
+ }
+ }
+ }
+ }
+}
+
+void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16])
+{
+ float T[6][6][16];
+ float t0[16];
+ float t1[16];
+ float t2[16];
+ float t3[16];
+ float t4[16];
+ float t5[16];
+
+pragma_unroll
+ for (int i = 0; i < 6; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < 16; v++) {
+ t0[v] = I[2][i][v] * -2.25f + I[4][i][v];
+ t1[v] = I[1][i][v] * -2.25f + I[3][i][v];
+ t2[v] = I[2][i][v] * -0.390625f + I[4][i][v];
+ t3[v] = I[1][i][v] * -0.390625f + I[3][i][v];
+ t4[v] = I[0][i][v] * 0.87890625f + I[4][i][v];
+ t5[v] = I[1][i][v] * 0.87890625f + I[5][i][v];
+
+ T[0][i][v] = I[2][i][v] * -2.640625f + t4[v];
+ T[1][i][v] = t1[v] * 0.625f + t0[v];
+ T[2][i][v] = t1[v] * -0.625f + t0[v];
+ T[3][i][v] = t3[v] * 1.5f + t2[v];
+ T[4][i][v] = t3[v] * -1.5f + t2[v];
+ T[5][i][v] = I[3][i][v] * -2.640625f + t5[v];
+ }
+ }
+
+pragma_unroll
+ for (int i = 0; i < 6; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < 16; v++) {
+ t0[v] = T[i][2][v] * -2.25f + T[i][4][v];
+ t1[v] = T[i][1][v] * -2.25f + T[i][3][v];
+ t2[v] = T[i][2][v] * -0.390625f + T[i][4][v];
+ t3[v] = T[i][1][v] * -0.390625f + T[i][3][v];
+ t4[v] = T[i][0][v] * 0.87890625f + T[i][4][v];
+ t5[v] = T[i][1][v] * 0.87890625f + T[i][5][v];
+
+ Iw[i][0][v] = T[i][2][v] * -2.640625f + t4[v];
+ Iw[i][1][v] = t1[v] * 0.625f + t0[v];
+ Iw[i][2][v] = t1[v] * -0.625f + t0[v];
+ Iw[i][3][v] = t3[v] * 1.5f + t2[v];
+ Iw[i][4][v] = t3[v] * -1.5f + t2[v];
+ Iw[i][5][v] = T[i][3][v] * -2.640625f + t5[v];
+ }
+ }
+}
+
+void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16])
+{
+ float T[6][4][16];
+ float t0[16];
+ float t1[16];
+ float t2[16];
+ float t3[16];
+ float t4[16];
+
+pragma_unroll
+ for (int i = 0; i < 4; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < 16; v++) {
+ t0[v] = F[2][i][v] * 0.26890756302521f;
+ t1[v] = F[0][i][v] * -0.688403361344538f - t0[v];
+ t2[v] = F[0][i][v] * 0.119514472455649f + t0[v];
+ t3[v] = F[1][i][v] * 0.430252100840336f +
+ F[3][i][v] * 0.168067226890756f;
+ t4[v] = F[1][i][v] * 0.179271708683473f +
+ F[3][i][v] * 0.403361344537815f;
+
+ T[0][i][v] = F[0][i][v] * 1.13777777777778f;
+ T[1][i][v] = t1[v] - t3[v];
+ T[2][i][v] = t1[v] + t3[v];
+ T[3][i][v] = t2[v] + t4[v];
+ T[4][i][v] = t2[v] - t4[v];
+ T[5][i][v] = F[3][i][v];
+ }
+ }
+pragma_unroll
+ for (int i = 0; i < 6; i++) {
+ for (int v = 0; v < 16; v++) {
+ t0[v] = T[i][2][v] * 0.26890756302521f;
+ t1[v] = T[i][0][v] * -0.688403361344538f - t0[v];
+ t2[v] = T[i][0][v] * 0.119514472455649f + t0[v];
+ t3[v] = T[i][1][v] * 0.430252100840336f +
+ T[i][3][v] * 0.168067226890756f;
+ t4[v] = T[i][1][v] * 0.179271708683473f +
+ T[i][3][v] * 0.403361344537815f;
+
+ Fw[i][0][v] = T[i][0][v] * 1.13777777777778f;
+ Fw[i][1][v] = t1[v] - t3[v];
+ Fw[i][2][v] = t1[v] + t3[v];
+ Fw[i][3][v] = t2[v] + t4[v];
+ Fw[i][4][v] = t2[v] - t4[v];
+ Fw[i][5][v] = T[i][3][v];
+ }
+ }
+}
+
+void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16])
+{
+ float T[3][6][16];
+ float t0[16];
+ float t1[16];
+ float t2[16];
+ float M_[3][16];
+
+ for (int j = 0; j < 16; j++) {
+pragma_unroll
+ for (int i = 0; i < 6; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < 16; v++) {
+ t0[v] = Mw[1][i][j][v] + Mw[2][i][j][v];
+ t1[v] = Mw[3][i][j][v] + Mw[4][i][j][v];
+ t2[v] = t1[v] * 2.25f + Mw[5][i][j][v];
+
+ T[0][i][v] = Mw[0][i][j][v] + t0[v] + t1[v];
+ T[1][i][v] = 0.625f * (Mw[1][i][j][v] - Mw[2][i][j][v]) +
+ 1.5f * (Mw[3][i][j][v] - Mw[4][i][j][v]);
+ T[2][i][v] = t0[v] * 0.390625f + t2[v];
+ }
+ }
+pragma_unroll
+ for (int i = 0; i < 3; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < 16; v++) {
+ t0[v] = T[i][1][v] + T[i][2][v];
+ t1[v] = T[i][3][v] + T[i][4][v];
+ t2[v] = t1[v] * 2.25f + T[i][5][v];
+
+ M_[0][v] = T[i][0][v] + t0[v] + t1[v];
+ M_[1][v] = 0.625f * (T[i][1][v] - T[i][2][v]) +
+ 1.5f * (T[i][3][v] - T[i][4][v]);
+ M_[2][v] = t0[v] * 0.390625f + t2[v];
+ }
+
+pragma_unroll
+ for (int k = 0; k < 3; k++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < 16; v++) {
+ M[i][k][j][v] = M_[k][v];
+ }
+ }
+ }
+ }
+}
+
+template <bool is_fwd>
+void input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
+ float *inp, float *tinp, bool streamout = true)
+{
+ const int inpw = is_fwd ? jcp.iw : jcp.ow;
+ const int inph = is_fwd ? jcp.ih : jcp.oh;
+ const int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
+ const int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
+ const int wp_max = inpw + l_pad;
+ const int hp_max = inph + t_pad;
+ float Iw[alpha][alpha][simd_w];
+ float I[alpha][alpha][simd_w];
+
+ array_offset_calculator<float, 5> input(inp,
+ jcp.mb, jcp.dimK/simd_w, inph, inpw,
+ simd_w);
+ array_offset_calculator<float, 8> output(tinp,
+ jcp.dimN_nb_block, alpha, alpha,
+ jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
+ jcp.dimN_reg_block, jcp.dimK_reg_block);
+
+ int tile_base_index = image * jcp.itiles * jcp.jtiles;
+ int tile_block_ur = tile_base_index % jcp.tile_block_ur;
+ int nb_tile_block_ur =
+ (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
+ int tile_block =
+ (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
+
+ for (int tj = 0; tj < jcp.jtiles; tj++) {
+ for (int ti = 0; ti < jcp.itiles; ti++) {
+ for (int j = 0; j < alpha; j++) {
+ int ydim = tj * tile_size + j;
+ if ((t_pad <= ydim) && (ydim < hp_max)) {
+ float *pinp_j = inp + (ydim - t_pad) * inpw * 16 ;
+ for (int i = 0; i < alpha; i++) {
+ int xdim = ti * tile_size + i;
+ if ((l_pad <= xdim) && (xdim < wp_max)) {
+ float *pinp_i = pinp_j + (xdim - l_pad) * 16;
+ load_ps(I[j][i], pinp_i);
+ } else {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ I[j][i][v] = 0.0f;
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < alpha; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ I[j][i][v] = 0.0f;
+ }
+ }
+ }
+ }
+
+ trans_I_4x4_3x3(Iw, I);
+
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ store_output(&(output(tile_block, j, i,
+ nb_tile_block_ur, 0, 0,
+ tile_block_ur, 0)),
+ Iw[j][i], streamout);
+ }
+ }
+ tile_block_ur++;
+ if (tile_block_ur >= jcp.tile_block_ur) {
+ tile_block_ur = 0;
+ nb_tile_block_ur++;
+ }
+ if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
+ nb_tile_block_ur = 0;
+ tile_block++;
+ }
+ }
+ }
+}
+
+template <bool is_fwd>
+void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
+ float *wp, float *twp)
+{
+ const int kh = 3;
+ const int kw = 3;
+ array_offset_calculator<float, 6> input(wp,
+ jcp.oc/jcp.oc_simd_block,
+ jcp.ic/jcp.ic_simd_block,
+ jcp.kh, jcp.kw,
+ simd_w, simd_w);
+ array_offset_calculator<float, 8> output(twp,
+ jcp.dimM_nb_block,
+ alpha, alpha,
+ jcp.dimK_nb_block,
+ jcp.dimM_block, jcp.dimK_block,
+ simd_w, simd_w);
+ float Fw[alpha][alpha][simd_w][simd_w];
+ float F[kh][kw][simd_w][simd_w];
+
+ for (int j = 0; j < kh; j++) {
+ for (int i = 0; i < kw; i++) {
+ for (int v1 = 0; v1 < simd_w; v1++) {
+ float *base_inp = is_fwd
+ ? &(input(0, 0, j, i, v1, 0))
+ : &(input(0, 0, 2 - j, 2 - i, v1, 0));
+ PRAGMA_OMP_SIMD()
+ for (int v2 = 0; v2 < simd_w; v2++) {
+ if (is_fwd)
+ F[j][i][v1][v2] = *(base_inp + v2);
+ else
+ F[j][i][v2][v1] = *(base_inp + v2);
+ }
+ }
+ }
+ }
+
+ trans_W_4x4_3x3(Fw, F);
+
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ for (int v1 = 0; v1 < simd_w; v1++) {
+ PRAGMA_OMP_SIMD()
+ for (int v2 = 0; v2 < simd_w; v2++) {
+ output(0, j, i, 0, 0, 0, v1, v2) = Fw[j][i][v1][v2];
+ }
+ }
+ }
+ }
+}
+
+template <bool is_fwd, bool with_bias, bool with_relu_presum, bool with_sum>
+void output_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
+ const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias,
+ bool streamout = true) {
+ float Ow[alpha][alpha][simd_w];
+ float O[tile_size][tile_size][simd_w];
+ int outw = is_fwd ? jcp.ow : jcp.iw;
+ int outh = is_fwd ? jcp.oh : jcp.ih;
+
+ /* Prepare for PostOps */
+ bool with_relu_postsum = p_ops.find(primitive_kind::eltwise, 1) != -1;
+
+ array_offset_calculator<float, 8> input(toutp,
+ jcp.dimN_nb_block, jcp.dimM_nb_block,
+ alpha, alpha,
+ jcp.dimN_block, jcp.dimM_block,
+ jcp.dimN_reg_block, jcp.dimM_simd_block);
+
+ int tile_base_index = image * jcp.itiles * jcp.jtiles;
+ int tile_block_ur = tile_base_index % jcp.tile_block_ur;
+ int nb_tile_block_ur =
+ (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
+ int tile_block =
+ (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
+
+ for (int tj = 0; tj < jcp.jtiles; tj++) {
+ for (int ti = 0; ti < jcp.itiles; ti++) {
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ Ow[j][i][v] = input(tile_block, 0,
+ j, i,
+ nb_tile_block_ur, 0,
+ tile_block_ur, v);
+ }
+ }
+ }
+
+ trans_O_4x4_3x3(Ow, O);
+
+ for (int j = 0; j < tile_size; j++) {
+ int ydim = tj * tile_size + j;
+ if (ydim < outh) {
+ float *pout_j = pout_b + ydim * outw * simd_w;
+ for (int i = 0; i < tile_size; i++) {
+ int xdim = ti * tile_size + i;
+ if (xdim < outw) {
+ float *pout_i = pout_j + xdim * simd_w;
+ if (is_fwd) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ O[j][i][v] += with_bias ? bias[v] : 0.f;
+ O[j][i][v] = true
+ && with_relu_presum && O[j][i][v] < 0.f
+ ? O[j][i][v]
+ * jcp.eltwise.alpha
+ : O[j][i][v];
+ }
+ }
+ if (with_sum)
+ accum_output(pout_i, O[j][i], streamout,
+ with_relu_postsum);
+ else
+ store_output(pout_i, O[j][i], streamout);
+ }
+ }
+ }
+ }
+ tile_block_ur++;
+ if (tile_block_ur >= jcp.tile_block_ur) {
+ tile_block_ur = 0;
+ nb_tile_block_ur++;
+ }
+ if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
+ nb_tile_block_ur = 0;
+ tile_block++;
+ }
+ }
+ }
+}
+
+template <bool ver_4fma>
+void diff_src_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
+ float *inp, float *tinp, float *Iw_temp,
+ void (*transpose_4fma_ker)(float *, float *))
+{
+
+ const int ifwp = conv.iw + conv.l_pad;
+ const int ifhp = conv.ih + conv.t_pad;
+ float I[alpha][alpha][simd_w];
+ float Iw[alpha][alpha][simd_w];
+
+ array_offset_calculator<float, 4> Iw_trans_temp(Iw_temp,
+ alpha, alpha, conv.tile_4fma, simd_w);
+ array_offset_calculator<float, 5> input(inp,
+ conv.mb, conv.ic/simd_w, conv.ih, conv.iw, simd_w);
+ array_offset_calculator<float, 8> output(tinp,
+ conv.nb_ic, alpha, alpha,
+ conv.tile_block, conv.ic_block,
+ conv.nb_tile_block_ur, conv.tile_block_ur,
+ conv.ic_simd_block * conv.tile_4fma);
+
+ int tile_base_index =
+ image * (conv.itiles * conv.jtiles + conv.tile_4fma_padding);
+ int tile_4fma = 0;
+ int tile_block_ur = (tile_base_index / conv.tile_4fma) % conv.tile_block_ur;
+ int nb_tile_block_ur =
+ (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
+ % conv.nb_tile_block_ur;
+ int tile_block = (tile_base_index / conv.tile_4fma / conv.tile_block_ur)
+ / conv.nb_tile_block_ur;
+
+ for (int tj = 0; tj < conv.jtiles; tj++) {
+ for (int ti = 0; ti < conv.itiles; ti++) {
+ for (int j = 0; j < alpha; j++) {
+ int ydim = tj * tile_size + j;
+ if ((conv.t_pad <= ydim) && ydim < ifhp) {
+ for (int i = 0; i < alpha; i++) {
+ int xdim = ti * tile_size + i;
+ if ((conv.l_pad <= xdim) && xdim < ifwp) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ I[j][i][v] = input(0, 0,
+ ydim - conv.t_pad,
+ xdim - conv.l_pad, v);
+ }
+ } else {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ I[j][i][v] = 0.0f;
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < alpha; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ I[j][i][v] = 0.0f;
+ }
+ }
+ }
+ }
+ trans_I_4x4_3x3(Iw, I);
+
+ if (ver_4fma) {
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ float *Iw_temp_base = &(Iw_trans_temp(j, i,
+ tile_4fma, 0));
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ Iw_temp_base[v] = Iw[j][i][v];
+ }
+ }
+ }
+ tile_4fma++;
+ if (tile_4fma == conv.tile_4fma) {
+ float *outp = &(output(0, 0, 0,
+ tile_block, 0,
+ nb_tile_block_ur, tile_block_ur, 0));
+ transpose_4fma_ker(outp, (float *)Iw_temp);
+ tile_4fma = 0;
+ tile_block_ur++;
+ }
+ } else {
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ store_output(&(output(0, j, i,
+ tile_block, 0,
+ nb_tile_block_ur, tile_block_ur, 0)),
+ Iw[j][i], true);
+ }
+ }
+ tile_block_ur++;
+ }
+
+ if (tile_block_ur == conv.tile_block_ur) {
+ tile_block_ur = 0;
+ ++nb_tile_block_ur;
+ }
+ if (nb_tile_block_ur == conv.nb_tile_block_ur) {
+ nb_tile_block_ur = 0;
+ tile_block++;
+ }
+ }
+ }
+
+ if (ver_4fma && tile_4fma < conv.tile_4fma && conv.tile_4fma_padding != 0) {
+
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ for (int tb = tile_4fma; tb < conv.tile_4fma; tb++) {
+ float *Iw_temp_base = &(Iw_trans_temp(j, i, tb, 0));
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ Iw_temp_base[v] = 0;
+ }
+ }
+ }
+ }
+ float *outp = &(output(0, 0, 0,
+ tile_block, 0,
+ nb_tile_block_ur, tile_block_ur, 0));
+ transpose_4fma_ker(outp, (float *)Iw_temp);
+ }
+}
+
+template <bool with_bias>
+void diff_dst_transform_bwd_weights(int image, jit_conv_winograd_conf_t conv,
+ float *inp, float *tinp, float *dbias)
+{
+
+ const int total_tiles = conv.itiles * conv.jtiles + conv.tile_4fma_padding;
+ float I[alpha][alpha][simd_w];
+ float Iw[alpha][alpha][simd_w];
+
+ array_offset_calculator<float, 5> input(inp,
+ conv.mb, conv.oc/simd_w, conv.oh, conv.ow, conv.oc_simd_block);
+ array_offset_calculator<float, 8> output(tinp,
+ conv.nb_oc, alpha, alpha,
+ conv.tile_block, conv.oc_block,
+ conv.nb_tile_block_ur,
+ conv.tile_block_ur * conv.tile_4fma, conv.oc_simd_block);
+
+ int tile_base_index = image * total_tiles;
+ int tile_block_ur = tile_base_index % (conv.tile_block_ur * conv.tile_4fma);
+ int nb_tile_block_ur =
+ (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
+ % conv.nb_tile_block_ur;
+ int tile_block = (tile_base_index / conv.tile_block_ur / conv.tile_4fma)
+ / conv.nb_tile_block_ur;
+
+ for (int tj = 0; tj < conv.jtiles; tj++) {
+ for (int ti = 0; ti < conv.itiles; ti++) {
+ for (int j = 0; j < alpha; j++) {
+ int ydim = tj * tile_size + j;
+ if (ydim < conv.oh) {
+ for (int i = 0; i < alpha; i++) {
+ int xdim = ti * tile_size + i;
+ if (xdim < conv.ow) {
+ float *input_base = &(input(0, 0, ydim, xdim, 0));
+
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ I[j][i][v] = input_base[v];
+ }
+ if (with_bias && j < tile_size && i < tile_size) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ dbias[v] += input_base[v];
+ }
+ }
+ } else {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ I[j][i][v] = 0.0f;
+ }
+ }
+ }
+ } else {
+ for (int i = 0; i < alpha; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ I[j][i][v] = 0.0f;
+ }
+ }
+ }
+ }
+
+ trans_W_3x3_4x4_wu(Iw, I);
+
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ store_output(&(output(0, j, i,
+ tile_block, 0,
+ nb_tile_block_ur,
+ tile_block_ur, 0)),
+ Iw[j][i], true);
+ }
+ }
+ tile_block_ur++;
+ if (tile_block_ur >= conv.tile_block_ur * conv.tile_4fma) {
+ tile_block_ur = 0;
+ nb_tile_block_ur++;
+ }
+ if (nb_tile_block_ur >= conv.nb_tile_block_ur) {
+ nb_tile_block_ur = 0;
+ tile_block++;
+ }
+ }
+ }
+}
+
+void diff_weights_transform_bwd_weights(jit_conv_winograd_conf_t conv,
+ float *wp, float *twp)
+{
+ const int kh = 3;
+ const int kw = 3;
+ float Fw[alpha][alpha][simd_w][simd_w];
+ float F[kh][kw][simd_w][simd_w];
+
+ array_offset_calculator<float, 8> input(twp,
+ conv.nb_ic, conv.nb_oc,
+ alpha, alpha,
+ conv.oc_block, conv.ic_block,
+ conv.ic_simd_block, conv.oc_simd_block);
+ array_offset_calculator<float, 6> output(wp,
+ conv.oc/simd_w, conv.ic/simd_w,
+ conv.kh, conv.kw,
+ conv.ic_simd_block, conv.oc_simd_block);
+
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ for (int v = 0; v < conv.ic_simd_block; v++) {
+ PRAGMA_OMP_SIMD()
+ for (int k = 0; k < conv.oc_simd_block; k++) {
+ Fw[j][i][v][k] = input(0, 0, j, i, 0, 0, v, k);
+ }
+ }
+ }
+ }
+
+ trans_O_3x3_4x4_wu(Fw, F);
+
+ for (int j = 0; j < kh; j++) {
+ for (int i = 0; i < kw; i++) {
+ for (int v = 0; v < conv.ic_simd_block; v++) {
+ store_output(&(output(0, 0, j, i, v, 0)),
+ F[j][i][v], true);
+ }
+ }
+ }
+}
+
+template <bool is_fwd>
+void _jit_avx512_common_convolution_winograd_t<is_fwd>::_execute_data_W_S_G_D(
+ float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const auto &jcp = kernel_->jcp;
+ const auto &p_ops = attr_->post_ops_;
+
+ const int inph = is_fwd ? jcp.ih : jcp.oh;
+ const int inpw = is_fwd ? jcp.iw : jcp.ow;
+ const int outh = is_fwd ? jcp.oh : jcp.ih;
+ const int outw = is_fwd ? jcp.ow : jcp.iw;
+
+ /* Note that jcp.with_eltwise is true for both fused conv+relu primitive
+ * and conv primitive with PostOps with relu before sum
+ * (PostOps relu after sum is handled later) */
+ auto output_transform = jcp.with_bias
+ ? (jcp.with_eltwise
+ ? (jcp.with_sum
+ ? output_transform_data<is_fwd, true, true, true>
+ : output_transform_data<is_fwd, true, true, false>)
+ : (jcp.with_sum
+ ? output_transform_data<is_fwd, true, false, true>
+ : output_transform_data<is_fwd, true, false, false>))
+ : (jcp.with_eltwise
+ ? (jcp.with_sum
+ ? output_transform_data<is_fwd, false, true, true>
+ : output_transform_data<is_fwd, false, true, false>)
+ : (jcp.with_sum
+ ? output_transform_data<is_fwd, false, false, true>
+ : output_transform_data<is_fwd, false, false, false>));
+
+ /* Notation:
+ FWD: dimM:oc, dimN:ntiles, dimK:ic,
+ BWD: dimM:ic, dimN:ntiles, dimK:oc,
+ FWD/BWD: V: src/diff_dst transform, U:weight transform,
+ M:dst/diff_src transform */
+ array_offset_calculator<float, 5> input(inp_ptr,
+ jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
+ jcp.dimK_reg_block);
+ array_offset_calculator<float, 5> output(out_ptr,
+ jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw,
+ jcp.dimM_simd_block);
+ array_offset_calculator<float, 6> weights(wei_ptr,
+ jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
+ jcp.ic_simd_block, jcp.oc_simd_block);
+ array_offset_calculator<float, 2> bias(bias_ptr,
+ jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
+
+ array_offset_calculator<float, 8> M(is_fwd
+ ? scratchpad.template get<float>(key_wino_M)
+ : scratchpad.template get<float>(key_wino_V),
+ jcp.dimN_nb_block, jcp.dimM_nb_block,
+ alpha, alpha,
+ jcp.dimN_block, jcp.dimM_block,
+ jcp.dimN_reg_block, jcp.dimM_simd_block);
+ array_offset_calculator<float, 8> U(
+ scratchpad.template get<float>(key_wino_U),
+ jcp.dimM_nb_block,
+ alpha, alpha,
+ jcp.dimK_nb_block,
+ jcp.dimM_block, jcp.dimK_block,
+ jcp.dimK_reg_block, jcp.dimM_simd_block);
+ array_offset_calculator<float, 8> V(is_fwd
+ ? scratchpad.template get<float>(key_wino_V)
+ : scratchpad.template get<float>(key_wino_M),
+ jcp.dimN_nb_block, alpha, alpha,
+ jcp.dimN_block, jcp.dimK_nb_block,
+ jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
+
+ bool V_streamout = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
+ > 2 * LLC_cache_size ? true : false;
+
+ const bool output_is_aligned = ((size_t)out_ptr & (64 - 1)) == 0;
+
+ const bool wants_padded_bias = jcp.with_bias
+ && jcp.oc_without_padding != jcp.oc;
+ float last_slice_bias[simd_w] = {0};
+ if (wants_padded_bias) {
+ for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
+ last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
+ }
+
+ {
+ parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block,
+ [&](int img, int K_blk1, int K_blk2) {
+ input_transform_data<is_fwd>(img, jcp,
+ &(input(img, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
+ &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)), V_streamout);
+ });
+
+ parallel_nd(jcp.nb_oc, jcp.nb_ic, jcp.oc_block, jcp.ic_block,
+ [&](int ofm1, int ifm1, int ofm2, int ifm2) {
+ float *U_base_ptr = is_fwd
+ ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
+ : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
+ weight_transform_data<is_fwd>(jcp,
+ &(weights(ofm1 * jcp.oc_block + ofm2,
+ ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)), U_base_ptr);
+ });
+
+ parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block, jcp.dimN_block,
+ [&](int N_blk1, int oj, int oi, int M_blk1, int N_blk2) {
+
+ kernel_->gemm_loop_ker_first_iter(
+ (float *)&(M(N_blk1, M_blk1, oj, oi,
+ N_blk2, 0, 0, 0)),
+ (const float *)&(U(M_blk1, oj, oi,
+ 0, 0, 0, 0, 0)),
+ (const float *)&(V(N_blk1, oj, oi,
+ N_blk2, 0, 0, 0, 0)));
+ for (int K_blk1 = 1; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
+ kernel_->gemm_loop_ker(
+ (float *)&(M(N_blk1, M_blk1, oj, oi,
+ N_blk2, 0, 0, 0)),
+ (const float *)&(U(M_blk1, oj, oi,
+ K_blk1, 0, 0, 0, 0)),
+ (const float *)&(V(N_blk1, oj, oi,
+ N_blk2, K_blk1,
+ 0, 0, 0)));
+ }
+
+ });
+
+ parallel_nd(jcp.mb, jcp.dimM_nb_block, jcp.dimM_block,
+ [&](int img, int M_blk1, int M_blk2) {
+
+ const int M_blk = M_blk1 * jcp.dimM_block + M_blk2;
+
+ float *bias_ptr = wants_padded_bias
+ && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
+ ? last_slice_bias : &bias(M_blk, 0);
+
+ output_transform(img, jcp, p_ops,
+ &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
+ &(output(img, M_blk, 0, 0, 0)),
+ bias_ptr, output_is_aligned);
+
+ });
+
+ }
+}
+
+template struct _jit_avx512_common_convolution_winograd_t<true>;
+template struct _jit_avx512_common_convolution_winograd_t<false>;
+
+void jit_avx512_common_convolution_winograd_bwd_weights_t::
+_maybe_execute_diff_bias_copy(float *diff_bias,
+ const memory_tracking::grantor_t &scratchpad) const {
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+ for (int oc = 0; oc < pd()->jcp_.oc_without_padding; ++oc)
+ diff_bias[oc] = padded_bias[oc];
+ }
+}
+
+void jit_avx512_common_convolution_winograd_bwd_weights_t::
+_execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx,
+ const memory_tracking::grantor_t &scratchpad) const {
+ auto ptr_diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
+ auto ptr_src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
+ auto ptr_diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto ptr_diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS);
+
+ const auto &jcp = kernel_->jcp;
+ const int nthreads = jcp.nthr;
+
+ auto diff_src_transform_bwd_weights_ver = jcp.ver == ver_4fma ?
+ diff_src_transform_bwd_weights<true> :
+ diff_src_transform_bwd_weights<false>;
+ auto diff_dst_transform_bwd_weights_ver = jcp.with_bias
+ ? diff_dst_transform_bwd_weights<true>
+ : diff_dst_transform_bwd_weights<false>;
+
+ array_offset_calculator<float, 5> src((float *)ptr_src,
+ jcp.mb, jcp.ic/simd_w, jcp.ih, jcp.iw, simd_w);
+ array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst,
+ jcp.mb, jcp.oc/simd_w, jcp.oh, jcp.ow, simd_w);
+ array_offset_calculator<float, 6> diff_weights(ptr_diff_weights,
+ jcp.oc/simd_w, jcp.ic/simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
+ array_offset_calculator<float, 2> diff_bias(pd()->wants_padded_bias()
+ ? scratchpad.get<float>(key_conv_padded_bias) : ptr_diff_bias,
+ jcp.oc/simd_w, simd_w);
+
+ array_offset_calculator<float, 8> U(
+ scratchpad.get<float>(key_wino_U),
+ jcp.nb_ic, jcp.nb_oc,
+ alpha, alpha,
+ jcp.oc_block, jcp.ic_block,
+ jcp.ic_simd_block, jcp.oc_simd_block);
+
+ array_offset_calculator<float, 8> M(
+ scratchpad.get<float>(key_wino_M),
+ jcp.nb_oc, alpha, alpha,
+ jcp.tile_block, jcp.oc_block,
+ jcp.nb_tile_block_ur, jcp.tile_block_ur * jcp.tile_4fma,
+ jcp.oc_simd_block);
+ array_offset_calculator<float, 8> V(
+ scratchpad.get<float>(key_wino_V),
+ jcp.nb_ic, alpha, alpha,
+ jcp.tile_block, jcp.ic_block,
+ jcp.nb_tile_block_ur, jcp.tile_block_ur,
+ jcp.ic_simd_block * jcp.tile_4fma);
+
+ const int trans_buffer_size = alpha * alpha * jcp.tile_4fma
+ * jcp.ic_simd_block;
+ array_offset_calculator<float, 2> trans_buffer(
+ scratchpad.get<float>(key_conv_tr_src),
+ nthreads,
+ trans_buffer_size);
+
+ array_offset_calculator<float, 2> diff_bias_prv(
+ scratchpad.get<float>(key_conv_bia_reduction),
+ nthreads,
+ jcp.oc);
+
+PRAGMA_OMP(parallel num_threads(nthreads))
+ {
+ if (jcp.with_bias) {
+ parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
+ diff_bias_prv(ithr, ofm) = 0.0f;
+ });
+
+PRAGMA_OMP(for nowait)
+ for (int bofm = 0; bofm < jcp.oc / simd_w; bofm++) {
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++)
+ diff_bias(bofm, v) = 0.0f;
+ }
+ }
+
+ const int ithread = mkldnn_get_thread_num();
+
+ parallel_nd_in_omp(jcp.mb, jcp.nb_ic, jcp.ic_block,
+ [&](int img, int ifm1, int ifm2) {
+ float *transb = jcp.ver == ver_4fma
+ ? &(trans_buffer(ithread, 0))
+ : NULL;
+ diff_src_transform_bwd_weights_ver(img, jcp,
+ &(src(img, ifm1 * jcp.ic_block + ifm2,
+ 0, 0, 0)),
+ &(V(ifm1, 0, 0, 0, ifm2, 0, 0, 0)),
+ transb,
+ kernel_->transpose_4fma_ker);
+ });
+
+ parallel_nd_in_omp(jcp.mb, jcp.nb_oc, jcp.oc_block,
+ [&](int img, int ofm1, int ofm2) {
+ float *dbias = jcp.with_bias
+ ? &(diff_bias_prv(ithread,
+ simd_w * (ofm1 * jcp.oc_block + ofm2)))
+ : NULL;
+ diff_dst_transform_bwd_weights_ver(img, jcp,
+ &(diff_dst(img, ofm1 * jcp.oc_block + ofm2,
+ 0, 0, 0)),
+ &(M(ofm1, 0, 0, 0, ofm2, 0, 0, 0)),
+ dbias);
+ });
+
+PRAGMA_OMP(barrier)
+
+ for (int ifm1 = 0; ifm1 < jcp.nb_ic; ifm1++) {
+ parallel_nd_in_omp(alpha, alpha, jcp.nb_oc,
+ [&](int oj, int oi, int ofm1) {
+ kernel_->gemm_loop_ker_first_iter(
+ (float *)&(U(ifm1, ofm1, oj, oi,
+ 0, 0, 0, 0)),
+ (const float *)&(M(ofm1, oj, oi,
+ 0, 0, 0, 0, 0)),
+ (const float *)&(V(ifm1, oj, oi,
+ 0, 0, 0, 0, 0)));
+ for (int tile_block = 1; tile_block < jcp.tile_block;
+ tile_block++) {
+ kernel_->gemm_loop_ker((float *)&(U(ifm1, ofm1,
+ oj, oi,
+ 0, 0, 0, 0)),
+ (const float *)&(M(ofm1, oj, oi, tile_block,
+ 0, 0, 0, 0)),
+ (const float *)&(V(ifm1, oj, oi, tile_block,
+ 0, 0, 0, 0)));
+ }
+ });
+ }
+
+PRAGMA_OMP(barrier)
+
+ parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
+ [&](int ifm1, int ofm1, int ofm2, int ifm2) {
+ diff_weights_transform_bwd_weights(jcp,
+ &(diff_weights(ofm1 * jcp.oc_block + ofm2,
+ ifm1 * jcp.ic_block + ifm2, 0, 0, 0, 0)),
+ &(U(ifm1, ofm1, 0, 0, ofm2, ifm2, 0, 0)));
+ });
+
+ if (jcp.with_bias) {
+PRAGMA_OMP(for)
+ for (int ofm1 = 0; ofm1 < jcp.oc / simd_w; ofm1++) {
+ for (int ithr = 0; ithr < nthreads; ithr++) {
+ float* base_bias_ptr = &(diff_bias(ofm1, 0));
+ float* base_bias_prv_ptr = &(diff_bias_prv(
+ ithr * jcp.oc + ofm1 * simd_w));
+ PRAGMA_OMP_SIMD()
+ for (int ofm2 = 0; ofm2 < simd_w; ofm2++) {
+ base_bias_ptr[ofm2] += base_bias_prv_ptr[ofm2];
+ }
+ }
+ }
+ }
+ }
+
+ _maybe_execute_diff_bias_copy(ptr_diff_bias, scratchpad);
+}
+
+}
+}
+}
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp
new file mode 100644
index 0000000000..6c76f37c72
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.hpp
@@ -0,0 +1,318 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP
+#define CPU_JIT_AVX512_COMMON_CONVOLUTION_WINOGRAD_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace winograd_avx512_common {
+inline void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_winograd_conf_t &jcp) {
+ using namespace memory_tracking::names;
+
+ size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc;
+ size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic
+ * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
+ size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc
+ * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
+
+ scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M);
+ scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M);
+ scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M);
+
+ if (jcp.sched_policy == WSCHED_WEI_S_D_G_W) {
+ const int nthr = mkldnn_get_max_threads();
+
+ size_t tr_src_sz = jcp.ver != ver_4fma ? 0 : (size_t)nthr
+ * alpha * alpha * jcp.tile_4fma * jcp.ic_simd_block;
+ scratchpad.book(key_conv_tr_src, sizeof(float) * tr_src_sz, PAGE_2M);
+
+ size_t br_sz = jcp.with_bias ? nthr * jcp.oc : 0;
+ scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M);
+
+ size_t padded_bias_sz =
+ jcp.with_bias && jcp.oc_without_padding != jcp.oc ? jcp.oc : 0;
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * padded_bias_sz);
+ }
+}
+}
+
+template <bool is_fwd>
+struct _jit_avx512_common_convolution_winograd_t {
+ _jit_avx512_common_convolution_winograd_t(
+ const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr)
+ : kernel_(nullptr), attr_(attr) {
+ kernel_ = new _jit_avx512_common_conv_winograd_data_kernel_f32(jcp);
+ }
+
+ ~_jit_avx512_common_convolution_winograd_t() { delete kernel_; }
+
+ protected:
+ void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr,
+ float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const;
+ _jit_avx512_common_conv_winograd_data_kernel_f32 *kernel_;
+ const primitive_attr_t *attr_;
+};
+
+struct jit_avx512_common_convolution_winograd_fwd_t
+ : _jit_avx512_common_convolution_winograd_t<true>
+ , public cpu_primitive_t
+ {
+ struct pd_t : public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
+ jit_avx512_common_convolution_winograd_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx512_common_conv_winograd_fwd_kernel_f32::
+ init_conf(jcp_, *desc(), *src_md(), *weights_md(), *dst_md(),
+ *attr());
+ if (status != status::success) return status;
+ set_default_alg_kind(alg_kind::convolution_winograd);
+
+ auto scratchpad = scratchpad_registry().registrar();
+ winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
+
+ return status;
+ }
+
+ jit_conv_winograd_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o;
+ return set_default_formats_common(nChw16c, wei_tag, nChw16c);
+ }
+ };
+
+ jit_avx512_common_convolution_winograd_fwd_t(const pd_t *apd)
+ : _jit_avx512_common_convolution_winograd_t<true>(apd->jcp_, apd->attr())
+ , cpu_primitive_t(apd, true) {}
+
+ ~jit_avx512_common_convolution_winograd_fwd_t(){};
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override
+ {
+ auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST);
+ this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights,
+ (float *)bias, this->scratchpad(ctx));
+ return status::success;
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+struct jit_avx512_common_convolution_winograd_bwd_data_t
+ : _jit_avx512_common_convolution_winograd_t<false>,
+ public cpu_primitive_t {
+ struct pd_t : public cpu_convolution_bwd_data_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
+ jit_avx512_common_convolution_winograd_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::undef, data_type::f32, data_type::f32)
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
+ && !has_zero_dim_memory()
+ && set_default_formats()
+ && mkldnn_thr_syncable();
+ if (!ok) return status::unimplemented;
+
+ status_t status =
+ jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf(
+ jcp_, *desc(), *diff_src_md(), *weights_md(),
+ *diff_dst_md());
+ if (status != status::success) return status;
+ set_default_alg_kind(alg_kind::convolution_winograd);
+
+ auto scratchpad = scratchpad_registry().registrar();
+ winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
+
+ return status;
+ }
+
+ jit_conv_winograd_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o;
+ return set_default_formats_common(nChw16c, wei_tag, nChw16c);
+ }
+ };
+
+ jit_avx512_common_convolution_winograd_bwd_data_t(const pd_t *apd)
+ : _jit_avx512_common_convolution_winograd_t<false>(apd->jcp_, apd->attr())
+ , cpu_primitive_t(apd, true) {}
+
+ ~jit_avx512_common_convolution_winograd_bwd_data_t(){};
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC);
+ this->_execute_data_W_S_G_D((float *)diff_dst, diff_src,
+ (float *)weights, nullptr, this->scratchpad(ctx));
+ return status::success;
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+struct jit_avx512_common_convolution_winograd_bwd_weights_t
+ : public cpu_primitive_t {
+ struct pd_t : public cpu_convolution_bwd_weights_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr,
+ hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_wino:", avx512_common, ""),
+ jit_avx512_common_convolution_winograd_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_weights
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats()
+ && mkldnn_thr_syncable();
+ if (!ok) return status::unimplemented;
+
+ status_t status =
+ jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::
+ init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(),
+ *diff_weights_md());
+ if (status != status::success) return status;
+ set_default_alg_kind(alg_kind::convolution_winograd);
+
+ auto scratchpad = scratchpad_registry().registrar();
+ winograd_avx512_common::init_scratchpad(scratchpad, jcp_);
+
+ return status;
+ }
+
+ jit_conv_winograd_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ auto wei_tag = with_groups() ? gOIhw16i16o : OIhw16i16o;
+ return set_default_formats_common(nChw16c, wei_tag, nChw16c);
+ }
+ };
+
+ jit_avx512_common_convolution_winograd_bwd_weights_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true), kernel_(nullptr)
+ {
+ kernel_ = new jit_avx512_common_conv_winograd_bwd_weights_kernel_f32(
+ pd()->jcp_);
+ }
+
+ ~jit_avx512_common_convolution_winograd_bwd_weights_t()
+ { delete kernel_; }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override
+ {
+ _execute_backward_weights_S_D_G_W(ctx, scratchpad(ctx));
+ return status::success;
+ }
+
+private:
+ void _execute_backward_weights_S_D_G_W(const exec_ctx_t &ctx,
+ const memory_tracking::grantor_t &scratchpad) const;
+ void _maybe_execute_diff_bias_copy(float *diff_bias,
+ const memory_tracking::grantor_t &scratchpad) const;
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ jit_avx512_common_conv_winograd_bwd_weights_kernel_f32 *kernel_;
+};
+
+void trans_W_4x4_3x3(float Fw_[6][6][16][16], float F[3][3][16][16]);
+void trans_O_4x4_3x3(float Mw[6][6][16], float O[4][4][16]);
+void trans_W_3x3_4x4(float Fw[6][6][16], float F[4][6][16]);
+void trans_O_3x3_4x4(float Mw[6][6][16][16], float M[3][3][16][16]);
+void trans_I_4x4_3x3(float Iw[6][6][16], float I[6][6][16]);
+void trans_W_3x3_4x4_wu(float Fw[6][6][16], float F[4][6][16]);
+void trans_O_3x3_4x4_wu(float Mw[6][6][16][16], float M[3][3][16][16]);
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp
new file mode 100644
index 0000000000..d4a451c021
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.cpp
@@ -0,0 +1,853 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_avx512_common_lrn.hpp"
+
+#include "jit_generator.hpp"
+
+#define FWD_RBC 4
+#define BWD_RBC 3
+
+#define XMM_SIZE (4*sizeof(float))
+#define ZMM_SIZE (vlen)
+#define BUFFER_BLOCK (XMM_SIZE + ZMM_SIZE + XMM_SIZE)
+#define BUFFER_NEXT_OFFSET (XMM_SIZE + ZMM_SIZE)
+#define SRC_PREV_OFFSET (vlen - XMM_SIZE)
+
+#define IRB_LOOP(statement) for(int irb = 0; irb < loop_size; irb++) { \
+ statement;\
+}
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+
+using namespace Xbyak;
+
+enum params { vsize = 16, vlen = 64};
+
+typedef struct {
+ const float *src;
+ float *dst, *ws0, *ws1;
+} jit_args_fwd_t;
+
+typedef struct {
+ const float *src, *diff_dst, *ws0, *ws1;
+ float *diff_src;
+} jit_args_bwd_t;
+
+struct nChw16c_across {
+/* version:
+ * -1: channels 0..15,
+ * 1: channels C-16 .. C-1,
+ * 0: other channels
+ * 3: channels only for this kernel(without prev and next)
+ */
+ int H, W, version;
+ nChw16c_across(int h, int w, int v) : H(h), W(w), version(v) {}
+};
+
+struct jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_kernel_f32:
+ public jit_generator {
+ int HW, W;
+ bool is_first;
+ bool is_last;
+ bool is_single;
+
+ Reg64 src = rax;
+ Reg64 dst = r8;
+ Reg64 scratch0 = rdx;
+ Reg64 scratch1 = rsi;
+ Reg64 imm_addr64 = rbx;
+
+ Zmm zalpha = zmm0;
+ Xmm xalpha = xmm0;
+ Zmm zk = zmm1;
+ Xmm xk = xmm1;
+
+ Reg64 param = abi_param1;
+ Reg64 t = rsp;
+ Reg64 hw = r9;
+
+ int xsrc_prev = 2;
+ int zsrc = 7;
+ int xsrc_next = 3;
+ int zc = 7;
+
+ int za = 2;
+ int zb = 3;
+ int zd = 5;
+ int ze = 6;
+ int zsum = 4;
+ int zdst = 2;
+ int zbase = 3;
+ int zsum2 = 5;
+
+ prop_kind_t pk;
+ int use_h_parallelism;
+
+ float alpha, k;
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32)
+
+ void (*ker)(jit_args_fwd_t *);
+ void operator()(jit_args_fwd_t *arg) { ker(arg); }
+
+ enum {
+ prf0_offt = 1*FWD_RBC,
+ prf2_offt = 8*FWD_RBC
+ };
+
+ inline void compute_loop(int loop_size_param)
+ {
+ // loop_size - param for IRB_LOOP macro
+ int loop_size = FWD_RBC;
+
+ auto xreg = [=](int irb, int i) {
+ return Xmm(irb*3 + i);
+ };
+
+ auto zreg = [=](int irb, int i) {
+ return Zmm(irb*7 + i);
+ };
+
+ if (!is_first && !is_single) {
+ IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt - HW)*vlen]));
+ IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt - HW)*vlen]));
+ }
+ IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(src, (irb + prf0_offt)*vlen)));
+ IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(src, (irb + prf2_offt)*vlen)));
+ if (!is_last && !is_single) {
+ IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt + HW)*vlen]));
+ IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt + HW)*vlen]));
+ }
+ if (pk != prop_kind::forward_inference) {
+ IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch0,
+ (irb + prf0_offt)*vlen)));
+ IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch0,
+ (irb + prf2_offt)*vlen)));
+ }
+ IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(dst, (irb + prf0_offt)*vlen)));
+ IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(dst, (irb + prf2_offt)*vlen)));
+ if (pk != prop_kind::forward_inference) {
+ IRB_LOOP(mic_prefetcht0(EVEX_compress_addr(scratch1,
+ (irb + prf0_offt) * vlen)));
+ IRB_LOOP(mic_prefetcht2(EVEX_compress_addr(scratch1,
+ (irb + prf2_offt) * vlen)));
+ }
+
+ loop_size = loop_size_param;
+ if (loop_size == 0)
+ return;
+ if (!is_first && !is_single) {
+ IRB_LOOP(vmovups(xreg(irb, xsrc_prev),
+ ptr[src + (irb - HW) * vlen + SRC_PREV_OFFSET]));
+ }
+ IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src,irb*vlen)));
+ if (!is_last && !is_single) {
+ IRB_LOOP(vmovups(xreg(irb, xsrc_next),
+ ptr[src + (irb + HW) * vlen]));
+ }
+
+ if (!is_first && !is_single) {
+ IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK],
+ xreg(irb, xsrc_prev)));
+ }
+ IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE),
+ zreg(irb, zsrc)));
+ if (!is_last && !is_single) {
+ IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
+ xreg(irb, xsrc_next)));
+ }
+
+ IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK
+ + XMM_SIZE - 2*sizeof(float))));
+ IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK
+ + XMM_SIZE - sizeof(float))));
+ IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK
+ + XMM_SIZE + sizeof(float))));
+ IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK
+ + XMM_SIZE + 2*sizeof(float))));
+
+ assert(zc == zsrc);
+ IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zc), zreg(irb, zc)));
+
+ IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, za), zreg(irb, za)));
+ IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zb), zreg(irb, zb)));
+ IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, zd), zreg(irb, zd)));
+ IRB_LOOP(vfmadd231ps(zreg(irb, zsum), zreg(irb, ze), zreg(irb, ze)));
+
+ IRB_LOOP(vfmadd132ps(zreg(irb, zsum), zk, zalpha));
+
+ IRB_LOOP(vmovaps(zreg(irb, zbase), zreg(irb, zsum)));
+
+ IRB_LOOP(vmulps(zreg(irb, zsum2), zreg(irb, zsum), zreg(irb, zsum)));
+ IRB_LOOP(vmulps(zreg(irb, zsum), zreg(irb, zsum), zreg(irb, zsum2)));
+
+ IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum)));
+ IRB_LOOP(vsqrtps(zreg(irb, zsum), zreg(irb, zsum)));
+
+ if (pk != prop_kind::forward_inference) {
+ IRB_LOOP(vmovups(EVEX_compress_addr(scratch0, irb*vlen),
+ zreg(irb, zsum)));
+ }
+ IRB_LOOP(vdivps(zreg(irb, zdst), zreg(irb, zsrc), zreg(irb, zsum)));
+ IRB_LOOP(vmovups(EVEX_compress_addr(dst, irb*vlen), zreg(irb, zdst)));
+ if (pk != prop_kind::forward_inference) {
+ /* ws1 = zdst / zbase = zsrc / (zbase^1.75) */
+ IRB_LOOP(vdivps(zreg(irb, zsum), zreg(irb, zdst), zreg(irb, zbase)));
+ IRB_LOOP(vmovups(EVEX_compress_addr(scratch1, irb*vlen),
+ zreg(irb, zsum)));
+ }
+ }
+
+ jit_avx512_common_lrn_kernel_f32(
+ const struct nChw16c_across &J,
+ prop_kind_t prop_kind,
+ int use_h_parallel,
+ float A,
+ float K,
+ void *code_ptr = nullptr,
+ size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE)
+ : jit_generator(code_ptr, code_size)
+ , pk(prop_kind)
+ , use_h_parallelism(use_h_parallel)
+ , alpha(A)
+ , k(K)
+ {
+ this->preamble();
+
+ mov(src, ptr[param + 0]);
+ mov(dst, ptr[param + 8]);
+ if (pk != prop_kind::forward_inference)
+ {
+ mov(scratch0, ptr[param + 16]);
+ mov(scratch1, ptr[param + 24]);
+ }
+ is_first = J.version == -1 || J.version == -2;
+ is_last = J.version == +1 || J.version == -2;
+ is_single = J.version == 3;
+
+ W = J.W;
+ HW = J.W*J.H;
+ int LSB = use_h_parallelism ? W : HW;
+
+ sub(t, FWD_RBC*BUFFER_BLOCK);
+ mov(imm_addr64, float2int(this->alpha));
+ movq(xalpha, imm_addr64);
+ vbroadcastss(zalpha, xalpha);
+
+ mov(imm_addr64, float2int(this->k));
+ movq(xk, imm_addr64);
+ vbroadcastss(zk, xk);
+
+ if (is_first || is_single) {
+ vxorps(xmm2, xmm2, xmm2);
+ for(int irb = 0; irb < FWD_RBC; irb++) {
+ vmovups(ptr[t + irb*BUFFER_BLOCK], xmm2);
+ }
+ }
+ if (is_last || is_single) {
+ vxorps(xmm2, xmm2, xmm2);
+ for(int irb = 0; irb < FWD_RBC; irb++) {
+ vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
+ xmm2);
+ }
+ }
+
+ int LSREST = LSB % FWD_RBC;
+ int LS = LSB - LSREST;
+
+ Label lrn_loop;
+
+ if (LS > 0) {
+ mov(hw, LS);
+
+ L(lrn_loop);
+ {
+ compute_loop(FWD_RBC);
+
+ add(src, FWD_RBC*vlen);
+ add(dst, FWD_RBC*vlen);
+ if (pk != prop_kind::forward_inference)
+ {
+ add(scratch0, FWD_RBC*vlen);
+ add(scratch1, FWD_RBC*vlen);
+ }
+
+ for(int irb = 0; irb < FWD_RBC; irb++)
+ dec(hw);
+ cmp(hw, 0);
+ jne(lrn_loop, T_NEAR);
+ }
+ }
+
+ compute_loop(LSREST);
+
+ add(t, FWD_RBC*BUFFER_BLOCK);
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+ }
+};
+
+status_t jit_avx512_common_lrn_fwd_t::pd_t::init() {
+ using namespace prop_kind;
+ using namespace alg_kind;
+
+ const memory_desc_wrapper data_d(src_md());
+ bool ok = true
+ && mayiuse(avx512_common)
+ && is_fwd()
+ && !has_zero_dim_memory()
+ && everyone_is(data_type::f32, data_d.data_type())
+ && data_d.ndims() == 4
+ && data_d.dims()[1] % vsize == 0
+ && attr()->has_default_values();
+ if (!ok) return unimplemented;
+
+ if (desc()->prop_kind == forward_training) {
+ dims_t ws_dims = { MB(), C(), H(), 2*W() };
+ mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32,
+ format_tag::nChw16c);
+ }
+
+ bool args_ok_across = true
+ && desc()->alg_kind == lrn_across_channels
+ && desc()->local_size == 5
+ && desc()->lrn_beta == 0.75
+ && data_d.matches_tag(format_tag::nChw16c);
+
+ return args_ok_across ? success : unimplemented;
+}
+
+jit_avx512_common_lrn_fwd_t::jit_avx512_common_lrn_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr)
+ , ker_last_(nullptr) {
+ using namespace alg_kind;
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const int ls = pd()->desc()->local_size;
+ const float alpha = pd()->desc()->lrn_alpha / ls;
+ const float k = pd()->desc()->lrn_k;
+
+ auto pk = pd()->desc()->prop_kind;
+
+ use_h_parallelism = H > 28 ? 1 : 0;
+
+ if (C / vsize == 1) {
+ ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3), pk,
+ use_h_parallelism, alpha, k);
+ } else {
+ ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0), pk,
+ use_h_parallelism, alpha, k);
+ ker_first_ = new jit_avx512_common_lrn_kernel_f32(
+ nChw16c_across(H, W, -1), pk, use_h_parallelism, alpha, k);
+ ker_last_ = new jit_avx512_common_lrn_kernel_f32(
+ nChw16c_across(H, W, +1), pk, use_h_parallelism, alpha, k);
+ }
+}
+
+jit_avx512_common_lrn_fwd_t::~jit_avx512_common_lrn_fwd_t()
+{ delete ker_; delete ker_first_; delete ker_last_; }
+
+void jit_avx512_common_lrn_fwd_t::execute_forward(const exec_ctx_t &ctx) const
+{
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE);
+
+ const int N = pd()->MB();
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ size_t start{0}, end{0};
+ const int C16 = C / vsize;
+ const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16;
+
+ balance211(work_amount, nthr, ithr, start, end);
+ if (use_h_parallelism) {
+ int n{0}, c16{0}, h{0};
+ nd_iterator_init(start, n, N, c16, C16, h, H);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ auto offset = n*C*H*W + c16*H*W*vsize
+ + h*W*vsize;
+ auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize
+ + h*2*W*vsize;
+ auto ws_offset1 = ws_offset0 + W*vsize;
+
+ jit_args_fwd_t args;
+ args.src = &src[offset];
+ args.dst = &dst[offset];
+ args.ws0 = &ws[ws_offset0];
+ args.ws1 = &ws[ws_offset1];
+
+ if (C16 == 1)
+ (*ker_)(&args);
+ else if (c16 == 0)
+ (*ker_first_)(&args);
+ else if (c16 == C16 - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ nd_iterator_step(n, N, c16, C16, h, H);
+ }
+ } else {
+ int n{0}, c16{0};
+ nd_iterator_init(start, n, N, c16, C16);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ auto offset = n*C*H*W + c16*H*W*vsize;
+ auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize;
+ auto ws_offset1 = ws_offset0 + H*W*vsize;
+
+ jit_args_fwd_t args;
+ args.src = &src[offset];
+ args.dst = &dst[offset];
+ args.ws0 = &ws[ws_offset0];
+ args.ws1 = &ws[ws_offset1];
+
+ if (C16 == 1)
+ (*ker_)(&args);
+ else if (c16 == 0)
+ (*ker_first_)(&args);
+ else if (c16 == C16 - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+
+ nd_iterator_step(n, N, c16, C16);
+ }
+ }
+ });
+}
+
+struct jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_kernel_f32:
+ public jit_generator {
+ int HW, W;
+ bool is_first;
+ bool is_last;
+ bool is_single;
+
+ Reg64 src = rax;
+ Reg64 diffsrc = r8;
+ Reg64 diffdst = r9;
+ Reg64 workspace0 = rdx;
+ Reg64 workspace1 = rsi;
+ Reg64 imm_addr64 = rbx;
+
+ Zmm znalphabeta = zmm0;
+ Xmm xnalphabeta = xmm0;
+
+ Reg64 param = abi_param1;
+ Reg64 t = rsp;
+ Reg64 hw = r10;
+
+ int xws1_prev = 1;
+ int xdiffdst_prev = 2;
+ int zws1 = 1;
+
+ int zsrc = 1;
+ int zdiffdst = 5;
+ int zdiffsrc = 6;
+
+ int xws1_next = 1;
+ int xdiffdst_next = 3;
+
+ int za = 1;
+ int zb = 2;
+ int zd = 3;
+ int ze = 4;
+ int zws0 = 2;
+
+ float nalphabeta;
+
+ int use_h_parallelism;
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_common_lrn_kernel_f32)
+
+ void (*ker)(jit_args_bwd_t *);
+ void operator()(jit_args_bwd_t *arg) { ker(arg); }
+
+ enum {
+ prf0_offt = 1*BWD_RBC,
+ prf2_offt = 8*BWD_RBC
+ };
+
+ inline void compute_loop(int loop_size_param, int prefetchL1,
+ int prefetchL2)
+ {
+ // loop_size - param for IRB_LOOP macro
+ int loop_size = loop_size_param;
+
+ auto xreg = [=](int irb, int i) {
+ return Xmm(irb*6 + i);
+ };
+
+ auto zreg = [=](int irb, int i) {
+ return Zmm(irb*6 + i);
+ };
+
+// ---- prefetching -------------------------------------------
+ if (!is_first && !is_single) {
+ if (prefetchL1)
+ IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt
+ - 2 * HW) * vlen]));
+ if (prefetchL1)
+ IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt
+ - HW) * vlen]));
+ }
+
+ if (prefetchL1)
+ IRB_LOOP(mic_prefetcht0(ptr[src + (irb + prf0_offt)*vlen]));
+ if (prefetchL2)
+ IRB_LOOP(mic_prefetcht2(ptr[src + (irb + prf2_offt)*vlen]));
+
+ if (prefetchL1)
+ IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt)*vlen]));
+
+ if (prefetchL1)
+ IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt)*vlen]));
+
+ if (!is_last && !is_single) {
+ if (prefetchL1)
+ IRB_LOOP(mic_prefetcht0(ptr[workspace1 + (irb + prf0_offt
+ + 2 * HW) * vlen]));
+ if (prefetchL2)
+ IRB_LOOP(mic_prefetcht2(ptr[workspace1 + (irb + prf2_offt
+ + 2 * HW) * vlen]));
+
+ if (prefetchL1)
+ IRB_LOOP(mic_prefetcht0(ptr[diffdst + (irb + prf0_offt
+ + HW) * vlen]));
+ if (prefetchL2)
+ IRB_LOOP(mic_prefetcht2(ptr[diffdst + (irb + prf2_offt
+ + HW) * vlen]));
+ }
+ if (prefetchL1)
+ IRB_LOOP(mic_prefetcht0(ptr[workspace0 + (irb + prf0_offt)*vlen]));
+ if (prefetchL2)
+ IRB_LOOP(mic_prefetcht2(ptr[workspace0 + (irb + prf2_offt)*vlen]));
+// -----------------------------------------------------------
+
+ if (loop_size_param == 0)
+ return;
+
+ if (!is_first && !is_single) {
+ IRB_LOOP(vmovups(xreg(irb, xws1_prev), ptr[workspace1 + (irb
+ - 2 * HW) * vlen + SRC_PREV_OFFSET]));
+ IRB_LOOP(vmovups(xreg(irb, xdiffdst_prev), ptr[diffdst + (irb
+ - HW) * vlen + SRC_PREV_OFFSET]));
+ IRB_LOOP(vmulps(xreg(irb, xdiffdst_prev), xreg(irb, xdiffdst_prev),
+ xreg(irb, xws1_prev)));
+ }
+
+ IRB_LOOP(vmovups(zreg(irb, zws1),
+ EVEX_compress_addr(workspace1, irb*vlen)));
+ IRB_LOOP(vmovups(zreg(irb, zdiffdst),
+ EVEX_compress_addr(diffdst, irb*vlen)));
+ IRB_LOOP(vmulps(zreg(irb, zdiffsrc), zreg(irb, zdiffdst),
+ zreg(irb, zws1)));
+
+ if (!is_last && !is_single) {
+ IRB_LOOP(vmovups(xreg(irb, xws1_next), ptr[workspace1 + (irb
+ + 2 * HW) * vlen]));
+ IRB_LOOP(vmovups(xreg(irb, xdiffdst_next), ptr[diffdst + (irb
+ + HW) * vlen]));
+ IRB_LOOP(vmulps(xreg(irb, xdiffdst_next), xreg(irb, xdiffdst_next),
+ xreg(irb, xws1_next)));
+ }
+
+ if (!is_first && !is_single) {
+ IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK],
+ xreg(irb, xdiffdst_prev)));
+ }
+ IRB_LOOP(vmovups(EVEX_compress_addr(t, irb*BUFFER_BLOCK + XMM_SIZE),
+ zreg(irb, zdiffsrc)));
+ if (!is_last && !is_single) {
+ IRB_LOOP(vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET],
+ xreg(irb, xdiffdst_next)));
+ }
+
+ IRB_LOOP(vmovups(zreg(irb, za), EVEX_compress_addr(t, irb*BUFFER_BLOCK
+ + XMM_SIZE - 2*sizeof(float))));
+ IRB_LOOP(vmovups(zreg(irb, zb), EVEX_compress_addr(t, irb*BUFFER_BLOCK
+ + XMM_SIZE - 1*sizeof(float))));
+ IRB_LOOP(vmovups(zreg(irb, zd), EVEX_compress_addr(t, irb*BUFFER_BLOCK
+ + XMM_SIZE + 1*sizeof(float))));
+ IRB_LOOP(vmovups(zreg(irb, ze), EVEX_compress_addr(t, irb*BUFFER_BLOCK
+ + XMM_SIZE + 2*sizeof(float))));
+ IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
+ zreg(irb, za)));
+ assert(zsrc == za);
+ IRB_LOOP(vmovups(zreg(irb, zsrc), EVEX_compress_addr(src, irb*vlen)));
+ IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
+ zreg(irb, zb)));
+ IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
+ zreg(irb, zd)));
+ IRB_LOOP(vaddps(zreg(irb, zdiffsrc), zreg(irb, zdiffsrc),
+ zreg(irb, ze)));
+ IRB_LOOP(vmulps(zreg(irb, zsrc), zreg(irb, zsrc), znalphabeta));
+
+ IRB_LOOP(vmovups(zreg(irb, zws0),
+ EVEX_compress_addr(workspace0, irb*vlen)));
+ IRB_LOOP(vdivps(zreg(irb, zdiffdst), zreg(irb, zdiffdst),
+ zreg(irb, zws0)));
+ IRB_LOOP(vfmadd213ps(zreg(irb, zdiffsrc), zreg(irb, zsrc),
+ zreg(irb, zdiffdst)));
+
+ Label unaligned_store, end_store;
+ test(diffsrc, vlen - 1);
+ jnz(unaligned_store, T_NEAR);
+ IRB_LOOP(uni_vmovntps(EVEX_compress_addr(diffsrc, irb*vlen),
+ zreg(irb, zdiffsrc)));
+ jmp(end_store, T_NEAR);
+ L(unaligned_store); {
+ IRB_LOOP(uni_vmovups(EVEX_compress_addr(diffsrc, irb*vlen),
+ zreg(irb, zdiffsrc)));
+ }
+ L(end_store);
+ }
+
+ jit_avx512_common_lrn_kernel_f32(
+ const struct nChw16c_across &J,
+ float A,
+ float B,
+ int use_h_parallel,
+ void *code_ptr = nullptr,
+ size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE)
+ : jit_generator(code_ptr, code_size)
+ , nalphabeta(-2*A*B)
+ , use_h_parallelism(use_h_parallel)
+ {
+ this->preamble();
+
+ mov(src, ptr[param + 0]);
+ mov(diffdst, ptr[param + 8]);
+ mov(workspace0, ptr[param + 16]);
+ mov(workspace1, ptr[param + 24]);
+ mov(diffsrc, ptr[param + 32]);
+
+ W = J.W;
+ HW = J.H*J.W;
+ int LSB = this->use_h_parallelism ? W : HW;
+
+ sub(t, BWD_RBC*BUFFER_BLOCK);
+ mov(imm_addr64, float2int(this->nalphabeta));
+ movq(xnalphabeta, imm_addr64);
+ vbroadcastss(znalphabeta, xnalphabeta);
+
+ is_first = J.version == -1 || J.version == -2;
+ is_last = J.version == +1 || J.version == +2;
+ is_single = J.version == 3;
+
+ if (is_first || is_single) {
+ vxorps(xmm1, xmm1, xmm1);
+ for(int irb = 0; irb < BWD_RBC; irb++) {
+ vmovups(ptr[t + irb*BUFFER_BLOCK], xmm1);
+ }
+ }
+ if (is_last || is_single) {
+ vxorps(xmm1, xmm1, xmm1);
+ for(int irb = 0; irb < BWD_RBC; irb++) {
+ vmovups(ptr[t + irb*BUFFER_BLOCK + BUFFER_NEXT_OFFSET], xmm1);
+ }
+ }
+
+ int LSREST = LSB % BWD_RBC;
+ int LS = LSB - LSREST;
+
+ Label lrn_loop;
+
+ if (LS > 0) {
+ mov(hw, LS);
+
+ L(lrn_loop);
+ {
+ compute_loop(BWD_RBC, 1, 1);
+
+ add(src, BWD_RBC*vlen);
+ add(diffsrc, BWD_RBC*vlen);
+ add(diffdst, BWD_RBC*vlen);
+ add(workspace0, BWD_RBC*vlen);
+ add(workspace1, BWD_RBC*vlen);
+
+ for(int irb = 0; irb < BWD_RBC; irb++)
+ dec(hw);
+ cmp(hw, 0);
+ jne(lrn_loop, T_NEAR);
+ }
+ }
+
+ compute_loop(LSREST, 1, this->use_h_parallelism ? 0 : 1);
+
+ add(t, BWD_RBC*BUFFER_BLOCK);
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+ }
+
+};
+
+status_t jit_avx512_common_lrn_bwd_t::pd_t::init() {
+ using namespace alg_kind;
+
+ const memory_desc_wrapper data_d(src_md());
+ bool ok = true
+ && mayiuse(avx512_common)
+ && !is_fwd()
+ && utils::everyone_is(data_type::f32, data_d.data_type())
+ && !has_zero_dim_memory()
+ && data_d.ndims() == 4
+ && data_d.dims()[1] % vsize == 0
+ && attr()->has_default_values();
+ if (!ok) return unimplemented;
+
+ dims_t ws_dims = { MB(), C(), H(), 2*W() };
+ mkldnn_memory_desc_init_by_tag(&ws_md_, 4, ws_dims, data_type::f32,
+ format_tag::nChw16c);
+
+ if (!compare_ws(hint_fwd_pd_)) return unimplemented;
+
+ bool args_ok_across = true
+ && desc()->alg_kind == lrn_across_channels
+ && desc()->local_size == 5
+ && desc()->lrn_beta == 0.75
+ && data_d.matches_tag(format_tag::nChw16c);
+
+ return args_ok_across ? success : unimplemented;
+}
+
+jit_avx512_common_lrn_bwd_t::jit_avx512_common_lrn_bwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , use_h_parallelism(0), ker_(nullptr), ker_first_(nullptr)
+ , ker_last_(nullptr) {
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const int ls = pd()->desc()->local_size;
+ const float alpha = pd()->desc()->lrn_alpha / ls;
+ const float beta = pd()->desc()->lrn_beta;
+
+ use_h_parallelism = H > 28 ? 1 : 0;
+
+ if (C / vsize == 1) {
+ ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 3),
+ alpha, beta, use_h_parallelism);
+ } else {
+ ker_ = new jit_avx512_common_lrn_kernel_f32(nChw16c_across(H, W, 0),
+ alpha, beta, use_h_parallelism);
+ ker_first_ = new jit_avx512_common_lrn_kernel_f32(
+ nChw16c_across(H, W, -1), alpha, beta, use_h_parallelism);
+ ker_last_ = new jit_avx512_common_lrn_kernel_f32(
+ nChw16c_across(H, W, +1), alpha, beta, use_h_parallelism);
+ }
+}
+
+jit_avx512_common_lrn_bwd_t::~jit_avx512_common_lrn_bwd_t()
+{ delete ker_; delete ker_first_; delete ker_last_; }
+
+void jit_avx512_common_lrn_bwd_t::execute_backward(const exec_ctx_t &ctx) const
+{
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const int N = pd()->MB();
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ size_t start{0}, end{0};
+ const int C16 = C / vsize;
+ const size_t work_amount = use_h_parallelism ? N*C16*H : N*C16;
+
+ balance211(work_amount, nthr, ithr, start, end);
+ if (use_h_parallelism) {
+ int n{0}, c16{0}, h{0};
+ nd_iterator_init(start, n, N, h, H, c16, C16);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ auto offset = n*C*H*W + c16*H*W*vsize
+ + h*W*vsize;
+ auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize
+ + h*2*W*vsize;
+ auto ws_offset1 = ws_offset0 + W*vsize;
+
+ jit_args_bwd_t args;
+ args.src = &src[offset];
+ args.diff_dst = &diff_dst[offset];
+ args.ws0 = &ws[ws_offset0];
+ args.ws1 = &ws[ws_offset1];
+ args.diff_src = &diff_src[offset];
+
+ if (C16 == 1)
+ (*ker_)(&args);
+ else if (c16 == 0)
+ (*ker_first_)(&args);
+ else if (c16 == C16 - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ nd_iterator_step(n, N, h, H, c16, C16);
+ }
+ } else {
+ int n{0}, c16{0};
+ nd_iterator_init(start, n, N, c16, C16);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ auto offset = n*C*H*W + c16*H*W*vsize;
+ auto ws_offset0 = n*C*H*2*W + c16*H*2*W*vsize;
+ auto ws_offset1 = ws_offset0 + H*W*vsize;
+
+ jit_args_bwd_t args;
+ args.src = &src[offset];
+ args.diff_dst = &diff_dst[offset];
+ args.ws0 = &ws[ws_offset0];
+ args.ws1 = &ws[ws_offset1];
+ args.diff_src = &diff_src[offset];
+
+ if (C16 == 1)
+ (*ker_)(&args);
+ else if (c16 == 0)
+ (*ker_first_)(&args);
+ else if (c16 == C16 - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+
+ nd_iterator_step(n, N, c16, C16);
+ }
+ }
+ });
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp
new file mode 100644
index 0000000000..37fbb9b3e5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_common_lrn.hpp
@@ -0,0 +1,96 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_COMMON_LRN_HPP
+#define CPU_JIT_AVX512_COMMON_LRN_HPP
+
+#include "c_types_map.hpp"
+
+#include "cpu_isa_traits.hpp"
+#include "cpu_lrn_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_avx512_common_lrn_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_lrn_fwd_pd_t {
+ using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
+ jit_avx512_common_lrn_fwd_t);
+
+ status_t init();
+ };
+
+ jit_avx512_common_lrn_fwd_t(const pd_t *apd);
+ ~jit_avx512_common_lrn_fwd_t();
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ int use_h_parallelism;
+ struct jit_avx512_common_lrn_kernel_f32;
+ jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_;
+};
+
+struct jit_avx512_common_lrn_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_lrn_bwd_pd_t {
+ using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", avx512_common, ""),
+ jit_avx512_common_lrn_bwd_t);
+
+ status_t init();
+ };
+
+ jit_avx512_common_lrn_bwd_t(const pd_t *apd);
+ ~jit_avx512_common_lrn_bwd_t();
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ int use_h_parallelism;
+ struct jit_avx512_common_lrn_kernel_f32;
+ jit_avx512_common_lrn_kernel_f32 *ker_, *ker_first_, *ker_last_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp
new file mode 100644
index 0000000000..c58d3fa0a6
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp
@@ -0,0 +1,1103 @@
+/*******************************************************************************
+ * Copyright 2018 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_avx512_core_fp32_wino_conv_2x3.hpp"
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::format_kind;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+using namespace Xbyak;
+
+/// SRC TRANSFORMS /////////////////////////////////////////////////////////////
+struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(
+ jit_avx512_core_fp32_wino_conv_2x3_src_trans_t)
+
+ jit_conv_conf_2x3_wino_t jcp;
+
+ struct call_params_t {
+ const void *src;
+ const void *wino_src;
+ const void *v_y_masks;
+ const void *v_x_masks;
+ };
+ void (*ker_)(const call_params_t *);
+
+ jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
+ jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
+ : jcp(ajcp) {
+ generate();
+ ker_ =
+ reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
+ }
+
+ void generate();
+
+ Zmm vreg_inp(int i) {
+ assert(i < jcp.alpha * jcp.alpha);
+ return Zmm(31 - i);
+ }
+
+ Zmm vreg_tmp(int i) {
+ assert(i < jcp.alpha * jcp.alpha);
+ return Zmm(15 - i);
+ }
+
+ Zmm vreg_out(int i) {
+ assert(i < jcp.alpha * jcp.alpha);
+ return Zmm(31 - i);
+ }
+
+ Opmask y_mask = Opmask(1);
+ Opmask r_mask = Opmask(2);
+ Opmask x_mask(int id) {
+ assert (id < 4);
+ return Opmask(3 + id);
+ }
+
+ Reg64 reg_ptr_v_y_masks = r12;
+ Reg64 reg_ptr_v_x_masks = r11;
+
+ Reg64 reg_aux_ptr_src = r10;
+ Reg64 reg_aux_ptr_dst = r9;
+
+ Reg64 reg_ic_block = r8;
+
+};
+
+void jit_avx512_core_fp32_wino_conv_2x3_src_trans_t::generate() {
+ Label ic_block_label;
+
+ const int load_block = 16;
+ int out_offset = 0, inp_offset = 0;
+ preamble();
+
+#define READ_PARAM(reg, field) \
+ mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
+ READ_PARAM(reg_aux_ptr_src, src);
+ READ_PARAM(reg_aux_ptr_dst, wino_src);
+ READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
+ READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
+#undef READ_PARAM
+
+ for (int i = 0; i < jcp.alpha; i++) {
+ kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
+ }
+ mov(reg_ic_block, jcp.ic / load_block);
+ L(ic_block_label);
+ {
+ for (int y = 0; y < jcp.alpha; y++) {
+ kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]);
+ for (int x = 0; x < jcp.alpha; x++) {
+ Zmm zmm = vreg_inp(y * jcp.alpha + x);
+
+ vxorps(zmm, zmm, zmm);
+ kandw(r_mask, y_mask, x_mask(x));
+ inp_offset = sizeof(float)
+ * ((-jcp.t_pad + y) * jcp.iw * load_block
+ + (-jcp.l_pad + x) * load_block);
+ vmovups(zmm | r_mask,
+ EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
+ }
+ }
+ for (int y = 0; y < jcp.alpha; y++) {
+ vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0),
+ vreg_inp(y * jcp.alpha + 2));
+ vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1),
+ vreg_inp(y * jcp.alpha + 2));
+ vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2),
+ vreg_inp(y * jcp.alpha + 1));
+ vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1),
+ vreg_inp(y * jcp.alpha + 3));
+ }
+ for (int x = 0; x < jcp.alpha; x++) {
+ vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0),
+ vreg_tmp(x + jcp.alpha * 2));
+ vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
+ vreg_tmp(x + jcp.alpha * 2));
+ vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2),
+ vreg_tmp(x + jcp.alpha * 1));
+ vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
+ vreg_tmp(x + jcp.alpha * 3));
+ }
+
+ for (int i = 0; i < 16; i++) {
+ out_offset = sizeof(float) * (jcp.inp_stride * i);
+ vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset),
+ vreg_out(i));
+ }
+
+ add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block);
+ add(reg_aux_ptr_dst, sizeof(float) * load_block);
+ }
+ dec(reg_ic_block);
+ cmp(reg_ic_block, 0);
+ jg(ic_block_label, T_NEAR);
+ postamble();
+}
+
+/// DST TRANSFORMS /////////////////////////////////////////////////////////////
+struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(
+ jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t)
+
+ jit_conv_conf_2x3_wino_t jcp;
+ const primitive_attr_t &attr_;
+
+ struct call_params_t {
+ const void *wino_dst;
+ const void *dst;
+ const void *v_y_masks;
+ const void *v_x_masks;
+
+ const void *bias;
+ const void *scales;
+ };
+ void (*ker_)(const call_params_t *);
+
+ jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
+ jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr) {
+ generate();
+ ker_ = reinterpret_cast<decltype(ker_)>(
+ const_cast<uint8_t *>(getCode()));
+ }
+
+ void generate();
+ bool maybe_relu(int position);
+
+ Zmm vreg_inp(int i) { // 16
+ assert(i < jcp.alpha * jcp.alpha);
+ return Zmm(31 - i);
+ }
+
+ Zmm vreg_stg(int id) { // 8
+ const int id_reg_stg = jcp.alpha * jcp.alpha + id;
+ assert(id_reg_stg < jcp.alpha * jcp.alpha + 8);
+ return Zmm(31 - id_reg_stg);
+ }
+
+ Zmm vreg_out(int id) { // 4
+ const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
+ assert(id_reg_out < jcp.alpha * jcp.alpha + 12);
+ return Zmm(31 - id_reg_out);
+ }
+
+ Zmm vreg_tmp(int id) { // 2
+ const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
+ assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14);
+ return Zmm(31 - id_reg_tmp);
+ }
+
+ Zmm vreg_zero = Zmm(0);
+ Zmm vreg_prev_dst = Zmm(0);
+ Zmm vreg_bias = Zmm(2);
+
+ Opmask y_mask = Opmask(1);
+ Opmask r_mask = Opmask(2);
+ Opmask x_mask(int id) {
+ assert (id < 4);
+ return Opmask(3 + id);
+ }
+
+ Reg64 reg_ptr_v_y_masks = r12;
+ Reg64 reg_ptr_v_x_masks = r11;
+
+ Reg64 reg_aux_ptr_src = r10;
+ Reg64 reg_aux_ptr_dst = r9;
+
+ Reg64 reg_oc_block = r8;
+
+ Reg64 reg_ptr_bias = rbx;
+ Reg64 reg_ptr_scales = abi_not_param1;
+ Reg64 reg_ptr_sum_scale = rdx;
+};
+
+bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) {
+ using namespace primitive_kind;
+ const auto &p = attr_.post_ops_;
+
+ if (position == 0) {
+ /* relu before sum */
+ return false
+ || p.contain(eltwise, 0);
+ } else if (position == 1) {
+ /* relu after sum */
+ const int sum_idx = p.contain(sum, 0)
+ ? 0 : (p.contain(sum, 1) ? 1 : -1);
+ if (sum_idx == -1)
+ return false;
+
+ return false
+ || p.contain(eltwise, sum_idx + 1);
+ }
+
+ return false;
+}
+
+void jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::generate() {
+ Label oc_block_label;
+
+ const int load_block = 16;
+
+ auto loop_body = [=]() {
+ const auto &p = attr_.post_ops_;
+ const int sum_idx = p.find(primitive_kind::sum);
+ const float *p_sum_scale = (sum_idx != -1)
+ ? &p.entry_[sum_idx].sum.scale
+ : nullptr;
+ if (p_sum_scale && *p_sum_scale != 1.f)
+ mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
+
+ for (int i = 0; i < 16; i++) {
+ int internal_offset = sizeof(float) * jcp.out_stride * i;
+ vmovups(vreg_inp(i),
+ EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
+ }
+ for (int y = 0; y < jcp.alpha; y++) {
+ vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1));
+ vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2));
+
+ vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2));
+ vsubps(vreg_stg(y * 2+1), vreg_tmp(1), vreg_inp(y * 4 + 3));
+ }
+ for (int x = 0; x < jcp.m; x++) {
+ vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2 * 1));
+ vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x+2 * 2));
+
+ vsubps(vreg_tmp(1), vreg_stg(x+2 * 1), vreg_stg(x+2 * 2));
+ vsubps(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2 * 3));
+ }
+
+
+ if (jcp.with_bias) {
+ auto bias_addr = ptr [ reg_ptr_bias ];
+ vmovups(vreg_bias, bias_addr);
+ }
+ for (int y = 0; y < jcp.m; y++) {
+ kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]);
+ for (int x = 0; x < jcp.m; x++) {
+ kandw(r_mask, y_mask, x_mask(x));
+
+ int i = y * jcp.m + x;
+ int offset = sizeof(float) *
+ (y * jcp.ow * jcp.oc_block + x * jcp.oc_block);
+ Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
+
+ Zmm zmm = vreg_out(i);
+ if (jcp.with_bias)
+ vaddps(zmm, zmm, vreg_bias);
+ vmulps(zmm, zmm, ptr [reg_ptr_scales]);
+
+ if (maybe_relu(0)) {
+ vxorps(vreg_zero, vreg_zero, vreg_zero);
+ vmaxps(zmm, vreg_zero, zmm);
+ }
+ if (p_sum_scale) { // post_op: sum
+ vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
+ vmovups(vreg_prev_dst | r_mask, addr);
+ if (*p_sum_scale == 1.f)
+ vaddps(zmm, vreg_prev_dst);
+ else
+ vfmadd231ps(zmm, vreg_prev_dst,
+ zword_b[reg_ptr_sum_scale]);
+ }
+ if (maybe_relu(1)) {
+ vxorps(vreg_zero, vreg_zero, vreg_zero);
+ vmaxps(zmm, vreg_zero, zmm);
+ }
+
+ vmovups(addr, zmm | r_mask);
+ }
+ }
+ };
+
+ preamble();
+
+#define READ_PARAM(reg, field) \
+ mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
+ READ_PARAM(reg_aux_ptr_src, wino_dst);
+ READ_PARAM(reg_aux_ptr_dst, dst);
+ READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
+ READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
+ READ_PARAM(reg_ptr_bias, bias);
+ READ_PARAM(reg_ptr_scales, scales);
+#undef READ_PARAM
+
+ for (int i = 0; i < jcp.alpha * jcp.alpha; i++)
+ vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i));
+
+ for (int i = 0; i < jcp.alpha; i++)
+ kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
+
+ int oc_blocks = 1;
+ oc_blocks = jcp.oc / load_block;
+ mov(reg_oc_block, oc_blocks);
+ L(oc_block_label);
+ {
+ loop_body();
+ add(reg_aux_ptr_src, sizeof(float) * load_block);
+ add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block);
+
+ add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
+ add(reg_ptr_bias, jcp.typesize_bia * load_block);
+ }
+ dec(reg_oc_block);
+ cmp(reg_oc_block, 0);
+ jg(oc_block_label, T_NEAR);
+
+ sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
+ sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block);
+
+ postamble();
+
+}
+
+/// GEMM kernel ////////////////////////////////////////////////////////////////
+struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t)
+ jit_conv_conf_2x3_wino_t jcp;
+
+ struct call_params_t {
+ const void *src;
+ const void *dst;
+ const void *wei;
+ const void *dst_b;
+ };
+ void (*ker_)(const call_params_t *);
+
+ void generate();
+ static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp,
+ const primitive_attr_t &attr);
+
+ jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
+ jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
+ : jcp(ajcp) {
+ generate();
+ ker_ = reinterpret_cast<decltype(ker_)>(
+ const_cast<uint8_t *>(getCode()));
+ }
+
+ static status_t init_conf(
+ jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
+ memory_desc_t &src_md, memory_desc_t &weights_md,
+ memory_desc_t &dst_md, memory_desc_t &bias_md,
+ const primitive_attr_t &attr,
+ memory_desc_t& expect_wei_md);
+
+ Zmm vreg_out(int n, int m) {
+ const int id_reg_out = n * jcp.m_block + m;
+ assert(id_reg_out < jcp.n2_block * jcp.m_block);
+ return Zmm(31 - id_reg_out);
+ }
+ Zmm vreg_wei(int i) {
+ assert (31 - jcp.n2_block * jcp.m_block - i > 1);
+ return Zmm(31 - jcp.n2_block * jcp.m_block - i);
+ }
+
+ Zmm vreg_src = Zmm(0);
+ Zmm vreg_one = Zmm(1);
+ Zmm vreg_tmp = Zmm(2);
+
+ Reg64 reg_ptr_src = r15;
+
+ Reg64 reg_aux_dst = r12;
+ Reg64 reg_aux_dst2 = r11;
+ Reg64 reg_aux_wei = r10;
+ Reg64 reg_aux_wei2 = r9;
+ Reg64 reg_aux_src = r8;
+ Reg64 reg_aux_src2 = rax;
+
+ Reg64 reg_mb = rbx;
+ Reg64 reg_nnb = rdx;
+ Reg64 reg_K = rsi;
+
+};
+
+bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok(
+ jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
+ using namespace primitive_kind;
+ const auto &p = attr.post_ops_;
+
+ auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
+
+ switch (p.len_) {
+ case 0: return true;
+ case 1: return is_relu(0) || p.contain(sum, 0);
+ case 2: return (p.contain(sum, 0) && is_relu(1)) ||
+ (p.contain(sum, 1) && is_relu(0));
+ case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
+ default: return false;
+ }
+
+ return false;
+}
+
+void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() {
+ Label nnb_loop_label, K_loop_label, mb_loop_label;
+
+ preamble();
+#define READ_PARAM(reg, field) \
+ mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
+ READ_PARAM(reg_ptr_src, src);
+ READ_PARAM(reg_aux_dst, dst);
+ READ_PARAM(reg_aux_wei, wei);
+#undef READ_PARAM
+
+ if (!jcp.small_mb) {
+ mov(reg_nnb, jcp.n_chunks);
+ L(nnb_loop_label);
+ }
+ mov(reg_aux_dst2, reg_aux_dst);
+ mov(reg_aux_src, reg_ptr_src);
+ mov(reg_mb, jcp.M / jcp.m_block);
+ L(mb_loop_label);
+ {
+ int nb2 = 0;
+ for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
+ for (int m = 0; m < jcp.m_block; m++) {
+ vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m));
+ }
+ }
+ mov(reg_aux_src2, reg_aux_src);
+ mov(reg_aux_wei2, reg_aux_wei);
+
+ mov(reg_K, jcp.k_chunks);
+ L(K_loop_label); {
+ int wei_offset = 0;
+ for (int _i = 0; _i < jcp.k2_block; _i++) {
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
+ if (jcp.small_mb) {
+ int wei_offset = sizeof(float)
+ * ((nb2 * jcp.nb_ic * jcp.ic_block
+ * jcp.oc_block)
+ + _i * jcp.oc_block);
+ vmovups(vreg_wei(nb2),
+ EVEX_compress_addr(reg_aux_wei2, wei_offset));
+ } else {
+ vmovups(vreg_wei(nb2),
+ EVEX_compress_addr(reg_aux_wei2,
+ sizeof(float) * wei_offset));
+ wei_offset += jcp.oc_block;
+ }
+ }
+ for (int m = 0; m < jcp.m_block; m++) {
+ int inp_offset = sizeof(float) * (m * jcp.K + _i);
+ if (jcp.n2_block > 1) {
+ vbroadcastss(vreg_src,
+ EVEX_compress_addr(reg_aux_src2, inp_offset));
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
+ vfmadd231ps(vreg_out(nb2, m), vreg_wei(nb2),
+ vreg_src);
+ } else {
+ vfmadd231ps(vreg_out(0, m), vreg_wei(0),
+ EVEX_compress_addr(reg_aux_src2, inp_offset, true));
+ }
+ }
+ }
+ add(reg_aux_src2, sizeof(float) * jcp.ic_block);
+ if (jcp.small_mb)
+ add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block);
+ else
+ add(reg_aux_wei2,
+ sizeof(float) * jcp.k2_block * jcp.n2_block
+ * jcp.oc_block);
+ }
+ dec(reg_K);
+ cmp(reg_K, 0);
+ jg(K_loop_label, T_NEAR);
+
+ for (int m = 0; m < jcp.m_block; m++) {
+ int nb2 = 0;
+ for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
+ int offset = sizeof(float) *
+ (m * jcp.N + nb2 * jcp.oc_block);
+ vmovups(EVEX_compress_addr(reg_aux_dst2,offset),
+ vreg_out(nb2, m));
+ }
+ }
+ add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K);
+ add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N);
+ }
+ dec(reg_mb);
+ cmp(reg_mb, 0);
+ jg(mb_loop_label, T_NEAR);
+
+ if (!jcp.small_mb) {
+ add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block);
+ add(reg_aux_wei,
+ sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block
+ * jcp.oc_block);
+
+ dec(reg_nnb);
+ cmp(reg_nnb, 0);
+ jg(nnb_loop_label, T_NEAR);
+ }
+ postamble();
+}
+
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
+ return jcp.mb >= 4;
+}
+}
+
+status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
+ jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
+ memory_desc_t &src_md, memory_desc_t &wei_md,
+ memory_desc_t &dst_md, memory_desc_t &bias_md,
+ const primitive_attr_t &attr, memory_desc_t &expect_wei_md) {
+ const memory_desc_wrapper src_d(&src_md);
+ const memory_desc_wrapper wei_d(&wei_md);
+ const memory_desc_wrapper dst_d(&dst_md);
+ const memory_desc_wrapper bias_d(&bias_md);
+
+ const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
+
+ jcp.nthr = mkldnn_get_max_threads();
+
+ jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.ih = src_d.dims()[2];
+ jcp.iw = src_d.dims()[3];
+ jcp.oh = dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[3];
+ jcp.kh = wei_d.dims()[with_groups + 2];
+ jcp.kw = wei_d.dims()[with_groups + 3];
+ jcp.t_pad = cd.padding[0][0];
+ jcp.b_pad = cd.padding[1][0];
+ jcp.l_pad = cd.padding[0][1];
+ jcp.r_pad = cd.padding[1][1];
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+ jcp.dilate_h = cd.dilates[0];
+ jcp.dilate_w = cd.dilates[1];
+
+ jcp.m = 2;
+ jcp.r = 3;
+ jcp.alpha = jcp.m + jcp.r - 1;
+ int simdw = 16;
+
+ format_tag_t dat_tag = format_tag::nChw16c;
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+
+ if (jcp.src_tag != dat_tag) return status::unimplemented;
+ if (jcp.dst_tag != dat_tag) return status::unimplemented;
+
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ bool ok_to_pad_channels = jcp.ngroups == 1;
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simdw);
+ jcp.ic = rnd_up(jcp.ic, simdw);
+ }
+
+ jcp.ver = ver_avx512_core;
+ if (!(mayiuse(avx512_core)))
+ return status::unimplemented;
+
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
+ if (src_d.data_type() != data_type::f32)
+ return status::unimplemented;
+ if (wei_d.data_type() != data_type::f32)
+ return status::unimplemented;
+ if (dst_d.data_type() != data_type::f32)
+ return status::unimplemented;
+
+ jcp.ic_block = simdw;
+ jcp.oc_block = simdw;
+
+ bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1
+ && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
+ && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0
+ && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad
+ && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0
+ && jcp.l_pad < 2 && jcp.l_pad >= 0;
+ if (!ok)
+ return status::unimplemented;
+
+ const int L2_cap = get_cache_size(2, true) / sizeof(float);
+ const int L3_capacity = get_cache_size(3, false) / sizeof(float);
+ int a = jcp.alpha;
+ int aa = a * a;
+ int mb = jcp.mb;
+ int ic = jcp.ic;
+ int oc = jcp.oc;
+ int ih = jcp.ih;
+ int iw = jcp.iw;
+ auto wei_sz = (float)aa * ic * oc;
+ auto inp_sz = (float)mb * ih * iw * ic;
+ auto sp_sz = (float)mb * ih * iw;
+
+ /* Heuristics here. Numbers '28','196' is an observation from data. */
+ if (wei_sz / inp_sz > 5)
+ jcp.small_mb = true;
+ else
+ jcp.small_mb = false;
+
+ if (mb > nstl::min(jcp.nthr, 28)
+ || (!jcp.small_mb
+ && (wei_sz >= 0.9f * L2_cap
+ || inp_sz > L2_cap * jcp.nthr + L3_capacity))
+ || (jcp.small_mb && sp_sz > 196))
+ return status::unimplemented;
+
+ jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+ jcp.dst_dt = cd.dst_desc.data_type;
+
+ jcp.typesize_bia
+ = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
+
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+
+ const int skx_free_regs = 30;
+
+ auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block,
+ int &n2_block, float &reg_eff) {
+ M = (xb * yb) / jcp.alpha;
+ int max_m_block = m_block = nstl::min(M, skx_free_regs);
+ int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs);
+ reg_eff = 0;
+ for (int im = max_m_block; im > 0; im--) {
+ for (int in2 = max_n2_block; in2 > 0; in2--) {
+ int used_regs = in2 * im + in2;
+ float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f;
+ if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs
+ || cur_reg_eff <= reg_eff)
+ continue;
+ reg_eff = cur_reg_eff;
+ m_block = im;
+ n2_block = in2;
+ }
+ }
+ };
+
+ int oh = jcp.oh;
+ int ow = jcp.ow;
+ int nb_oc = jcp.nb_oc;
+ int Z = ic + oc;
+ int Y = ic * oc;
+ const int L3_cap_per_core = get_cache_size(3, true) / sizeof(float);
+
+ /* Selecting xb and yb blocking */
+ int min_yb = jcp.alpha;
+ int min_xb = jcp.alpha;
+ int max_yb = nstl::max(min_yb, rnd_up(ih, 2));
+ int max_xb = nstl::max(min_xb, rnd_up(iw, 2));
+ float best_eff = 0.f;
+ for (int ix = max_xb; ix >= min_xb; ix -= 2) {
+ if (rnd_up(ow, ix) < iw - 2)
+ continue;
+ for (int iy = max_yb; iy >= min_yb; iy -= 2) {
+ if (rnd_up(oh, iy) < ih - 2)
+ continue;
+ int ex_y = rnd_up(oh, iy);
+ int ex_x = rnd_up(ow, ix);
+ float work_eff = (float)(ih * iw) / (ex_y * ex_x);
+
+ int M, m_block, n2_b;
+ float reg_eff, thr_eff, par_eff, mem_eff, req_mem;
+
+ find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff);
+
+ /* outer parallelization */
+ int nblocks = mb * div_up(ih, iy) * div_up(iw, ix);
+ thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
+
+ mem_eff = 1.f;
+ req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y;
+ if (req_mem > L2_cap / 2) {
+ if (req_mem > ((L2_cap + L3_cap_per_core) * 4) / 7)
+ mem_eff /= (n2_b + 1) / 2.f;
+ else
+ mem_eff /= (n2_b + 1) / 3.f;
+ }
+
+ float outer_eff = thr_eff + work_eff + reg_eff + mem_eff;
+
+ /* inner parallelization */
+ int bsz = iy * ix / a;
+ int gemmw = aa * (nb_oc / n2_b);
+ int bsz_r = rnd_up(bsz, jcp.nthr);
+ int gemmw_r = rnd_up(gemmw, jcp.nthr);
+ thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y);
+
+ req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic;
+ mem_eff = nstl::min(1.f, L2_cap / req_mem);
+ int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr));
+ int oc_per_thr =
+ nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr));
+ req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z;
+ if (req_mem > L2_cap)
+ mem_eff = 0.1f;
+ par_eff = 1 / (2.f * nblocks);
+
+ float inner_eff = thr_eff + work_eff + mem_eff + par_eff;
+
+ float eff = jcp.small_mb ? inner_eff : outer_eff;
+ if (eff > best_eff) {
+ best_eff = eff;
+ jcp.yb = iy;
+ jcp.xb = ix;
+ jcp.M = M;
+ jcp.m_block = m_block;
+ jcp.n2_block = n2_b;
+ }
+ }
+ }
+
+ assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
+
+ jcp.inp_stride = jcp.M * jcp.ic;
+ jcp.out_stride = jcp.M * jcp.oc;
+ jcp.wei_stride = jcp.ic * jcp.oc;
+ jcp.bia_stride = jcp.oc;
+
+ jcp.N = jcp.oc;
+ jcp.K = jcp.ic;
+
+ jcp.n_block = jcp.oc_block;
+ jcp.k_block = jcp.ic_block;
+
+ assert(jcp.M % jcp.m_block == 0);
+ assert(jcp.nb_oc % jcp.n2_block == 0);
+
+ jcp.n_chunks = jcp.nb_oc / jcp.n2_block;
+ jcp.k2_block = jcp.ic_block;
+ jcp.k_chunks = jcp.K / jcp.k2_block;
+
+ const auto &oscales = attr.output_scales_;
+ jcp.is_oc_scale = oscales.mask_ == 1 << 1;
+ assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
+
+ /* re-create weights primitive descriptor
+ and set weights wino_blocking */
+ expect_wei_md.format_kind = format_kind::wino;
+ expect_wei_md.data_type = data_type::f32;
+ mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
+ wd.wino_format
+ = jcp.small_mb ? mkldnn_wino_wei_aaOio : mkldnn_wino_wei_aaOBiOo;
+ wd.r = jcp.r;
+ wd.alpha = jcp.alpha;
+ wd.ic = jcp.ic;
+ wd.oc = jcp.oc;
+ wd.ic_block = jcp.ic_block;
+ wd.oc_block = jcp.oc_block;
+ wd.oc2_block = jcp.n2_block;
+ wd.ic2_block = 1;
+ wd.adj_scale = 1.f;
+ size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
+ wd.size = max_size;
+
+ return status::success;
+}
+////////////////////////////////////////////////////////////////////////////////
+
+status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t
+ ::pd_t::jit_conf(memory_desc_t& expect_wei_md) {
+ return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf(
+ jcp_, *this->desc(), this->src_md_, this->weights_md_,
+ this->dst_md_,this->bias_md_, *this->attr(), expect_wei_md);
+}
+
+jit_avx512_core_fp32_wino_conv_2x3_fwd_t::
+ jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+{
+ kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
+ pd()->jcp_, *pd()->attr());
+ src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
+ pd()->jcp_, *pd()->attr());
+ dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
+ pd()->jcp_, *pd()->attr());
+}
+
+jit_avx512_core_fp32_wino_conv_2x3_fwd_t
+ ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() {
+ delete kernel_;
+ delete src_trans_;
+ delete dst_trans_;
+}
+
+void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN(
+ const float *src, const float *wei, const float *bia, float *dst,
+ const memory_tracking::grantor_t &scratchpad) const
+{
+ const auto &jcp = kernel_->jcp;
+ const auto &oscales = pd()->attr()->output_scales_;
+
+ const size_t wino_size_offset =
+ (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb);
+ const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16;
+ const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16;
+
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+ utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bia = padded_bias;
+ }
+
+ auto ptr_V = scratchpad.get<float>(key_wino_V);
+ auto ptr_M = scratchpad.get<float>(key_wino_M);
+
+ parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb),
+ [&](int mb, int tile_y_b, int tile_x_b) {
+ int tile_y = tile_y_b * jcp.yb;
+ int tile_x = tile_x_b * jcp.xb;
+
+ int ithr = mkldnn_get_thread_num();
+ auto wino_src = ptr_V + size_wino_src * ithr;
+ auto wino_dst = ptr_M + size_wino_dst * ithr;
+
+ auto src_trans_p =
+ jit_avx512_core_fp32_wino_conv_2x3_src_trans_t
+ ::call_params_t();
+ auto dst_trans_p =
+ jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t
+ ::call_params_t();
+ auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
+ call_params_t();
+
+ /* transformation of input tensor to winograd domain */
+ for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
+ for (int x_in_block = 0; x_in_block < jcp.xb;
+ x_in_block += 2) {
+
+ unsigned short v_y_masks[4], v_x_masks[4];
+
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (y_in_block / 2) * (jcp.xb / 2)
+ + (x_in_block / 2);
+
+ int v_ys = nstl::max(0, jcp.t_pad - y);
+ int v_ye = nstl::min(jcp.alpha,
+ nstl::max(0, jcp.ih + jcp.t_pad - y));
+
+ int v_xs = nstl::max(0, jcp.l_pad - x);
+ int v_xe = nstl::min(jcp.alpha,
+ nstl::max(0, jcp.iw + jcp.l_pad - x));
+
+#pragma unroll(4)
+ for (int i = 0; i < jcp.alpha; i++) {
+ v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
+ v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
+ }
+ auto local_s = src
+ + mb * jcp.nb_ic * jcp.ih * jcp.iw
+ * jcp.ic_block
+ + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
+ auto local_w = wino_src + m * jcp.ic;
+
+ src_trans_p.src = local_s;
+ src_trans_p.wino_src = local_w;
+ src_trans_p.v_y_masks = v_y_masks;
+ src_trans_p.v_x_masks = v_x_masks;
+
+ src_trans_->ker_(&src_trans_p);
+ }
+ }
+ /* gemms */
+ for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
+ int offset = (tile_ij + ithr) % 16;
+ gemm_p.src = wino_src + jcp.inp_stride * offset;
+ gemm_p.dst = wino_dst + jcp.out_stride * offset;
+ gemm_p.wei = wei + jcp.wei_stride * offset;
+
+ kernel_->ker_(&gemm_p);
+ }
+
+ /* transformation from winograd domain to output tensor */
+ for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
+ for (int x_in_block = 0; x_in_block < jcp.xb;
+ x_in_block += 2) {
+ unsigned short v_y_masks[2], v_x_masks[2];
+
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (y_in_block / 2) * (jcp.xb / 2)
+ + (x_in_block / 2);
+
+#pragma unroll(2)
+ for (int i = 0; i < jcp.m; i++) {
+ v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
+ v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
+ }
+ auto local_d = dst
+ + mb * jcp.nb_oc * jcp.oh * jcp.ow
+ * jcp.oc_block
+ + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
+ auto local_w = wino_dst + m * jcp.oc;
+
+ auto scales = oscales.scales_;
+ dst_trans_p.dst = local_d;
+ dst_trans_p.wino_dst = local_w;
+ dst_trans_p.v_y_masks = v_y_masks;
+ dst_trans_p.v_x_masks = v_x_masks;
+
+ dst_trans_p.scales = scales;
+ dst_trans_p.bias = bia;
+
+ dst_trans_->ker_(&dst_trans_p);
+ }
+ }
+ });
+}
+
+void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb(
+ const float *src, const float *wei, const float *bia, float *dst,
+ const memory_tracking::grantor_t &scratchpad) const
+{
+ const auto &jcp = kernel_->jcp;
+ const auto &oscales = pd()->attr()->output_scales_;
+
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+ utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bia = padded_bias;
+ }
+
+ auto ptr_V = scratchpad.get<float>(key_wino_V);
+ auto ptr_M = scratchpad.get<float>(key_wino_M);
+
+ for (int mb = 0; mb < jcp.mb; mb++) {
+ for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
+ for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
+ /* transformation of input tensor to winograd domain */
+ parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
+ [&](int y_in_block_b, int x_in_block_b) {
+ int y_in_block = y_in_block_b * 2;
+ int x_in_block = x_in_block_b * 2;
+
+ auto src_trans_p = jit_avx512_core_fp32_wino_conv_2x3_src_trans_t ::
+ call_params_t();
+
+ unsigned short v_y_masks[4], v_x_masks[4];
+
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+
+ int v_ys = nstl::max(0, jcp.t_pad - y);
+ int v_ye = nstl::min(
+ jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
+
+ int v_xs = nstl::max(0, jcp.l_pad - x);
+ int v_xe = nstl::min(
+ jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
+
+#pragma unroll(4)
+ for (int i = 0; i < jcp.alpha; i++) {
+ v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
+ v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
+ }
+ auto local_s = src
+ + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block
+ + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
+ auto local_w = ptr_V + m * jcp.ic;
+
+ src_trans_p.src = local_s;
+ src_trans_p.wino_src = local_w;
+ src_trans_p.v_y_masks = v_y_masks;
+ src_trans_p.v_x_masks = v_x_masks;
+
+ src_trans_->ker_(&src_trans_p);
+ });
+
+ /* gemms */
+ parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
+ auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
+ call_params_t();
+
+ gemm_p.src = ptr_V + jcp.inp_stride * tile_ij;
+ gemm_p.dst = ptr_M + jcp.out_stride * tile_ij
+ + nnb * jcp.n2_block * jcp.n_block;
+ gemm_p.wei = wei + jcp.wei_stride * tile_ij
+ + nnb * jcp.n2_block * jcp.n_block * jcp.K;
+
+ kernel_->ker_(&gemm_p);
+ });
+
+ /* transformation from winograd domain to output tensor */
+
+ parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
+ [&](int y_in_block_b, int x_in_block_b) {
+ int y_in_block = y_in_block_b * 2;
+ int x_in_block = x_in_block_b * 2;
+
+ auto dst_trans_p = jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t ::
+ call_params_t();
+
+ unsigned short v_y_masks[2], v_x_masks[2];
+
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+
+#pragma unroll(2)
+ for (int i = 0; i < jcp.m; i++) {
+ v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
+ v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
+ }
+ auto local_d = dst
+ + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block
+ + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
+ auto local_w = ptr_M + m * jcp.oc;
+
+ auto scales = oscales.scales_;
+ dst_trans_p.dst = local_d;
+ dst_trans_p.wino_dst = local_w;
+ dst_trans_p.v_y_masks = v_y_masks;
+ dst_trans_p.v_x_masks = v_x_masks;
+
+ dst_trans_p.scales = scales;
+ dst_trans_p.bias = bia;
+
+ dst_trans_->ker_(&dst_trans_p);
+ });
+ }}}
+}
+
+} // namespace cpu
+} // namespace impl
+} // namespace mkldnn
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp
new file mode 100644
index 0000000000..7e38b07f5a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.hpp
@@ -0,0 +1,144 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP
+#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_2x3_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_primitive_conf.hpp"
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t;
+struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t;
+struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t;
+
+struct jit_avx512_core_fp32_wino_conv_2x3_fwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_fp32_wino_2x3:", avx512_core, ""),
+ jit_avx512_core_fp32_wino_conv_2x3_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::forward_inference
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ memory_desc_t expect_wei_md = *weights_md();
+ status_t jit_conf_result = jit_conf(expect_wei_md);
+ if (jit_conf_result != status::success) return jit_conf_result;
+ set_default_alg_kind(alg_kind::convolution_winograd);
+
+ if (weights_md_.format_kind == format_kind::any)
+ weights_md_ = expect_wei_md;
+ if (weights_md_ != expect_wei_md)
+ return status::unimplemented;
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ jit_conv_conf_2x3_wino_t jcp_;
+
+ protected:
+ status_t jit_conf(memory_desc_t& expect_wei_md);
+
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+
+ auto scratchpad = scratchpad_registry().registrar();
+
+ int wino_size_offset = (jcp_.yb / 2) * (jcp_.xb / 2) + jcp_.xb;
+
+ size_t V_sz = (size_t)jcp_.ic * 16 * wino_size_offset * jcp_.nthr;
+ scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_4K);
+
+ size_t M_sz = (size_t)jcp_.oc * 16 * wino_size_offset * jcp_.nthr;
+ scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_4K);
+
+ if (wants_padded_bias()) {
+ assert(jcp_.ngroups == 1);
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp_.oc);
+ }
+ }
+
+ bool set_default_formats() {
+ using namespace format_tag;
+ return set_default_formats_common(nChw16c, any, nChw16c);
+ }
+ };
+
+ jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd);
+ ~jit_avx512_core_fp32_wino_conv_2x3_fwd_t();
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
+ auto wei = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
+ auto bia = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST);
+
+ if (pd()->jcp_.small_mb)
+ execute_forward_small_mb(src, wei, bia, dst, this->scratchpad(ctx));
+ else
+ execute_forward_mbN(src, wei, bia, dst, this->scratchpad(ctx));
+
+ return status::success;
+ }
+
+private:
+ void execute_forward_small_mb(const float *src, const float *wei,
+ const float *bia, float *dst,
+ const memory_tracking::grantor_t &scratchpad) const;
+ void execute_forward_mbN(const float *src, const float *wei,
+ const float *bia, float *dst,
+ const memory_tracking::grantor_t &scratchpad) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t *kernel_;
+ jit_avx512_core_fp32_wino_conv_2x3_src_trans_t *src_trans_;
+ jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t *dst_trans_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp
new file mode 100644
index 0000000000..96325e3ade
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp
@@ -0,0 +1,1020 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifdef __INTEL_COMPILER
+#include <immintrin.h>
+#endif
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_avx512_core_fp32_wino_conv_4x3.hpp"
+
+#ifndef _MSC_VER
+#define pragma_unroll _Pragma("unroll")
+#else
+#define pragma_unroll
+#endif
+
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+template <bool is_fwd>
+void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
+::weight_transform_data(const jit_conv_winograd_conf_t &jcp,
+ float *wp, float *twp) const
+{
+ float G[] = {0.26890756302521f, 0.688403361344538f, 0.119514472455649f,
+ 1.13777777777778f, 0.430252100840336f, 0.179271708683473f};
+ const int kh = 3;
+ const int kw = 3;
+ float Fw[alpha][alpha][simd_w][simd_w];
+ float F[kh][kw][simd_w][simd_w];
+ float T[alpha][3][simd_w];
+ auto p = jit_wino_transform_call_s();
+
+ p.src = wp;
+ p.dst = twp;
+ p.G = G;
+ p.M = F;
+ p.Mw = Fw;
+ p.T = T;
+
+ kernel_->weights_transform_data_ker(&p);
+}
+
+template<bool is_fwd>
+void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::output_transform_data
+(int image, const jit_conv_winograd_conf_t &jcp,
+ const post_ops_t &p_ops, float *toutp, float *pout_b, float *bias) const {
+
+ float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
+ float Ow[alpha][alpha][simd_w];
+ float O[tile_size][tile_size][simd_w];
+ float T[tile_size][alpha][simd_w];
+
+ auto p = jit_wino_transform_call_s();
+ p.src = toutp;
+ p.dst = pout_b;
+ p.G = G;
+ p.M = O;
+ p.Mw = Ow;
+ p.T = T;
+ p.bias = bias;
+
+ int tile_base_index = image * jcp.itiles * jcp.jtiles;
+ int tile_block_ur = tile_base_index % jcp.tile_block_ur;
+ int nb_tile_block_ur =
+ (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
+ int tile_block =
+ (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
+
+ for (int tj = 0; tj < jcp.jtiles; tj++) {
+ for (int ti = 0; ti < jcp.itiles; ti++) {
+
+ p.tile_block_ur = tile_block_ur;
+ p.nb_tile_block_ur = nb_tile_block_ur;
+ p.tile_block = tile_block;
+ p.tj = tj;
+ p.ti = ti;
+
+ kernel_->output_transform_data_ker(&p);
+
+ tile_block_ur++;
+ if (tile_block_ur >= jcp.tile_block_ur) {
+ tile_block_ur = 0;
+ nb_tile_block_ur++;
+ }
+ if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
+ nb_tile_block_ur = 0;
+ tile_block++;
+ }
+ }
+ }
+}
+
+template<bool is_fwd>
+void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
+::output_transform_tileblock_data(int tile_block,
+ const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
+ float *toutp, float *outp, float *bias) const {
+
+ float G[] = {0.625f, 1.5f, 0.390625f, 2.25f, 0.244140625f, 3.375f};
+ float Ow[alpha][alpha][simd_w];
+ float O[tile_size][tile_size][simd_w];
+ float T[tile_size][alpha][simd_w];
+
+ auto p = jit_wino_transform_call_s();
+ p.src = toutp;
+ p.dst = outp;
+ p.G = G;
+ p.M = O;
+ p.Mw = Ow;
+ p.T = T;
+ p.bias = bias;
+
+ int outw = is_fwd ? jcp.ow : jcp.iw;
+ int outh = is_fwd ? jcp.oh : jcp.ih;
+
+ int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
+
+ for (int nb_tile_block_ur = 0;
+ nb_tile_block_ur < jcp.nb_tile_block_ur;
+ nb_tile_block_ur++) {
+
+ for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
+ tile_block_ur++) {
+ int img = tile_index / (jcp.jtiles * jcp.itiles);
+ int ti = tile_index % jcp.itiles;
+ int tj = (tile_index / jcp.itiles) % jcp.jtiles;
+
+ p.tile_block_ur = tile_block_ur;
+ p.nb_tile_block_ur = nb_tile_block_ur;
+ p.tile_block = tile_block;
+ p.tj = tj;
+ p.ti = ti;
+ p.dst = outp + img * (jcp.dimM / jcp.dimM_simd_block)
+ * outh * outw * jcp.dimM_simd_block;
+
+ kernel_->output_transform_data_ker(&p);
+
+ tile_index++;
+ }
+ }
+}
+
+
+template<bool is_fwd>
+void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
+ ::input_transform_data(int image, const jit_conv_winograd_conf_t &jcp,
+ float *inp, float *tinp) const
+{
+ float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
+ 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
+
+ float Iw[alpha][alpha][simd_w];
+ float I[alpha][alpha][simd_w];
+ float T[alpha][alpha][simd_w];
+
+ auto p = jit_wino_transform_call_s();
+
+ p.src = inp;
+ p.dst = tinp;
+ p.G = G;
+ p.M = I;
+ p.Mw = Iw;
+ p.T = T;
+
+ int tile_base_index = image * jcp.itiles * jcp.jtiles;
+ int tile_block_ur = tile_base_index % jcp.tile_block_ur;
+ int nb_tile_block_ur =
+ (tile_base_index / jcp.tile_block_ur) % jcp.nb_tile_block_ur;
+ int tile_block =
+ (tile_base_index / jcp.tile_block_ur) / jcp.nb_tile_block_ur;
+
+ for (int tj = 0; tj < jcp.jtiles; tj++) {
+ for (int ti = 0; ti < jcp.itiles; ti++) {
+
+ p.tile_block_ur = tile_block_ur;
+ p.nb_tile_block_ur = nb_tile_block_ur;
+ p.tile_block = tile_block;
+ p.tj = tj;
+ p.ti = ti;
+
+ kernel_->input_transform_data_ker(&p);
+
+ tile_block_ur++;
+ if (tile_block_ur >= jcp.tile_block_ur) {
+ tile_block_ur = 0;
+ nb_tile_block_ur++;
+ }
+ if (nb_tile_block_ur >= jcp.nb_tile_block_ur) {
+ nb_tile_block_ur = 0;
+ tile_block++;
+ }
+ }
+ }
+}
+
+template <bool is_fwd>
+void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>
+ ::input_transform_tileblock_data(int tile_block,
+ const jit_conv_winograd_conf_t &jcp,
+ float *inp, float *tinp) const
+{
+ float G[] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
+ 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
+ float Iw[alpha][alpha][simd_w];
+ float I[alpha][alpha][simd_w];
+ float T[alpha][alpha][simd_w];
+
+ const int inph = is_fwd ? jcp.ih : jcp.oh;
+ const int inpw = is_fwd ? jcp.iw : jcp.ow;
+
+ array_offset_calculator<float, 5> input(inp,
+ jcp.mb, jcp.dimK / simd_w, inph, inpw, simd_w);
+ array_offset_calculator<float, 7> output(tinp,
+ alpha, alpha,
+ jcp.dimN_block, jcp.dimK_nb_block, jcp.dimK_block,
+ jcp.dimN_reg_block, jcp.dimK_reg_block);
+
+ auto p = jit_wino_transform_call_s();
+
+ p.dst = tinp;
+ p.G = G;
+ p.M = I;
+ p.Mw = Iw;
+ p.T = T;
+
+
+ int tile_index = tile_block * jcp.nb_tile_block_ur * jcp.tile_block_ur;
+
+ for (int nb_tile_block_ur = 0;
+ nb_tile_block_ur < jcp.nb_tile_block_ur;
+ nb_tile_block_ur++) {
+
+ for (int tile_block_ur = 0; tile_block_ur < jcp.tile_block_ur;
+ tile_block_ur++) {
+
+ int img = tile_index / (jcp.jtiles * jcp.itiles);
+ int ti = tile_index % jcp.itiles;
+ int tj = (tile_index / jcp.itiles) % jcp.jtiles;
+ float *pinp_b = &(input(img, 0, 0, 0, 0));
+
+ p.src = pinp_b;
+ p.tile_block_ur = tile_block_ur;
+ p.nb_tile_block_ur = nb_tile_block_ur;
+ p.tj = tj;
+ p.ti = ti;
+
+ kernel_->input_transform_data_ker(&p);
+
+ tile_index++;
+ }
+ }
+}
+
+template <bool is_fwd>
+void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_S_G_D(
+ float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const auto &jcp = kernel_->jcp;
+ const auto &p_ops = attr_->post_ops_;
+
+ const int inph = is_fwd ? jcp.ih : jcp.oh;
+ const int inpw = is_fwd ? jcp.iw : jcp.ow;
+ const int outh = is_fwd ? jcp.oh : jcp.ih;
+ const int outw = is_fwd ? jcp.ow : jcp.iw;
+
+ /* Notation:
+ FWD: dimM:oc, dimN:ntiles, dimK:ic,
+ BWD: dimM:ic, dimN:ntiles, dimK:oc,
+ FWD/BWD: V: src/diff_dst transform, U:weight transform,
+ M:dst/diff_src transform */
+ array_offset_calculator<float, 5> input(inp_ptr,
+ jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw,
+ jcp.dimK_reg_block);
+ array_offset_calculator<float, 5> output(out_ptr,
+ jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw,
+ jcp.dimM_simd_block);
+ array_offset_calculator<float, 6> weights(wei_ptr,
+ jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
+ jcp.ic_simd_block, jcp.oc_simd_block);
+ array_offset_calculator<float, 2> bias(bias_ptr,
+ jcp.dimM/jcp.dimM_simd_block, jcp.dimM_simd_block);
+
+ array_offset_calculator<float, 8> M(is_fwd
+ ? scratchpad.template get<float>(key_wino_M)
+ : scratchpad.template get<float>(key_wino_V),
+ jcp.dimN_nb_block, jcp.dimM_nb_block,
+ alpha, alpha,
+ jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
+ jcp.dimN_reg_block, jcp.dimM_simd_block);
+
+ auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
+ ? wei_ptr
+ : scratchpad.template get<float>(key_wino_U);
+
+ array_offset_calculator<float, 8> U(wino_wei,
+ jcp.dimM_nb_block,
+ alpha, alpha,
+ jcp.dimK_nb_block,
+ jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
+ jcp.dimK_reg_block, jcp.dimM_simd_block);
+ array_offset_calculator<float, 8> V(is_fwd
+ ? scratchpad.template get<float>(key_wino_V)
+ : scratchpad.template get<float>(key_wino_M),
+ jcp.dimN_nb_block, alpha, alpha,
+ jcp.dimN_block, jcp.dimK_nb_block,
+ jcp.dimK_block, jcp.dimN_reg_block, jcp.dimK_reg_block);
+
+ const bool wants_padded_bias = jcp.with_bias
+ && jcp.oc_without_padding != jcp.oc;
+ float last_slice_bias[simd_w] = {0};
+ if (wants_padded_bias) {
+ for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
+ last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
+ }
+
+ {
+
+ parallel_nd(jcp.mb, jcp.dimK_nb_block, jcp.dimK_block,
+ [&](int img, int K_blk1, int K_blk2) {
+ input_transform_data(img, jcp,
+ &(input(img, K_blk1 * jcp.dimK_block + K_blk2,
+ 0, 0, 0)),
+ &(V(0, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
+ });
+
+ if (jcp.prop_kind != prop_kind::forward_inference) {
+ parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block),
+ (jcp.ic_block * jcp.ic_reg_block),
+ [&](int ofm1, int ifm1, int ofm2, int ifm2) {
+ float *U_base_ptr = is_fwd
+ ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
+ : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
+ weight_transform_data(jcp,
+ &(weights(
+ ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
+ ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
+ 0, 0, 0, 0)),
+ U_base_ptr);
+ });
+ }
+
+ parallel_nd(jcp.dimN_nb_block, alpha, alpha, jcp.dimM_nb_block,
+ [&](int N_blk1, int oj, int oi, int M_blk1) {
+ for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block;
+ K_blk1++)
+ for (int N_blk2 = 0; N_blk2 < jcp.dimN_block; N_blk2++)
+ kernel_->gemm_loop_ker(
+ (float *)&(M(N_blk1, M_blk1, oj, oi,
+ N_blk2, 0, 0, 0)),
+ (const float *)&(U(M_blk1, oj, oi,
+ K_blk1, 0, 0, 0, 0)),
+ (const float *)&(V(N_blk1, oj, oi,
+ N_blk2, K_blk1, 0, 0, 0)), K_blk1);
+ });
+
+ parallel_nd(jcp.mb, jcp.dimM_nb_block, (jcp.dimM_block * jcp.dimM_reg_block),
+ [&](int img, int M_blk1, int M_blk2) {
+ const int M_blk =
+ M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
+
+ float *bias_ptr = wants_padded_bias
+ && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
+ ? last_slice_bias : &bias(M_blk, 0);
+ output_transform_data(img, jcp, p_ops,
+ &(M(0, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
+ &(output(img, M_blk, 0, 0, 0)), bias_ptr);
+ });
+
+ }
+}
+
+template <bool is_fwd>
+void _jit_avx512_core_fp32_wino_conv_4x3_t<is_fwd>::_execute_data_W_SGD(
+ float *inp_ptr, float *out_ptr, float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const auto &jcp = kernel_->jcp;
+ const auto &p_ops = attr_->post_ops_;
+
+ const int inph = is_fwd ? jcp.ih : jcp.oh;
+ const int inpw = is_fwd ? jcp.iw : jcp.ow;
+ const int outh = is_fwd ? jcp.oh : jcp.ih;
+ const int outw = is_fwd ? jcp.ow : jcp.iw;
+
+ array_offset_calculator<float, 5> input(inp_ptr,
+ jcp.mb, jcp.dimK/jcp.dimK_reg_block, inph, inpw, jcp.dimK_reg_block);
+ array_offset_calculator<float, 5> output(out_ptr,
+ jcp.mb, jcp.dimM/jcp.dimM_simd_block, outh, outw, jcp.dimM_simd_block);
+ array_offset_calculator<float, 6> weights(wei_ptr,
+ jcp.oc/jcp.oc_simd_block, jcp.ic/jcp.ic_simd_block, jcp.kh, jcp.kw,
+ jcp.ic_simd_block, jcp.oc_simd_block);
+ array_offset_calculator<float, 2> bias(bias_ptr,
+ jcp.oc/jcp.oc_simd_block, jcp.oc_simd_block);
+
+ auto wino_wei = (jcp.prop_kind == prop_kind::forward_inference)
+ ? wei_ptr
+ : scratchpad.template get<float>(key_wino_U);
+
+ array_offset_calculator<float, 8> U(wino_wei,
+ jcp.dimM_nb_block,
+ alpha, alpha,
+ jcp.dimK_nb_block,
+ jcp.dimM_block * jcp.dimM_reg_block, jcp.dimK_block,
+ jcp.dimK_reg_block, jcp.dimM_simd_block);
+
+ array_offset_calculator<float, 8> M(is_fwd
+ ? scratchpad.template get<float>(key_wino_M)
+ : scratchpad.template get<float>(key_wino_V),
+ 0, jcp.dimM_nb_block, alpha, alpha,
+ jcp.dimN_block, jcp.dimM_block * jcp.dimM_reg_block,
+ jcp.dimN_reg_block, jcp.dimM_simd_block);
+ array_offset_calculator<float, 8> V(is_fwd
+ ? scratchpad.template get<float>(key_wino_V)
+ : scratchpad.template get<float>(key_wino_M),
+ 0, alpha, alpha, jcp.dimN_block,
+ jcp.dimK_nb_block, jcp.dimK_block,
+ jcp.dimN_reg_block, jcp.dimK_reg_block);
+
+ const bool wants_padded_bias = jcp.with_bias
+ && jcp.oc_without_padding != jcp.oc;
+ float last_slice_bias[simd_w] = {0};
+ if (wants_padded_bias) {
+ for (int oc = 0; oc < jcp.oc_without_padding % jcp.oc_simd_block; ++oc)
+ last_slice_bias[oc] = bias(jcp.dimM / jcp.dimM_simd_block - 1, oc);
+ }
+
+ if (jcp.prop_kind != prop_kind::forward_inference) {
+
+ parallel_nd(jcp.nb_oc, jcp.nb_ic, (jcp.oc_block * jcp.oc_reg_block), (jcp.ic_block * jcp.ic_reg_block),
+ [&](int ofm1, int ifm1, int ofm2, int ifm2) {
+ float *U_base_ptr = is_fwd
+ ? &(U(ofm1, 0, 0, ifm1, ofm2, ifm2, 0, 0))
+ : &(U(ifm1, 0, 0, ofm1, ifm2, ofm2, 0, 0));
+ weight_transform_data(jcp,
+ &(weights(
+ ofm1 * jcp.oc_block * jcp.oc_reg_block + ofm2,
+ ifm1 * jcp.ic_block * jcp.ic_reg_block + ifm2,
+ 0, 0, 0, 0)),
+ U_base_ptr);
+ });
+ }
+
+ parallel_nd(jcp.tile_block, [&](int tile_block) {
+ int ithr = mkldnn_get_thread_num();
+
+ for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++) {
+ for (int K_blk2 = 0; K_blk2 < jcp.dimK_block; K_blk2++) {
+
+ input_transform_tileblock_data(
+ tile_block, jcp,
+ &(input(0, K_blk1 * jcp.dimK_block + K_blk2, 0, 0, 0)),
+ &(V(ithr, 0, 0, 0, K_blk1, K_blk2, 0, 0)));
+ }
+ }
+
+ for (int oj = 0; oj < alpha; oj++) {
+ for (int oi = 0; oi < alpha; oi++) {
+ for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++)
+ for (int K_blk1 = 0; K_blk1 < jcp.dimK_nb_block; K_blk1++)
+ for (int N_blk = 0; N_blk < jcp.dimN_block; N_blk++)
+ kernel_->gemm_loop_ker(
+ (float *)&(M(ithr, M_blk1, oj, oi,
+ N_blk, 0, 0, 0)),
+ (const float *)&(U(M_blk1, oj, oi, K_blk1,
+ 0, 0, 0, 0)),
+ (const float *)&(V(ithr, oj, oi,
+ N_blk, K_blk1, 0, 0, 0)), K_blk1);
+ }
+ }
+
+ for (int M_blk1 = 0; M_blk1 < jcp.dimM_nb_block; M_blk1++) {
+ for (int M_blk2 = 0; M_blk2 < jcp.dimM_block * jcp.dimM_reg_block;
+ M_blk2++) {
+ const int M_blk =
+ M_blk1 * jcp.dimM_block * jcp.dimM_reg_block + M_blk2;
+
+ float *bias_ptr = wants_padded_bias
+ && M_blk == jcp.dimM / jcp.dimM_simd_block - 1
+ ? last_slice_bias : &bias(M_blk, 0);
+
+ output_transform_tileblock_data(tile_block, jcp, p_ops,
+ &(M(ithr, M_blk1, 0, 0, 0, M_blk2, 0, 0)),
+ &(output(0, M_blk, 0, 0, 0)), bias_ptr);
+ }
+ }
+ });
+}
+
+template struct _jit_avx512_core_fp32_wino_conv_4x3_t<true>;
+template struct _jit_avx512_core_fp32_wino_conv_4x3_t<false>;
+
+namespace {
+
+void subarray_sum(size_t num_arrs, float *output, size_t nelems,
+ float *input_ptrs[], size_t input_starts[], size_t input_ends[]) {
+ using namespace nstl;
+ const size_t block_size = 16 * 1024 / sizeof(float);
+ const size_t blocks_number = nelems / block_size;
+ const size_t tail = nelems % block_size;
+
+PRAGMA_OMP(parallel)
+ {
+ const int ithr = mkldnn_get_thread_num();
+ const int nthr = mkldnn_get_num_threads();
+ size_t start{ 0 }, end{ 0 };
+ balance211(blocks_number, nthr, ithr, start, end);
+
+ for (size_t nb = start; nb < end; ++nb) {
+ size_t start_e = nb * block_size;
+ size_t end_e = start_e + block_size;
+ size_t input_start = max(start_e, min(input_starts[0], end_e));
+ size_t input_end = max(start_e, min(input_ends[0], end_e));
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < input_start; e++) {
+ output[e] = 0.f;
+ }
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = input_start; e < input_end; e++) {
+ output[e] = input_ptrs[0][e];
+ }
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = input_end; e < end_e; e++) {
+ output[e] = 0.f;
+ }
+
+ for (size_t a = 1; a < num_arrs; a++) {
+ input_start = max(start_e, input_starts[a]);
+ input_end = min(input_ends[a], end_e);
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = input_start; e < input_end; e++) {
+ output[e] += input_ptrs[a][e];
+ }
+ }
+ }
+
+ if (tail != 0 && ithr == nthr - 1) {
+ size_t start_e = nelems - tail;
+ size_t end_e = nelems;
+ size_t input_start = max(start_e, min(input_starts[0], end_e));
+ size_t input_end = max(start_e, min(input_ends[0], end_e));
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < input_start; e++) {
+ output[e] = 0.f;
+ }
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = input_start; e < input_end; e++) {
+ output[e] = input_ptrs[0][e];
+ }
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = input_end; e < end_e; e++) {
+ output[e] = 0.f;
+ }
+
+ for (size_t a = 1; a < num_arrs; a++) {
+ input_start = max(start_e, input_starts[a]);
+ input_end = min(input_ends[a], end_e);
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = input_start; e < input_end; e++) {
+ output[e] += input_ptrs[a][e];
+ }
+ }
+ }
+ }
+}
+
+const int max_threads_number = 1024;
+
+// Sum to the first buffer array
+void array_sum(size_t num_arrs, float *output,
+ size_t nelems, float *input_ptrs[], bool reduce_to_first = true) {
+ const size_t block_size = 16 * 1024 / sizeof(float);
+ const size_t blocks_number = nelems / block_size;
+ const size_t tail = nelems % block_size;
+
+PRAGMA_OMP(parallel)
+ {
+ const size_t ithr = mkldnn_get_thread_num();
+ const size_t nthr = mkldnn_get_num_threads();
+ size_t start{ 0 }, end{ 0 };
+ balance211(blocks_number, nthr, ithr, start, end);
+
+ for (size_t nb = start; nb < end; ++nb) {
+ size_t start_e = nb * block_size;
+ size_t end_e = start_e + block_size;
+ if (!reduce_to_first) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < end_e; e++) {
+ output[e] = input_ptrs[0][e];
+ }
+ }
+ for (size_t a = 1; a < num_arrs; a++) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < end_e; e++) {
+ output[e] += input_ptrs[a][e];
+ }
+ }
+ }
+
+ if (tail != 0 && ithr == nthr - 1) {
+ size_t start_e = nelems - tail;
+ size_t end_e = nelems;
+ if (!reduce_to_first) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < end_e; e++) {
+ output[e] = input_ptrs[0][e];
+ }
+ }
+ for (size_t a = 1; a < num_arrs; a++) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < end_e; e++) {
+ output[e] += input_ptrs[a][e];
+ }
+ }
+ }
+ }
+}
+} //bwdw namespace
+
+void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
+_execute_backward_weights_SDGtWo(const float *ptr_src,
+ const float *ptr_diff_dst, float *ptr_diff_weights,
+ float *ptr_diff_bias,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const auto &jcp = kernel_->jcp;
+ const int nthreads = jcp.nthr;
+
+ array_offset_calculator<float, 5> src((float *)ptr_src,
+ jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
+ array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst,
+ jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
+ array_offset_calculator<float, 6> diff_weights(ptr_diff_weights,
+ jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
+
+ array_offset_calculator<float, 8> Us(scratchpad.get<float>(key_wino_U),
+ 0, alpha, alpha,
+ jcp.oc_block, jcp.ic_block,
+ jcp.ic_simd_block,
+ jcp.oc_reg_block,
+ jcp.oc_simd_block);
+
+ const int U_sz = nthreads * alpha * alpha * jcp.oc / jcp.nb_oc
+ * jcp.ic / jcp.nb_ic;
+ array_offset_calculator<float, 7>diff_weights_prv(
+ scratchpad.get<float>(key_wino_U) + U_sz,
+ 0, jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
+
+ array_offset_calculator<float, 8> M(scratchpad.get<float>(key_wino_M),
+ 0, alpha, alpha,
+ jcp.oc_block,
+ jcp.nb_tile_block_ur,
+ jcp.tile_block_ur,
+ jcp.oc_reg_block,
+ jcp.oc_simd_block);
+
+ array_offset_calculator<float, 7> V(scratchpad.get<float>(key_wino_V),
+ 0, alpha, alpha,
+ jcp.ic_block,
+ jcp.nb_tile_block_ur,
+ jcp.tile_block_ur,
+ jcp.ic_simd_block);
+
+ array_offset_calculator<float, 2> diff_bias_prv(
+ scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
+
+ auto trans_ker_p = jit_wino_transform_call_s();
+ float I[alpha][alpha][simd_w];
+ float T[alpha][alpha][simd_w];
+ float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
+ 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
+ float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f, 0.119514472455649f,
+ 0.430252100840336f, 0.168067226890756f, 0.179271708683473f, 0.403361344537815f,
+ 1.13777777777778f};
+ float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
+
+PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p, I, T))
+{
+ if (jcp.with_bias) {
+ parallel_nd_in_omp(nthreads, jcp.oc / simd_w,
+ [&](int ithr, int ofm){
+ float *pdbias = &(diff_bias_prv(ithr, ofm * simd_w));
+ PRAGMA_OMP_SIMD()
+ for (int v = 0; v < simd_w; v++) {
+ pdbias[v] = 0.0f;
+ }
+ });
+ }
+
+ int ithr = mkldnn_get_thread_num();
+ for (int ifm1 = 0; ifm1 < jcp.nb_ic; ++ifm1) {
+ int first_tblk = 0;
+PRAGMA_OMP(for)
+ for (int tblk1 = 0; tblk1 < jcp.tile_block; ++tblk1) {
+ int tile_index = tblk1 * jcp.nb_tile_block_ur * jcp.tile_block_ur;
+ int img = tile_index / (jcp.itiles * jcp.jtiles);
+ trans_ker_p.ti = tile_index % jcp.itiles;
+ trans_ker_p.tj = (tile_index / jcp.itiles) % jcp.jtiles;
+ trans_ker_p.M = I;
+ trans_ker_p.T = T;
+ trans_ker_p.G = G_I_3x3_4x4;
+ for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
+ int ifm = ifm1 * jcp.ic_block + ifm2;
+ trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
+ trans_ker_p.dst = (float *)&(V(ithr, 0, 0, ifm2, 0, 0, 0));
+ kernel_->src_transform(&trans_ker_p);
+ }
+
+ for (int ofm1 = 0; ofm1 < jcp.nb_oc; ++ofm1) {
+ trans_ker_p.G = G_W_3x3_4x4;
+ for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
+ int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
+ trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
+ trans_ker_p.dst = (float *)&(M(ithr, 0, 0, ofm2, 0, 0, 0, 0));
+ if (jcp.with_bias && ifm1 == 0) {
+ trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
+ kernel_->diff_dst_transform_wbias(&trans_ker_p);
+ } else {
+ kernel_->diff_dst_transform(&trans_ker_p);
+ }
+ }
+
+ for (int oj = 0; oj < alpha; ++oj) {
+ for (int oi = 0; oi < alpha; ++oi) {
+ kernel_->gemm_loop_ker_first_iter(
+ &(Us(ithr, oj, oi, 0, 0, 0, 0, 0)),
+ &(M(ithr, oj, oi, 0, 0, 0, 0, 0)),
+ &(V(ithr, oj, oi, 0, 0, 0, 0)));
+ }
+ }
+ trans_ker_p.G = G_O_3x3_4x4;
+ for (int ofm2 = 0; ofm2 < jcp.oc_block; ++ofm2) {
+ for (int ofm3 = 0; ofm3 < jcp.oc_reg_block; ++ofm3) {
+ int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block
+ + ofm3;
+ for (int ifm2 = 0; ifm2 < jcp.ic_block; ++ifm2) {
+ int ifm = ifm1 * jcp.ic_block + ifm2;
+ trans_ker_p.src = (float *)&(Us(ithr, 0, 0,
+ ofm2, ifm2, 0, ofm3, 0));
+ trans_ker_p.dst = (float *)&(diff_weights_prv(ithr,
+ ofm, ifm, 0, 0, 0, 0));
+ if (first_tblk == 0) {
+ kernel_->diff_weights_transform(&trans_ker_p);
+ } else {
+ kernel_->diff_weights_transform_accum(&trans_ker_p);
+ }
+ }
+ }
+ }
+ }
+ ++first_tblk;
+ }
+ }
+}
+
+ // Reduce diff-weights
+ {
+ float *output = ptr_diff_weights;
+ float *input_base = scratchpad.get<float>(key_wino_U) + U_sz;
+ int nelems = jcp.oc * jcp.ic * jcp.kh * jcp.kw;
+ float *input_ptrs[max_threads_number];
+ for (int i = 0; i < nthreads; ++i) {
+ input_ptrs[i] = input_base + nelems * i;
+ }
+ array_sum(nthreads, output, nelems, input_ptrs, false);
+
+ if (jcp.with_bias) {
+ output = ptr_diff_bias;
+ input_base = scratchpad.get<float>(key_conv_bia_reduction);
+ for (int i = 0; i < nthreads; ++i) {
+ input_ptrs[i] = input_base + jcp.oc * i;
+ }
+ array_sum(nthreads, output, jcp.oc_without_padding, input_ptrs,
+ false);
+ }
+ }
+}
+
+void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t::
+_execute_backward_weights_S_D_Giot_W(const float *ptr_src,
+ const float *ptr_diff_dst, float *ptr_diff_weights,
+ float *ptr_diff_bias,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const auto &jcp = kernel_->jcp;
+ const int nthreads = jcp.nthr;
+
+ array_offset_calculator<float, 5> src((float *)ptr_src,
+ jcp.mb, jcp.ic / simd_w, jcp.ih, jcp.iw, simd_w);
+ array_offset_calculator<float, 5> diff_dst((float *)ptr_diff_dst,
+ jcp.mb, jcp.oc / simd_w, jcp.oh, jcp.ow, simd_w);
+ array_offset_calculator<float, 6> diff_weights((float *)ptr_diff_weights,
+ jcp.oc / simd_w, jcp.ic / simd_w, jcp.kh, jcp.kw, simd_w, simd_w);
+ array_offset_calculator<float, 1> diff_bias((float *)ptr_diff_bias, jcp.oc);
+
+ array_offset_calculator<float, 9> U(scratchpad.get<float>(key_wino_U),
+ jcp.nb_ic, jcp.nb_oc,
+ alpha, alpha,
+ jcp.oc_block, jcp.ic_block,
+ jcp.ic_simd_block,
+ jcp.oc_reg_block,
+ jcp.oc_simd_block);
+
+ const int U_size = jcp.oc * jcp.ic * alpha * alpha;
+ array_offset_calculator<float, 10> Us(
+ scratchpad.get<float>(key_wino_U) + U_size,
+ 0, jcp.nb_ic, jcp.nb_oc,
+ alpha, alpha,
+ jcp.oc_block, jcp.ic_block,
+ jcp.ic_simd_block,
+ jcp.oc_reg_block,
+ jcp.oc_simd_block);
+
+ array_offset_calculator<float, 9> M(scratchpad.get<float>(key_wino_M),
+ jcp.nb_oc,
+ jcp.tile_block,
+ alpha, alpha,
+ jcp.oc_block,
+ jcp.nb_tile_block_ur,
+ jcp.tile_block_ur ,
+ jcp.oc_reg_block,
+ jcp.oc_simd_block);
+
+ array_offset_calculator<float, 8> V(scratchpad.get<float>(key_wino_V),
+ jcp.nb_ic,
+ jcp.tile_block,
+ alpha, alpha,
+ jcp.ic_block,
+ jcp.nb_tile_block_ur, jcp.tile_block_ur,
+ jcp.ic_simd_block);
+
+ array_offset_calculator<float, 2> diff_bias_prv(
+ scratchpad.get<float>(key_conv_bia_reduction), nthreads, jcp.oc);
+
+ size_t input_starts[max_threads_number] = {0};
+ size_t input_ends[max_threads_number] = {0};
+ size_t first_tblk = 0;
+
+ auto trans_ker_p = jit_wino_transform_call_s();
+ float G_I_3x3_4x4[9] = {-2.25f, -0.390625f, 0.87890625f, -2.640625f,
+ 0.625f, -0.625f, 1.5f, -1.5f, -2.640625f};
+ float G_W_3x3_4x4[8] = {0.26890756302521f, -0.688403361344538f,
+ 0.119514472455649f, 0.430252100840336f, 0.168067226890756f,
+ 0.179271708683473f, 0.403361344537815f, 1.13777777777778f};
+ float G_O_3x3_4x4[4] = {2.25f, 0.625f, 1.5f, 0.390625f};
+ float I[alpha][alpha][simd_w];
+ float T[alpha][alpha][simd_w];
+
+PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T))
+{
+ if (jcp.with_bias) {
+ parallel_nd_in_omp(nthreads, jcp.oc, [&](int ithr, int ofm) {
+ diff_bias_prv(ithr, ofm) = 0.0f;
+ });
+ }
+
+ trans_ker_p.G = G_I_3x3_4x4;
+ trans_ker_p.M = I;
+ trans_ker_p.T = T;
+
+ parallel_nd_in_omp(jcp.nb_ic, jcp.ic_block, jcp.mb,
+ [&](int ifm1, int ifm2, int img){
+ size_t ifm = ifm1 * jcp.ic_block + ifm2;
+ size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
+ size_t tblk3 = tile_base_index % jcp.tile_block_ur;
+ size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
+ % jcp.nb_tile_block_ur;
+ size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
+ / jcp.nb_tile_block_ur;
+ trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
+ trans_ker_p.src = (float *)&(src(img, ifm, 0, 0, 0));
+ trans_ker_p.dst = (float *)&(V(ifm1, tblk1, 0, 0, ifm2, 0, 0, 0));
+ kernel_->src_transform(&trans_ker_p);
+ });
+
+ int ithr = mkldnn_get_thread_num();
+ trans_ker_p.G = G_W_3x3_4x4;
+ parallel_nd_in_omp(jcp.nb_oc, jcp.oc_block, jcp.mb,
+ [&](int ofm1, int ofm2, int img){
+ int ofm = (ofm1 * jcp.oc_block + ofm2) * jcp.oc_reg_block;
+ size_t tile_base_index = img * (jcp.itiles * jcp.jtiles);
+ size_t tblk3 = tile_base_index % jcp.tile_block_ur;
+ size_t tblk2 = (tile_base_index / jcp.tile_block_ur)
+ % jcp.nb_tile_block_ur;
+ size_t tblk1 = (tile_base_index / jcp.tile_block_ur)
+ / jcp.nb_tile_block_ur;
+ trans_ker_p.tile_count = tblk2 * jcp.tile_block_ur + tblk3;
+ trans_ker_p.src = (float *)&(diff_dst(img, ofm, 0, 0, 0));
+ trans_ker_p.dst = (float *)&(M(ofm1, tblk1, 0, 0, ofm2, 0, 0, 0, 0));
+ if (jcp.with_bias) {
+ trans_ker_p.bias = (float *)&(diff_bias_prv(ithr, ofm * simd_w));
+ kernel_->diff_dst_transform_wbias(&trans_ker_p);
+ } else {
+ kernel_->diff_dst_transform(&trans_ker_p);
+ }
+ });
+
+ PRAGMA_OMP(barrier)
+
+ parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, alpha, alpha, jcp.tile_block,
+ [&](int ifm1, int ofm1, int oj, int oi, int tblk1){
+ if (first_tblk == 0) {
+ input_starts[ithr] =
+ (float *)&(Us(ithr, ifm1, ofm1, oj, oi, 0, 0, 0,
+ 0, 0))
+ - (float *)&(Us(ithr, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0));
+ input_ends[ithr] = input_starts[ithr]
+ + jcp.oc_block * jcp.ic_block
+ * jcp.ic_simd_block * jcp.oc_reg_block
+ * jcp.oc_simd_block;
+ }
+ else if (tblk1 == 0) {
+ input_ends[ithr] += jcp.oc_block * jcp.ic_block
+ * jcp.ic_simd_block * jcp.oc_reg_block
+ * jcp.oc_simd_block;
+ }
+
+ if (first_tblk == 0 || tblk1 == 0) {
+ kernel_->gemm_loop_ker_first_iter(
+ &(Us(ithr, ifm1, ofm1, oj, oi,
+ 0, 0, 0, 0, 0)),
+ &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
+ &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
+ } else {
+ kernel_->gemm_loop_ker(
+ &(Us(ithr, ifm1, ofm1, oj, oi,
+ 0, 0, 0, 0, 0)),
+ &(M(ofm1, tblk1, oj, oi, 0, 0, 0, 0, 0)),
+ &(V(ifm1, tblk1, oj, oi, 0, 0, 0, 0)));
+ }
+ ++first_tblk;
+ });
+}
+
+ // Reduce diff-weights
+ {
+ float *output = &(U(0, 0, 0, 0, 0, 0, 0, 0, 0));
+ size_t nelems = jcp.ic * jcp.oc * alpha * alpha;
+ float *input_ptrs[max_threads_number];
+ for (int i = 0; i < nthreads; ++i)
+ input_ptrs[i] = output + nelems * (i + 1);
+ subarray_sum(nthreads, output, nelems, input_ptrs,
+ input_starts, input_ends);
+ }
+
+ trans_ker_p.G = G_O_3x3_4x4;
+PRAGMA_OMP(parallel firstprivate(trans_ker_p))
+ {
+ parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block,
+ [&](int ifm1, int ofm1, int ofm2, int ifm2, int ofm3){
+ int ofm = (ofm1 * jcp.oc_block + ofm2)
+ * jcp.oc_reg_block + ofm3;
+ int ifm = ifm1 * jcp.ic_block + ifm2;
+ trans_ker_p.src = (float *)&(U(ifm1, ofm1, 0, 0,
+ ofm2, ifm2, 0, ofm3, 0));
+ trans_ker_p.dst = (float *)&(diff_weights(ofm, ifm,
+ 0, 0, 0, 0));
+ kernel_->diff_weights_transform(&trans_ker_p);
+ });
+ }
+
+ if (jcp.with_bias) {
+ parallel_nd(jcp.oc / simd_w, [&](int ofm1) {
+ float* pbias = &(diff_bias(ofm1 * simd_w));
+ float *pbias_prv = &(diff_bias_prv(0, ofm1 * simd_w));
+
+ const int blk_sz = ofm1 == jcp.oc / simd_w - 1
+ ? jcp.oc_without_padding - ofm1 * simd_w : simd_w;
+
+ PRAGMA_OMP_SIMD()
+ for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
+ pbias[ofm2] = pbias_prv[ofm2];
+ }
+
+ for (int ithr = 1; ithr < nthreads; ++ithr) {
+ pbias_prv = &(diff_bias_prv(ithr, ofm1 * simd_w));
+ PRAGMA_OMP_SIMD()
+ for (int ofm2 = 0; ofm2 < blk_sz; ++ofm2) {
+ pbias[ofm2] += pbias_prv[ofm2];
+ }
+ }
+ });
+ }
+}
+
+}
+}
+}
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp
new file mode 100644
index 0000000000..f1a56aac70
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.hpp
@@ -0,0 +1,386 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP
+#define CPU_JIT_AVX512_CORE_FP32_WINO_CONV_4x3_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace winograd_avx512_core {
+inline void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_winograd_conf_t &jcp) {
+ using namespace utils;
+ using namespace memory_tracking::names;
+
+ size_t U_sz = (size_t)alpha * alpha * jcp.ic * jcp.oc;
+ size_t V_sz = (size_t)alpha * alpha * jcp.mb * jcp.ic * jcp.itiles
+ * jcp.jtiles;
+ size_t M_sz = (size_t)alpha * alpha * jcp.mb * jcp.oc * jcp.itiles
+ * jcp.jtiles;
+
+ switch (jcp.sched_policy) {
+ case WSCHED_DATA_W_SGD:
+ V_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur
+ * jcp.tile_block_ur * jcp.ic;
+ M_sz = (size_t)jcp.nthr * alpha * alpha * jcp.nb_tile_block_ur
+ * jcp.tile_block_ur * jcp.oc;
+ break;
+ case WSCHED_WEI_SDGtWo:
+ U_sz = (size_t)jcp.nthr * (alpha * alpha * jcp.oc
+ * (jcp.ic / jcp.nb_ic) + jcp.ic * jcp.oc * jcp.kh * jcp.kw);
+ M_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block)
+ * (jcp.oc / jcp.nb_oc);
+ V_sz = (size_t)jcp.nthr * alpha * alpha * (jcp.ntiles / jcp.tile_block)
+ * (jcp.ic / jcp.nb_ic);
+ break;
+ case WSCHED_WEI_S_D_Giot_W:
+ U_sz = (size_t)(jcp.nthr + 1) * alpha * alpha * jcp.ic * jcp.oc;
+ M_sz = (size_t)alpha * alpha * jcp.oc * jcp.ntiles;
+ V_sz = (size_t)alpha * alpha * jcp.ic * jcp.ntiles;
+ break;
+ default: break;
+ }
+
+ scratchpad.book(key_wino_U, sizeof(float) * U_sz, PAGE_2M);
+ scratchpad.book(key_wino_V, sizeof(float) * V_sz, PAGE_2M);
+ scratchpad.book(key_wino_M, sizeof(float) * M_sz, PAGE_2M);
+
+ if (one_of(jcp.sched_policy, WSCHED_WEI_SDGtWo, WSCHED_WEI_S_D_Giot_W)) {
+ size_t br_sz = (size_t)jcp.nthr * jcp.oc;
+ scratchpad.book(key_conv_bia_reduction, sizeof(float) * br_sz, PAGE_2M);
+ }
+}
+}
+
+template <bool is_fwd>
+struct _jit_avx512_core_fp32_wino_conv_4x3_t {
+
+ _jit_avx512_core_fp32_wino_conv_4x3_t(
+ const jit_conv_winograd_conf_t &jcp, const primitive_attr_t *attr)
+ : kernel_(nullptr), attr_(attr) {
+ kernel_ = new _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(jcp);
+ }
+
+ ~_jit_avx512_core_fp32_wino_conv_4x3_t() { delete kernel_; }
+
+ protected:
+ void weight_transform_data(const jit_conv_winograd_conf_t &jcp,
+ float *wp, float *twp) const;
+ void input_transform_data(int image,
+ const jit_conv_winograd_conf_t &jcp,
+ float *inp, float *tinp) const;
+ void input_transform_tileblock_data(int tile_block,
+ const jit_conv_winograd_conf_t &jcp,
+ float *inp, float *tinp) const;
+ void output_transform_data(int image,
+ const jit_conv_winograd_conf_t &jcp,
+ const post_ops_t &p_ops, float *toutp, float *pout_b,
+ float *bias) const;
+ void output_transform_tileblock_data(int tile_block,
+ const jit_conv_winograd_conf_t &jcp, const post_ops_t &p_ops,
+ float *toutp, float *outp, float *bias) const;
+ void _execute_data_W_S_G_D(float *inp_ptr, float *out_ptr,
+ float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const;
+ void _execute_data_W_SGD(float *inp_ptr, float *out_ptr,
+ float *wei_ptr, float *bias_ptr,
+ const memory_tracking::grantor_t &scratchpad) const;
+ _jit_avx512_core_fp32_wino_conv_4x3_data_kernel *kernel_;
+ const primitive_attr_t *attr_;
+};
+
+struct jit_avx512_core_fp32_wino_conv_4x3_fwd_t
+ : _jit_avx512_core_fp32_wino_conv_4x3_t<true>
+ , public cpu_primitive_t
+ {
+ struct pd_t : public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
+ jit_avx512_core_fp32_wino_conv_4x3_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::
+ init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_,
+ *attr());
+ if (status != status::success) return status;
+ set_default_alg_kind(alg_kind::convolution_winograd);
+
+ auto scratchpad = scratchpad_registry().registrar();
+ winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
+
+ return status;
+ }
+
+ jit_conv_winograd_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ auto wei_fmt = desc()->prop_kind == prop_kind::forward_training
+ ? (with_groups() ? gOIhw16i16o : OIhw16i16o) : any;
+ return set_default_formats_common(nChw16c, wei_fmt, nChw16c);
+ }
+ };
+
+ jit_avx512_core_fp32_wino_conv_4x3_fwd_t(const pd_t *apd)
+ : _jit_avx512_core_fp32_wino_conv_4x3_t<true>(apd->jcp_, apd->attr())
+ , cpu_primitive_t(apd, true)
+ {}
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(float *, MKLDNN_ARG_DST);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ switch ((pd()->jcp_).sched_policy) {
+ case WSCHED_DATA_W_S_G_D:
+ this->_execute_data_W_S_G_D((float *)src, dst, (float *)weights,
+ (float *)bias, scratchpad);
+ break;
+ case WSCHED_DATA_W_SGD:
+ this->_execute_data_W_SGD((float *)src, dst, (float *)weights,
+ (float *)bias, scratchpad);
+ break;
+ default:
+ break;
+ }
+ return status::success;
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t
+ : _jit_avx512_core_fp32_wino_conv_4x3_t<false>,
+ public cpu_primitive_t {
+ struct pd_t : public cpu_convolution_bwd_data_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && mkldnn_thr_syncable()
+ && desc()->prop_kind == prop_kind::backward_data
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::undef, data_type::f32, data_type::f32)
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel
+ ::init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(),
+ *diff_dst_md());
+ if (status != status::success) return status;
+ set_default_alg_kind(alg_kind::convolution_winograd);
+
+ auto scratchpad = scratchpad_registry().registrar();
+ winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
+
+ return status;
+ }
+
+ jit_conv_winograd_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o;
+ return set_default_formats_common(nChw16c, wei_fmt, nChw16c);
+ }
+ };
+
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_data_t(const pd_t *apd)
+ : _jit_avx512_core_fp32_wino_conv_4x3_t<false>(apd->jcp_, apd->attr())
+ , cpu_primitive_t(apd, true)
+ {}
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const float *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ switch ((pd()->jcp_).sched_policy) {
+ case WSCHED_DATA_W_S_G_D:
+ this->_execute_data_W_S_G_D((float *)diff_dst, diff_src,
+ (float *)weights, NULL, scratchpad);
+ break;
+
+ case WSCHED_DATA_W_SGD:
+ this->_execute_data_W_SGD((float *)diff_dst, diff_src,
+ (float *)weights, NULL, scratchpad);
+ break;
+
+ default:
+ break;
+ }
+
+ return status::success;
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t
+ : public cpu_primitive_t {
+ struct pd_t : public cpu_convolution_bwd_weights_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_wino_4x3:", avx512_core, ""),
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && mkldnn_thr_syncable()
+ && desc()->prop_kind == prop_kind::backward_weights
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && set_default_formats();
+ if (!ok)
+ return status::unimplemented;
+
+ status_t status =
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
+ init_conf(jcp_, *desc(), *src_md(), *diff_dst_md(),
+ *diff_weights_md());
+ if (status != status::success) return status;
+ set_default_alg_kind(alg_kind::convolution_winograd);
+
+ auto scratchpad = scratchpad_registry().registrar();
+ winograd_avx512_core::init_scratchpad(scratchpad, jcp_);
+
+ return status;
+ }
+
+ jit_conv_winograd_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ auto wei_fmt = with_groups() ? gOIhw16i16o : OIhw16i16o;
+ return set_default_formats_common(nChw16c, wei_fmt, nChw16c);
+ }
+ };
+
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true)
+ , kernel_(nullptr)
+ {
+ kernel_ = new jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel(
+ pd()->jcp_);
+ }
+
+ ~jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_t()
+ {
+ delete kernel_;
+ }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto diff_dst = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST);
+ auto src = CTX_IN_MEM(const float *, MKLDNN_ARG_SRC);
+ auto diff_weights = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS);
+
+ switch (kernel_->jcp.sched_policy) {
+ case WSCHED_WEI_SDGtWo:
+ _execute_backward_weights_SDGtWo(src, diff_dst, diff_weights,
+ diff_bias, scratchpad(ctx));
+ break;
+ case WSCHED_WEI_S_D_Giot_W:
+ _execute_backward_weights_S_D_Giot_W(src, diff_dst, diff_weights,
+ diff_bias, scratchpad(ctx));
+ break;
+ default:
+ assert(kernel_->jcp.sched_policy != WSCHED_INVALID);
+ break;
+ }
+ return status::success;
+ }
+
+private:
+ void _execute_backward_weights_SDGtWo(const float *src,
+ const float *diff_dst, float *diff_weights, float *diff_bias,
+ const memory_tracking::grantor_t &scratchpad) const;
+ void _execute_backward_weights_S_D_Giot_W(const float *src,
+ const float *diff_dst, float *diff_weights, float *diff_bias,
+ const memory_tracking::grantor_t &scratchpad) const;
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel *kernel_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp
new file mode 100644
index 0000000000..0d64a2d13a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp
@@ -0,0 +1,2596 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include <math.h>
+
+#include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp"
+
+#define GET_OFF(field) offsetof(jit_wino_transform_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace {
+
+using namespace mkldnn::impl::utils;
+
+unsigned int L1_cache_size = get_cache_size(1, true);
+unsigned int L2_cache_size = get_cache_size(2, true);
+unsigned int LLC_data_size = get_cache_size(3, false);
+
+// the test funtion takes jcp, the candidate and the current best.
+// it returns true if the new candidate is better
+int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
+ int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
+{
+ int best_divisor = default_best;
+ auto test_num
+ = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
+ if (test(jcp, num, best_divisor)) {
+ best_divisor = num;
+ }
+ };
+
+ for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
+ if (number % divisor == 0) {
+ test_num(jcp, divisor);
+ test_num(jcp, number / divisor);
+ }
+ }
+
+ return best_divisor;
+}
+
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
+ /* Determines if current winograd implementation is faster than direct.
+ Following conditions are empirical and based on performance data */
+ unsigned int ncores_per_socket =
+ cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
+ unsigned int nthreads = mkldnn_get_max_threads();
+
+ if (jcp.prop_kind == prop_kind::forward_inference) {
+ return jcp.mb >= 4;
+ } else if (nthreads > ncores_per_socket) {
+ double src_dst_transforms_per_core = alpha * alpha
+ * (jcp.ic + jcp.oc)
+ * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size)
+ * ((jcp.ow + tile_size - 1) / tile_size)
+ * sizeof(float) / 1024. / 1024. / nthreads;
+ double wei_transform = alpha * alpha
+ * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.;
+
+ if (jcp.prop_kind == prop_kind::backward_weights) {
+ if (src_dst_transforms_per_core < 0.3
+ || (src_dst_transforms_per_core <= 28 && wei_transform < 4))
+ return false;
+ else
+ return true;
+ } else {
+ if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02)
+ return false;
+ }
+ }
+
+ return jcp.mb > 8;
+}
+}
+
+/* assumes 512 bits registers */
+/* TODO: add support for strides */
+/* TODO: handle the prefetch distance automatically */
+typedef enum cache_t_ { L1, L2, L3 } cache_t;
+
+template <typename data_t>
+struct prefetcher_t {
+ prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
+ cache_t cache_type, size_t block_size, /* in number of elements*/
+ int nb_instructions_in_block, int fma_ipc)
+ : cg_(generator)
+ , reg_base_addr_(reg_base_addr)
+ , cache_type_(cache_type)
+ , cache_block_size_(block_size)
+ {
+ nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
+ prefetch_spread_
+ = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
+ prefetch_blk_
+ = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
+
+ /* assumption: when fetch in Li, data is already in L(i+1) */
+ int cache_latency;
+ switch (cache_type_) {
+ case L1: cache_latency = 14; break;
+ case L2: cache_latency = 250; break;
+ case L3: cache_latency = 250; break;
+ }
+
+ prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
+ }
+
+ void prefetch(int instruction_number)
+ {
+ if (instruction_number % prefetch_spread_ == 0) {
+ for (int i = 0; (i < prefetch_blk_)
+ && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
+ i++, prefetches_issued_++) {
+ prefetch_inst_(cg_->EVEX_compress_addr(
+ reg_base_addr_, (cache_block_size_ * prefetch_distance_)
+ * sizeof(data_t)
+ + (prefetches_issued_ * 64)));
+ }
+ }
+ }
+
+private:
+ void prefetch_inst_(const Xbyak::Address &addr)
+ {
+ switch (cache_type_) {
+ case L1: cg_->prefetcht0(addr); break;
+ case L2: cg_->prefetcht1(addr); break;
+ case L3: cg_->prefetcht2(addr); break;
+ default:
+ break; // TODO: raise an exception or put an assert
+ }
+ }
+
+ jit_generator *cg_;
+ Xbyak::Reg64 reg_base_addr_;
+ cache_t cache_type_;
+ int cache_block_size_ = 0;
+ int nb_cache_lines_to_prefetch_ = 0;
+ int prefetches_issued_ = 0;
+ int prefetch_spread_ = 0;
+ int prefetch_blk_ = 0;
+ int prefetch_distance_ = 0;
+};
+
+// utilities to support kernel parameter selection
+bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp,
+ int dimN_block, float C2_min, float C2_max) {
+ float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic)
+ * dimN_block * jcp.dimN_reg_block
+ + div_up(jcp.ic * jcp.oc,mkldnn_get_max_threads())) * (float)sizeof(float);
+ float L2_lb = C2_min * L2_cache_size;
+ float L2_ub = C2_max * L2_cache_size;
+ return (block_size > L2_lb && block_size < L2_ub);
+}
+
+bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block,
+ int dimM_block, float C1_min, float C1_max) {
+ float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block
+ * jcp.dimK_reg_block * jcp.dimM_reg_block
+ + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
+ + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block)
+ * (float)sizeof(float);
+ float L1_lb = C1_min * L1_cache_size;
+ float L1_ub = C1_max * L1_cache_size;
+ return (gemm_block_size > L1_lb && gemm_block_size < L1_ub);
+}
+bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
+ int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
+{
+ float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block
+ + dimM_block * dimK_block * dimK_reg_block
+ * dimM_simd_block * dimM_reg_block
+ + dimK_block * dimN_reg_block * dimK_reg_block)
+ * (float)sizeof(float);
+ float rhs = C * L1_cache_size;
+ return (lhs < rhs);
+}
+bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
+ int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
+{
+ float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block
+ * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block)
+ * (float)sizeof(float);
+ float rhs = C * L1_cache_size;
+ return (lhs < rhs);
+}
+bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
+ int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block,
+ int dimM_simd_block, float C)
+{
+ float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block
+ * dimM_simd_block * dimM_reg_block
+ + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
+ * dimM_simd_block * dimM_reg_block
+ + nb_dimN_reg_block * dimK_nb_block * dimK_block
+ * dimN_reg_block * dimK_reg_block)
+ * (float)sizeof(float);
+ float rhs = C * L2_cache_size;
+ return (lhs < rhs);
+}
+
+bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block,
+ int dimN_block, int dimN_reg_block, int dimK, float C1, float C2)
+{
+ float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK
+ * (float)sizeof(float);
+ float B_size = dimN_block * dimN_reg_block * dimK
+ * (float)sizeof(float);
+ return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size);
+}
+}
+
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::utils;
+using namespace Xbyak;
+
+void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::gemm_loop_generate()
+{
+ // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
+ // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block;
+ // dimM_reg_block++) // unrolled
+ // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
+ // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
+ // dimK_reg_block++) // unrolled
+ // for (int tile =0; tile < jcp.dimN_reg_block; tile++)
+ // C[dimM_block][dimM_reg_block][tile] +=
+ // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block]
+ // * broadcast(B[dimK_block][tile][dimK_reg_block]);
+ // Notes:
+ // jcp.kernel_kind defines embedded or explicit broadcast
+ // dimM_reg_block=1 for embedded bcast kernel
+
+ auto zmm_srcA = [=]() {
+ return Xbyak::Zmm(0);
+ };
+ auto zmm_srcB = [=](int tile) {
+ int idx = 1 + tile;
+ assert(idx < 1 + jcp.dimN_reg_block);
+ return Xbyak::Zmm(idx);
+ };
+ auto zmm_dstC = [=](int dimM_reg_block, int tile) {
+ int idx{0};
+ if (jcp.kernel_kind == embd_bcast)
+ idx = 1 + tile;
+ else
+ idx = 1 + jcp.dimN_reg_block
+ + dimM_reg_block * jcp.dimN_reg_block + tile;
+ assert(idx < 32);
+ return Xbyak::Zmm(idx);
+ };
+
+ auto prepare_output = [=]() {
+ for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
+ dimM_reg_block++) {
+ for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
+ Zmm zmm = zmm_dstC(dimM_reg_block, tile);
+ vpxord(zmm, zmm, zmm);
+ }
+ }
+ };
+ auto store_output = [=](bool output_is_aligned) {
+ Label save;
+ cmp(reg_is_beta_zero, 0);
+ je(save, T_NEAR);
+
+ for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
+ dimM_reg_block++) {
+ for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
+ Zmm zmm = zmm_dstC(dimM_reg_block,tile);
+ int output_offset
+ = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
+ vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset));
+ }
+ }
+
+ L(save);
+ for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
+ dimM_reg_block++) {
+ for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
+ Zmm zmm = zmm_dstC(dimM_reg_block,tile);
+ int output_offset
+ = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
+
+ // In W_SGD, output will be reused.
+ if (output_is_aligned
+ && jcp.dimK_nb_block == 1
+ && jcp.sched_policy == WSCHED_DATA_W_S_G_D
+ && (jcp.dimN * jcp.dimM * alpha * alpha
+ * sizeof(float) > 2 * LLC_data_size))
+ vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm);
+ else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm);
+ }
+ }
+ };
+
+ auto inner_loops = [=]() {
+ Label dimM_block_loop, dimK_block_loop;
+
+ if (jcp.dimM_block > 1) {
+ mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
+ L(dimM_block_loop);
+ }
+
+ prepare_output();
+
+ if (jcp.dimK_block > 1) {
+ mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
+ L(dimK_block_loop);
+ }
+
+ for (int dimK_reg_block = 0;
+ dimK_reg_block < jcp.dimK_reg_block;
+ dimK_reg_block ++) {
+
+ if (jcp.kernel_kind == expl_bcast) {
+ for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
+ vbroadcastss(zmm_srcB(tile),
+ ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]);
+ }
+ }
+
+ /* Performing the fmas */
+
+ for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
+ dimM_reg_block++) {
+
+ vmovups(zmm_srcA(),
+ zword[reg_srcA
+ + jcp.dimK_reg_block * jcp.dimK_block * 64
+ * dimM_reg_block
+ + dimK_reg_block * 64]
+ );
+
+ for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
+ if (jcp.kernel_kind == expl_bcast)
+ vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
+ zmm_srcB(tile));
+ else
+ vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
+ EVEX_compress_addr(reg_srcB,
+ 64 * tile + dimK_reg_block * 4, true));
+ }
+ }
+ }
+ add(reg_srcA, jcp.dimK_reg_block * 64);
+ add(reg_srcB, jcp.dimN_reg_block * 64);
+ if (jcp.dimK_block > 1) {
+ sub(reg_dimK_block_loop_cnt, 1);
+ jnz(dimK_block_loop);
+ }
+
+ Label unaligned_store, end_store;
+ test(reg_dstC, cpu_isa_traits<avx512_core>::vlen - 1);
+ jnz(unaligned_store, T_NEAR);
+ store_output(true);
+ jmp(end_store, T_NEAR);
+ L(unaligned_store); {
+ store_output(false);
+ }
+ L(end_store);
+
+ if (jcp.dimM_block > 1) {
+ sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
+ add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64);
+ if (jcp.kernel_kind == expl_bcast) {
+ add(reg_srcA,
+ (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64
+ * jcp.dimK_block);
+ }
+ sub(reg_dimM_block_loop_cnt, 1);
+ jnz(dimM_block_loop);
+ }
+ };
+
+ /* Preamble */
+ preamble();
+
+ /* kernel */
+ inner_loops();
+
+ /* Postamble */
+ postamble();
+ ret();
+}
+
+void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
+ ::weights_transform_data_ker_generate()
+{
+ bool is_fwd = one_of(jcp.prop_kind,
+ mkldnn_forward_training, mkldnn_forward_inference);
+ int kh = jcp.kh;
+ int kw = jcp.kw;
+
+ auto zmm_temp = Xbyak::Zmm(31);
+ auto zmm_zero = Xbyak::Zmm(30);
+
+ auto zmm_M = [=](int i) {
+ return Xbyak::Zmm(i);
+ };
+ auto zmm_MT = [=](int i) {
+ return Xbyak::Zmm(i + simd_w);
+ };
+
+ auto zmm_G = [=](int i) {
+ return Xbyak::Zmm(i);
+ };
+ auto zmm_F = [=](int i) {
+ return Xbyak::Zmm(alpha + i);
+ };
+ auto zmm_T = [=](int i) {
+ return Xbyak::Zmm(alpha + 3 + i);
+ };
+ auto zmm_t = [=](int i) {
+ return Xbyak::Zmm(2 * alpha + 3 + i);
+ };
+
+ auto zmm_load = [=](int i) {
+ return Xbyak::Zmm(i);
+ };
+
+ auto init_G = [=]() {
+ mov(wreg_temp, ptr[param1 + GET_OFF(G)]);
+ for (int i = 0; i < alpha; i++) {
+ vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]);
+ }
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+ };
+
+ auto trans16x16 = [=]() {
+ for (int i = 0; i < simd_w; i+=2 ) {
+ vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]);
+ vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]);
+ vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1));
+ vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1));
+ }
+ for (int i = 0; i < simd_w; i+=4 ) {
+ vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2));
+ vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2));
+ vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3));
+ vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3));
+ }
+ for (int i = 0; i < simd_w; i += 8) {
+ vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88);
+ vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88);
+ vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88);
+ vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88);
+ vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd);
+ vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd);
+ vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd);
+ vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd);
+ }
+ {
+ int i = 0;
+ int mask = 0x88;
+ vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask);
+ vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0));
+ vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask);
+ vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1));
+ vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask);
+ vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2));
+ vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask);
+ vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3));
+ vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask);
+ vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4));
+ vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask);
+ vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5));
+ vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask);
+ vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6));
+ vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask);
+ vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7));
+ mask = 0xdd;
+ vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask);
+ vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8));
+ vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask);
+ vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9));
+ vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask);
+ vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10));
+ vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask);
+ vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11));
+ vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask);
+ vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12));
+ vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask);
+ vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13));
+ vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask);
+ vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14));
+ vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask);
+ vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15));
+ }
+ };
+
+ auto load_src = [=]() {
+ mov(wreg_src, ptr[param1 + GET_OFF(src)]);
+ mov(wreg_F, ptr[param1 + GET_OFF(M)]);
+ for (int j = 0; j < kh; j++) {
+ for (int i = 0; i < kw; i++) {
+ if (is_fwd) {
+ for (int v1 = 0; v1 < simd_w; v1++) {
+ int offset_src = (j * kw * simd_w * simd_w
+ + i * simd_w * simd_w + v1 * simd_w) * typesize;
+ int offset_F = (j * kw * simd_w * simd_w
+ + i * simd_w * simd_w + v1 * simd_w) * typesize;
+ vmovups(zmm_temp, ptr[wreg_src + offset_src]);
+ vmovups(ptr[wreg_F + offset_F], zmm_temp);
+ }
+ } else {
+ int offset_src = ((2 - j) * kw * simd_w * simd_w
+ + (2 - i) * simd_w * simd_w) * typesize;
+ int offset_F = (j * kw * simd_w * simd_w
+ + i * simd_w * simd_w) * typesize;
+ lea(wreg_M, ptr[wreg_src + offset_src]);
+ lea(wreg_MT, ptr[wreg_F + offset_F]);
+ trans16x16();
+ }
+ }
+ }
+ };
+
+ auto store_dst = [=]() {
+ mov(wreg_dst, ptr[param1 + GET_OFF(dst)]);
+ mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
+
+ Label Loop_j;
+ mov(wreg_cnt_j, 0);
+ mov(wreg_dst_aux, wreg_dst);
+ mov(wreg_Fw_aux, wreg_Fw);
+
+ int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block)
+ * jcp.dimK_block * simd_w * simd_w;
+
+ L(Loop_j);
+ {
+ for (int i = 0; i < alpha; i++) {
+ // touch pages
+ vmovups(zmm_load(0), ptr[wreg_Fw_aux
+ + (i * simd_w * simd_w) * typesize]);
+ mov(wreg_dst_idx, i * dim5 * typesize);
+ vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0));
+ }
+ for (int i = 0; i < alpha; i++) {
+ for (int v1 = 1; v1 < simd_w; v1++) {
+ int offset_Fw = (i * simd_w * simd_w + v1 * simd_w)
+ * typesize;
+ vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]);
+ }
+ mov(wreg_dst_idx, i * dim5 * typesize);
+ for (int v1 = 1; v1 < simd_w; v1++) {
+ int offset_dst = v1 * simd_w * typesize;
+ vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst],
+ zmm_load(v1));
+ }
+ }
+ add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize);
+ add(wreg_dst_aux, alpha * dim5 * typesize);
+ add(wreg_cnt_j, 1);
+ cmp(wreg_cnt_j, alpha);
+ jl(Loop_j, T_NEAR);
+ }
+ };
+
+ auto trans_W_4x4_3x3 = [=]() {
+ auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
+ vmovups(dst, a);
+ vfmadd231ps(dst, b, c);
+ };
+ auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
+ vmulps(zmm_temp, b, c);
+ vsubps(dst, a, zmm_temp);
+ };
+ auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
+ vsubps(dst, zmm_zero, a);
+ vfnmadd231ps(dst, b, c);
+ };
+
+ mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
+ mov(wreg_F, ptr[param1 + GET_OFF(M)]);
+ mov(wreg_T, ptr[param1 + GET_OFF(T)]);
+
+ Label Loop_j;
+ mov(wreg_cnt_j, 0);
+ L(Loop_j);
+ mov(wreg_F_aux, wreg_F);
+ mov(wreg_Fw_aux, wreg_Fw);
+ mov(wreg_temp, wreg_cnt_j);
+ shl(wreg_temp, 4 + 2);
+ lea(wreg_F_aux, ptr[wreg_F + wreg_temp]);
+ lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]);
+
+ for (int i = 0; i < 3; i++) {
+ for (int idx = 0; idx < 3; idx ++) {
+ vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w
+ * simd_w + i * simd_w * simd_w) * typesize]);
+ }
+ vmulps(zmm_t(0), zmm_G(0), zmm_F(2));
+ fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0));
+ fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0));
+
+ vmulps(zmm_T(0), zmm_G(3), zmm_F(0));
+ fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1));
+ fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1));
+ fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1));
+ fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1));
+ vmovaps(zmm_T(5), zmm_F(2));
+
+ for (int idx = 0; idx < 6; idx ++) {
+ vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w)
+ * typesize], zmm_T(idx));
+ }
+ }
+ for (int i = 0; i < 6; i++) {
+
+ for (int idx = 0; idx < 3; idx ++) {
+ vmovups(zmm_T(idx), ptr[wreg_T
+ + (i * 3 * simd_w + idx * simd_w) * typesize]);
+ }
+ vmulps(zmm_t(0), zmm_G(0), zmm_T(2));
+ fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0));
+ fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0));
+
+ vmulps(zmm_F(0), zmm_G(3), zmm_T(0));
+ fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1));
+ fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1));
+ fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1));
+ fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1));
+ vmovaps(zmm_F(5), zmm_T(2));
+
+ for (int l = 0; l < 6; l++) {
+ vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w
+ + l * simd_w * simd_w) * typesize], zmm_F(l));
+ }
+ }
+ add(wreg_cnt_j, 1);
+ cmp(wreg_cnt_j, 16);
+ jl(Loop_j, T_NEAR);
+ };
+
+ auto inner_loops = [=]() {
+ load_src();
+ init_G();
+ trans_W_4x4_3x3();
+ store_dst();
+ };
+
+ preamble();
+ inner_loops();
+ postamble();
+}
+
+void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
+ ::output_transform_data_ker_generate()
+{
+ bool is_fwd = one_of(jcp.prop_kind,
+ mkldnn_forward_training, mkldnn_forward_inference);
+ int outw = is_fwd ? jcp.ow : jcp.iw;
+ int outh = is_fwd ? jcp.oh : jcp.ih;
+ bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
+ bool with_bias = jcp.with_bias;
+ bool with_relu = jcp.with_eltwise;
+ bool with_relu_postsum = jcp.with_relu_postsum;
+ bool with_sum = jcp.with_sum;
+
+ auto zmm_zero = Xbyak::Zmm(0);
+ auto zmm_temp = Xbyak::Zmm(31);
+ auto zmm_G = [=](int i) {
+ return Xbyak::Zmm(1 + i);
+ };
+ auto zmm_O = [=](int i) {
+ return Xbyak::Zmm(1 + alpha + i);
+ };
+ auto zmm_T = [=](int i) {
+ return Xbyak::Zmm(1 + 2 * alpha + i);
+ };
+ auto zmm_t = [=](int i) {
+ return Xbyak::Zmm(1 + 3 * alpha + i);
+ };
+
+ auto init_G = [=]() {
+ mov(oreg_temp, ptr[param1 + GET_OFF(G)]);
+ for (int i = 0; i < 6; i++) {
+ vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]);
+ }
+ };
+
+ auto load_src = [=]() {
+ mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
+ mov(oreg_src, ptr[param1 + GET_OFF(src)]);
+
+ mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
+ imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur,
+ (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
+ * jcp.dimM_simd_block * typesize);
+ add(oreg_src, oreg_nb_tile_block_ur);
+
+ mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
+ imul(oreg_tile_block_ur, oreg_tile_block_ur,
+ jcp.dimM_simd_block * typesize);
+ add(oreg_src, oreg_tile_block_ur);
+
+ if (not_tiled) {
+ mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
+ imul(oreg_tile_block, oreg_tile_block,
+ jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block
+ * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
+ * jcp.dimM_simd_block * typesize);
+ add(oreg_src, oreg_tile_block);
+ }
+
+ int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block)
+ * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize;
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ int j_base_offset = j * alpha * last4dim;
+ int i_base_offset = i * last4dim;
+ vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]);
+ vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w)
+ * typesize], zmm_temp);
+ }
+ }
+ };
+
+ auto store_dst = [=]() {
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+ mov(oreg_dst, ptr[param1 + GET_OFF(dst)]);
+ mov(oreg_O, ptr[param1 + GET_OFF(M)]);
+ mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]);
+ shl(oreg_ydim, 2); // tj * tile_size (==4)
+ mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]);
+ shl(oreg_xdim, 2); // ti * tilesize (==4)
+
+ if (with_bias)
+ mov(oreg_bias, ptr[param1 + GET_OFF(bias)]);
+
+ auto store_one = [=](int j, int i, bool is_aligned) {
+ auto zmm_O = Xbyak::Zmm(31);
+ auto zmm_relu_ns = Xbyak::Zmm(30);
+ auto xmm_relu_ns = Xbyak::Xmm(30);
+ int offset = (j * tile_size * simd_w + i * simd_w) * typesize;
+
+ vmovups(zmm_O, ptr[oreg_O + offset]);
+ if (is_fwd) {
+ if (with_bias) {
+ vaddps(zmm_O, zmm_O, ptr[oreg_bias]);
+ }
+ if (with_relu) {
+ if (jcp.eltwise.alpha == 0) {
+ vmaxps(zmm_O, zmm_O, zmm_zero);
+ } else {
+ Opmask kmask = Opmask(7);
+ mov(imm_addr64, float2int(jcp.eltwise.alpha));
+ vmovq(xmm_relu_ns, imm_addr64);
+ vbroadcastss(zmm_relu_ns, xmm_relu_ns);
+ vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
+ vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
+ }
+ }
+ }
+ if (with_sum) {
+ vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]);
+ if (with_relu_postsum) // orig: with_relu_postsum
+ vmaxps(zmm_O, zmm_O, zmm_zero);
+ }
+ if (is_aligned)
+ vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O);
+ else
+ vmovups(ptr[oreg_out_j + oreg_temp], zmm_O);
+ };
+
+ auto i_loop = [=](int j, bool is_aligned) {
+ for (int i = 0; i < tile_size; i++) {
+ Label next;
+ mov(oreg_temp, oreg_xdim);
+ add(oreg_temp, i);
+ cmp(oreg_temp, outw);
+ jge(next, T_NEAR);
+ shl(oreg_temp, 4 + 2); // * 16 * 4
+
+ store_one(j, i, is_aligned);
+
+ L(next);
+ }
+ };
+
+
+ for (int j = 0; j < tile_size; j++) {
+ Label next, unaligned;
+ mov(oreg_temp, oreg_ydim);
+ add(oreg_temp, j);
+ cmp(oreg_temp, outh);
+ jge(next, T_NEAR);
+
+ mov(oreg_out_j, oreg_dst);
+ imul(oreg_temp, oreg_temp, outw * simd_w * typesize);
+ add(oreg_out_j, oreg_temp);
+
+ test(oreg_dst, 63);
+ jnz(unaligned, T_NEAR);
+
+ i_loop(j, true);
+ jmp(next, T_NEAR);
+
+ L(unaligned);
+ i_loop(j, false);
+
+ L(next);
+ }
+ };
+
+ auto trans_O_4x4_3x3 = [=]() {
+ auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){
+ vmulps(dst, v1, u1);
+ vfmadd231ps(dst, v2, u2);
+ };
+ mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
+ mov(oreg_T, ptr[param1 + GET_OFF(T)]);
+ mov(oreg_O, ptr[param1 + GET_OFF(M)]);
+
+ for (int i = 0; i < alpha; i++) {
+ for (int j = 0; j < alpha; j++) {
+ vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w
+ + i * simd_w) * typesize]);
+ }
+
+ vaddps(zmm_t(0), zmm_O(1), zmm_O(2));
+ vaddps(zmm_t(1), zmm_O(3), zmm_O(4));
+ vsubps(zmm_t(2), zmm_O(1), zmm_O(2));
+ vsubps(zmm_t(3), zmm_O(3), zmm_O(4));
+
+ vaddps(zmm_T(0), zmm_t(0), zmm_t(1));
+ vaddps(zmm_T(0), zmm_T(0), zmm_O(0));
+ fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
+ fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
+ fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
+ vaddps(zmm_T(3), zmm_T(3), zmm_O(5));
+
+ for (int j = 0; j < tile_size; j++) {
+ vmovups(ptr[oreg_T + (j * alpha * simd_w
+ + i * simd_w) * typesize], zmm_T(j));
+ }
+ }
+ for (int j = 0; j < tile_size; j++) {
+ for (int i = 0; i < alpha; i++) {
+ vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w
+ + i * simd_w) * typesize]);
+ }
+ vaddps(zmm_t(0), zmm_T(1), zmm_T(2));
+ vaddps(zmm_t(1), zmm_T(3), zmm_T(4));
+ vsubps(zmm_t(2), zmm_T(1), zmm_T(2));
+ vsubps(zmm_t(3), zmm_T(3), zmm_T(4));
+
+ vaddps(zmm_O(0), zmm_t(0), zmm_t(1));
+ vaddps(zmm_O(0), zmm_O(0), zmm_T(0));
+ fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
+ fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
+ fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
+ vaddps(zmm_O(3), zmm_O(3), zmm_T(5));
+
+ for (int i = 0; i < tile_size; i++) {
+ vmovups(ptr[oreg_O + (j * tile_size * simd_w
+ + i * simd_w) * typesize], zmm_O(i));
+ }
+ }
+ };
+
+ auto inner_loops = [=]() {
+ init_G();
+ load_src();
+ trans_O_4x4_3x3();
+ store_dst();
+ };
+
+ preamble();
+ inner_loops();
+ postamble();
+}
+
+void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
+ ::input_transform_data_ker_generate()
+{
+ bool is_fwd = one_of(jcp.prop_kind,
+ mkldnn_forward_training, mkldnn_forward_inference);
+ int inpw = is_fwd ? jcp.iw : jcp.ow;
+ int inph = is_fwd ? jcp.ih : jcp.oh;
+ int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
+ int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
+ int wp_max = inpw + l_pad;
+ int hp_max = inph + t_pad;
+ bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
+ int G_size = 9;
+
+ auto zmm_zero = Xbyak::Zmm(0);
+ auto zmm_temp = Xbyak::Zmm(31);
+ auto zmm_G = [=](int i) {
+ return Xbyak::Zmm(1 + i);
+ };
+ auto zmm_I = [=](int i) {
+ return Xbyak::Zmm(1 + G_size + i);
+ };
+ auto zmm_T = [=](int i) {
+ return Xbyak::Zmm(1 + G_size + alpha + i);
+ };
+ auto zmm_t = [=](int i) {
+ return Xbyak::Zmm(1 + G_size + 2 * alpha + i);
+ };
+
+ auto init_G = [=]() {
+ mov(ireg_temp, ptr[param1 + GET_OFF(G)]);
+ for (int i = 0; i < G_size; i++) {
+ vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]);
+ }
+ };
+
+ auto load_src = [=]() {
+ mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp
+ mov(ireg_I, ptr[param1 + GET_OFF(M)]);
+
+ xor_(ireg_zero, ireg_zero);
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+
+ mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]);
+ shl(ireg_ydim, 2); // tj * tile_size (==4)
+ mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]);
+ shl(ireg_xdim, 2); // ti * tilesize (==4)
+
+ for (int j = 0; j < alpha; j++) {
+ mov(ireg_temp, ireg_ydim);
+ add(ireg_temp, j);
+
+ mov(ireg_mask_j, 0xffff);
+ cmp(ireg_temp, t_pad);
+ cmovl(ireg_mask_j, ireg_zero);
+ cmp(ireg_temp, hp_max);
+ cmovge(ireg_mask_j, ireg_zero);
+
+ sub(ireg_temp, t_pad);
+ imul(ireg_temp, ireg_temp, inpw * simd_w * typesize);
+ mov(ireg_inp_j, ireg_src);
+ add(ireg_inp_j, ireg_temp);
+
+ for (int i = 0; i < alpha; i++) {
+
+ mov(ireg_temp, ireg_xdim);
+ add(ireg_temp, i);
+
+ mov(ireg_mask, 0xffff);
+ cmp(ireg_temp, l_pad);
+ cmovl(ireg_mask, ireg_zero);
+ cmp(ireg_temp, wp_max);
+ cmovge(ireg_mask, ireg_zero);
+ and_(ireg_mask, ireg_mask_j);
+
+ sub(ireg_temp, l_pad);
+ shl(ireg_temp, 4 + 2);
+
+ vpxord(zmm_temp, zmm_temp, zmm_temp);
+ Opmask kmask = Opmask(7);
+ kmovw(kmask, ireg_mask_32);
+ vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]);
+ vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w)
+ * typesize], zmm_temp);
+ }
+ }
+ };
+
+ auto store_Iw = [=]() {
+
+ mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
+ mov(ireg_output, ptr[param1 + GET_OFF(dst)]);
+
+ bool streamout
+ = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
+ > 2 * LLC_data_size
+ ? true : false;
+
+ if (not_tiled) {
+ mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
+ imul(ireg_tile_block, ireg_tile_block,
+ alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block
+ * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
+ * typesize);
+ }
+
+ mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
+ imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur,
+ jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block
+ * jcp.dimK_reg_block * typesize);
+
+ mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
+ imul(ireg_tile_block_ur, ireg_tile_block_ur,
+ jcp.dimK_reg_block * typesize);
+
+ add(ireg_output, ireg_nb_tile_block_ur);
+ add(ireg_output, ireg_tile_block_ur);
+ if (not_tiled)
+ add(ireg_output, ireg_tile_block);
+
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w
+ + i * simd_w) * typesize]);
+
+ int j_base_offset =
+ j * alpha * jcp.dimN_block * jcp.dimK_nb_block
+ * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
+ * typesize;
+ int i_base_offset =
+ i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
+ * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
+
+ if (not_tiled && streamout)
+ vmovntps(ptr[ireg_output + j_base_offset + i_base_offset],
+ zmm_temp);
+ else
+ vmovups(ptr[ireg_output + j_base_offset + i_base_offset],
+ zmm_temp);
+ }
+ }
+ };
+
+ auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
+ vmulps(zmm_temp, a, b);
+ vaddps(dst, zmm_temp, c);
+ };
+
+ auto trans_I_4x4_3x3 = [=]() {
+ mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
+ mov(ireg_T, ptr[param1 + GET_OFF(T)]);
+ mov(ireg_I, ptr[param1 + GET_OFF(M)]);
+
+ mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch
+ for (int i = 0; i < alpha; i++) {
+ for (int idx = 0; idx < alpha; idx++) {
+ vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w
+ + i * simd_w) * typesize]);
+ int j_base_offset =
+ i * alpha * jcp.dimN_block * jcp.dimK_nb_block
+ * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
+ * typesize;
+ int idx_base_offset =
+ idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
+ * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
+ prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]);
+ }
+
+ fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
+ fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
+ fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
+ fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
+ fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
+ fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
+
+ fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
+ fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
+ fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
+ fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
+ fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
+ fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
+
+ for (int idx = 0; idx < alpha; idx++) {
+ vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w)
+ * typesize],zmm_T(idx));
+ }
+ }
+ for (int i = 0; i < alpha; i++) {
+ for (int idx = 0; idx < alpha; idx++) {
+ vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx
+ * simd_w) * typesize]);
+ }
+
+ fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
+ fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
+ fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
+ fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
+ fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
+ fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
+
+ fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
+ fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
+ fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
+ fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
+ fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
+ fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
+
+ for (int idx = 0; idx < alpha; idx++) {
+ vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w)
+ * typesize],zmm_I(idx));
+ }
+ }
+ };
+
+ auto inner_loops = [=]() {
+ init_G();
+ load_src();
+ trans_I_4x4_3x3();
+ store_Iw();
+ };
+
+ preamble();
+ inner_loops();
+ postamble();
+}
+
+status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common(
+ jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d)
+{
+ if (!mayiuse(avx512_core)) {
+ return status::unimplemented;
+ }
+
+ jcp.nthr = mkldnn_get_max_threads();
+
+ jcp.ver = ver_avx512_core;
+ jcp.prop_kind = cd.prop_kind;
+
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.ih = src_d.dims()[2];
+ jcp.iw = src_d.dims()[3];
+ jcp.oh = dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[3];
+ jcp.kh = weights_d.dims()[with_groups + 2];
+ jcp.kw = weights_d.dims()[with_groups + 3];
+ jcp.t_pad = cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][1];
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+ jcp.dilate_h = cd.dilates[0];
+ jcp.dilate_w = cd.dilates[1];
+ jcp.r_pad = nstl::max(
+ 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
+ jcp.b_pad = nstl::max(
+ 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
+ jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
+ jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
+ jcp.ohp = jcp.oh;
+ jcp.owp = jcp.ow;
+
+ bool ok_to_pad_channels = jcp.ngroups == 1;
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+ }
+
+ // Checking conditions not supported by these kernels
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
+ if (jcp.ngroups != 1)
+ return status::unimplemented;
+ if ((jcp.kh != 3) || (jcp.kw != 3))
+ return status::unimplemented;
+ if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
+ return status::unimplemented;
+ if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
+ return status::unimplemented;
+ if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
+ return status::unimplemented;
+
+ format_tag_t dat_tag = nChw16c;
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+
+ if (jcp.src_tag != dat_tag) return status::unimplemented;
+ if (jcp.dst_tag != dat_tag) return status::unimplemented;
+
+ if (!one_of(weights_d.format_kind(), format_kind::any, format_kind::wino)) {
+ format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
+ jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
+ if (jcp.wei_tag != wei_tag)
+ return status::unimplemented;
+ }
+
+ bool layout_consistency = true
+ && jcp.ic <= src_d.padded_dims()[1]
+ && jcp.oc <= dst_d.padded_dims()[1]
+ && (one_of(weights_d.format_kind(),
+ format_kind::any, format_kind::wino)
+ || (jcp.ic <= weights_d.padded_dims()[with_groups + 1]
+ && jcp.oc <= weights_d.padded_dims()[with_groups + 0]));
+ if (!layout_consistency)
+ return status::unimplemented;
+
+ return status::success;
+}
+
+void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) {
+
+ /* ----------- dimM reg block ---------------------*/
+ auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp,
+ int dimM_reg_block, int current_best) {
+ int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4;
+ return (dimM_reg_block >= 1)
+ && (dimM_reg_block <= max_dimM_reg_block )
+ && (dimM_reg_block > current_best);
+ };
+ jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp,
+ jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block);
+
+ /* ----------- dimN reg block ---------------------*/
+
+ auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
+ int dimN_reg_block, int current_best) {
+ return jcp.kernel_kind == embd_bcast
+ ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best
+ : dimN_reg_block >= 1
+ && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block)
+ < jcp.nb_reg
+ && dimN_reg_block > current_best;
+ };
+ jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp,
+ jcp.dimN, 1, test_cond_dimN_reg_block);
+}
+
+status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) {
+ if (jcp.ver != ver_avx512_core)
+ return status::unimplemented;
+
+ jcp.kernel_kind = embd_bcast;
+
+ set_kernel_dims_reg_block(jcp);
+
+ /*-------------- L2 blocking for dimN block ---------*/
+
+ auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp,
+ int dimN_block, int current_best) {
+ return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0)
+ && (dimN_block > current_best)
+ && ((jcp.dimN / dimN_block / jcp.dimN_reg_block)
+ >= 1.5 * mkldnn_get_max_threads());
+ };
+
+ jcp.dimN_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
+ jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
+
+ if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2)
+ && (jcp.dimN_nb_block >= 1.5 * mkldnn_get_max_threads())) {
+
+ /* ------------------- L1 blocking for GEMM --------------*/
+ /* -------------------- Choose dimK block ----------------*/
+
+ auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp,
+ int dimK_block, int current_best) {
+ return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5)
+ && (dimK_block > current_best);
+ };
+
+ jcp.dimK_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block);
+
+ if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) {
+ jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
+
+ /* -------------- Choose dimM block -------------------*/
+ auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp,
+ int dimM_block, int current_best) {
+ return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block,
+ 0.2, 0.5) && (dimM_block > current_best);
+ };
+
+ jcp.dimM_block = get_divisor_satisfying_cond(jcp,
+ jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1,
+ test_cond_dimM_block);
+ jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
+ / jcp.dimM_simd_block;
+
+ jcp.sched_policy = WSCHED_DATA_W_SGD;
+ return status::success;
+ }
+
+ }
+ return status::unimplemented;
+}
+
+void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) {
+
+ set_kernel_dims_reg_block(jcp);
+
+ //********************* Choosing dimK_block **********************//
+ auto test_cond1_dimK_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
+ return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
+ 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f)
+ && (dimK_block > current_best);
+ };
+
+ auto test_cond1_bis_dimK_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
+ return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
+ jcp.dimK_reg_block, 1, jcp.dimM_reg_block,
+ jcp.dimM_simd_block, .9f)
+ && (dimK_block > current_best);
+ };
+
+ jcp.dimK_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
+ // If we are not able to use streams, we fall back to condition [1]
+ if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
+ jcp.dimK_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
+ jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
+
+ //********************* Choosing dimM_block **********************//
+ auto test_cond1_dimM_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
+ return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
+ jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
+ jcp.dimM_simd_block, .5f)
+ && (dimM_block > current_best);
+ };
+
+ auto test_cond1_bis_dimM_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
+ return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
+ jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
+ jcp.dimM_simd_block, .3f)
+ && (dimM_block > current_best);
+ };
+
+ if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
+ jcp.dimM_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
+ test_cond1_dimM_block);
+ else
+ jcp.dimM_block = get_divisor_satisfying_cond(jcp,
+ jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
+ test_cond1_bis_dimM_block);
+ jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block
+ * jcp.dimM_reg_block);
+
+ //******************* Choosing dimN_block *******************//
+ auto test_cond2_dimN_block = [](
+ jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
+ return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
+ jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
+ jcp.dimM_reg_block, jcp.dimM_simd_block, .9f)
+ && (dimN_block > current_best);
+ };
+
+ jcp.dimN_block = get_divisor_satisfying_cond(
+ jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
+ jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
+}
+
+status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) {
+
+ jcp.kernel_kind = expl_bcast;
+ set_kernel_blocking_DATA_W_S_G_D(jcp);
+ if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block,
+ jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK,
+ .1f, .35f))) {
+ jcp.kernel_kind = embd_bcast;
+ set_kernel_blocking_DATA_W_S_G_D(jcp);
+ }
+ jcp.sched_policy = WSCHED_DATA_W_S_G_D;
+ return status::success;
+}
+
+status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_kernel(
+ jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
+{
+ jcp.nb_reg = 32;
+ jcp.dimN = dimN;
+ jcp.dimK = dimK;
+ jcp.dimM = dimM;
+ jcp.sched_policy = WSCHED_INVALID;
+
+ jcp.dimK_reg_block = 16;
+ jcp.dimM_simd_block = 16;
+
+ if (jcp.kernel_kind == embd_bcast) {
+ jcp.dimM_reg_block = 1;
+ }
+
+ if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success))
+ set_wsched_DATA_W_S_G_D_avx512_core(jcp);
+
+ assert(jcp.sched_policy != WSCHED_INVALID);
+ return status::success;
+}
+
+bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok(
+ jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+
+ switch (p.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_relu(0) || is_sum(0); // relu or sum
+ case 2: return (is_sum(0) && is_relu(1))
+ || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
+ case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf(
+ jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_t &src_md, memory_desc_t &weights_md,
+ const memory_desc_t &dst_md, const primitive_attr_t &attr) {
+
+ status_t st = init_conf_common(jcp, cd, src_md, weights_md, dst_md);
+
+ if (st != status::success)
+ return st;
+
+ // Winograd specific initialization
+ jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
+ jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
+ jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
+
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise)
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+
+ jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
+ jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1;
+
+ status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
+
+ jcp.ic_simd_block = jcp.dimK_reg_block;
+ jcp.ic_block = jcp.dimK_block;
+ jcp.nb_ic = jcp.dimK_nb_block;
+ jcp.oc_simd_block = jcp.dimM_simd_block;
+ jcp.oc_block = jcp.dimM_block;
+ jcp.oc_reg_block = jcp.dimM_reg_block;
+ jcp.ic_reg_block = 1;
+ jcp.nb_oc = jcp.dimM_nb_block;
+ jcp.tile_block_ur = jcp.dimN_reg_block;
+ jcp.nb_tile_block_ur = jcp.dimN_block;
+ jcp.tile_block = jcp.dimN_nb_block;
+
+ /* re-create weights primitive descriptor
+ and set weights wino_blocking */
+ if (cd.prop_kind == mkldnn_forward_inference) {
+ memory_desc_t expect_wei_md = weights_md;
+
+ expect_wei_md.format_kind = format_kind::wino;
+ expect_wei_md.data_type = data_type::f32;
+ mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
+ wd.wino_format = mkldnn_wino_wei_OBaaIBOIio;
+ wd.r = 3;
+ wd.alpha = 6;
+
+ wd.ic = jcp.ic;
+ wd.oc = jcp.oc;
+ wd.ic_block = jcp.dimK_reg_block;
+ wd.oc_block = jcp.dimM_simd_block;
+ wd.ic2_block = jcp.dimK_block;
+ wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block;
+ size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc;
+ wd.size = max_size;
+ wd.adj_scale = 1.f;
+
+ if (weights_md.format_kind == format_kind::any)
+ weights_md = expect_wei_md;
+ if (weights_md != expect_wei_md)
+ return status::unimplemented;
+ }
+
+ return res;
+}
+
+status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf(
+ jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d)
+{
+ status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
+
+ if (st != status::success)
+ return st;
+
+ jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
+ jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
+ jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
+
+ status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
+
+ jcp.oc_simd_block = jcp.dimK_reg_block;
+ jcp.oc_block = jcp.dimK_block;
+ jcp.nb_oc = jcp.dimK_nb_block;
+ jcp.ic_simd_block = jcp.dimM_simd_block;
+ jcp.ic_block = jcp.dimM_block;
+ jcp.ic_reg_block = jcp.dimM_reg_block;
+ jcp.oc_reg_block = 1;
+ jcp.nb_ic = jcp.dimM_nb_block;
+ jcp.tile_block_ur = jcp.dimN_reg_block;
+ jcp.nb_tile_block_ur = jcp.dimN_block;
+ jcp.tile_block = jcp.dimN_nb_block;
+
+ return res;
+}
+
+void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
+src_transform_generate() {
+ constexpr int G_size = 9;
+ const size_t ifwp = jcp.iw + jcp.l_pad;
+ const size_t ifhp = jcp.ih + jcp.t_pad;
+
+ auto zmm_G = [=](int i) {
+ return Xbyak::Zmm(i);
+ };
+ auto zmm_I = [=](int i) {
+ return Xbyak::Zmm(G_size + i);
+ };
+ auto zmm_T = [=](int i) {
+ return Xbyak::Zmm(G_size + alpha + i);
+ };
+ auto zmm_t = [=](int i) {
+ return Xbyak::Zmm(G_size + 2 * alpha + i);
+ };
+
+ auto init_G = [=]() {
+ mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
+ for (int i = 0; i < G_size; i++) {
+ vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
+ }
+ };
+
+ auto load_src = [=]() {
+ mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
+ xor_(reg_zero, reg_zero);
+
+ mov(reg_ydim, reg_tj);
+ shl(reg_ydim, 2); //tj * tile_size(=4)
+
+ for (int j = 0; j < alpha; j++) {
+ /* check if tile index is within physical spatial boundaries*/
+ mov(reg_maskj, 0xffff);
+ cmp(reg_ydim, jcp.t_pad);
+ cmovl(reg_maskj, reg_zero);
+ cmp(reg_ydim, ifhp);
+ cmovge(reg_maskj, reg_zero);
+
+ /*address offset for tile in src*/
+ mov(reg_src_offset, reg_ydim);
+ sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad
+ imul(reg_src_offset, reg_src_offset, jcp.iw);
+
+ mov(reg_xdim, reg_ti);
+ shl(reg_xdim, 2); // xdim = ti * tile_size
+
+ add(reg_src_offset, reg_xdim);
+ sub(reg_src_offset, jcp.l_pad);
+ imul(reg_src_offset, reg_src_offset, simd_w * typesize);
+ for (int i = 0; i < alpha; i++) {
+ /* check if tile index is within physical spatial boundaries*/
+ mov(reg_maski, 0xffff);
+ cmp(reg_xdim, jcp.l_pad);
+ cmovl(reg_maski, reg_zero);
+ cmp(reg_xdim, ifwp);
+ cmovge(reg_maski, reg_zero);
+ and_(reg_maski, reg_maskj);
+
+ Opmask kmask_src = Xbyak::Opmask(7);
+ auto zmm_src = Xbyak::Zmm(31);
+ kmovw(kmask_src, reg_maski_32);
+ vpxord(zmm_src, zmm_src, zmm_src);
+ vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]);
+ vmovups(ptr[reg_I], zmm_src);
+
+ add(reg_xdim, 1); //xdim = ti * tile_size + i
+ add(reg_src_offset, simd_w * typesize);
+ add(reg_I, simd_w * typesize);
+ }
+ add(reg_ydim, 1);
+ }
+ };
+
+ auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) {
+ vmovups(dst, c);
+ vfmadd231ps(dst, a, b);
+ };
+
+ auto trans_I_3x3_4x4 = [=]() {
+ //Use 24 registers
+ mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
+ mov(reg_T, ptr[reg_transp + GET_OFF(T)]);
+ for (int i = 0; i < alpha; i++) {
+ for (int j = 0; j < alpha; j++) {
+ size_t I_off = (j * alpha + i) * simd_w * typesize;
+ vmovups(zmm_I(j), ptr[reg_I + I_off]);
+ }
+
+ fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
+ fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
+ fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
+ fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
+ fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
+ fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
+
+ fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
+ fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
+ fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
+ fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
+ fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
+ fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
+
+ for (int j = 0; j < alpha; j++) {
+ vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize],
+ zmm_T(j));
+ }
+
+ }
+
+ for (int j = 0; j < alpha; j++) {
+ for (int i = 0; i < alpha; i++) {
+ vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]);
+ }
+
+ fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
+ fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
+ fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
+ fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
+ fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
+ fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
+
+ fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
+ fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
+ fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
+ fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
+ fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
+ fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
+
+ for (int i = 0; i < alpha; i++) {
+ size_t dst_off = (j * alpha * jcp.ic_block
+ * jcp.nb_tile_block_ur * jcp.tile_block_ur
+ + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur)
+ * simd_w * typesize;
+ vmovups(ptr[reg_dst + dst_off], zmm_I(i));
+ }
+ }
+ };
+
+ auto compute_transform_SDGtWo = [=]() {
+ mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
+ mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
+ mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
+ mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
+ xor_(reg_tile_count, reg_tile_count);
+ Label loop_mb, loop_jtiles, loop_itiles, done;
+ L(loop_mb);
+ {
+ L(loop_jtiles);
+ {
+ L(loop_itiles);
+ {
+ load_src();
+
+ trans_I_3x3_4x4();
+
+ add(reg_tile_count, 1);
+ cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
+ jge(done);
+
+ add(reg_dst, simd_w * typesize);
+ add(reg_ti, 1);
+ cmp(reg_ti, jcp.itiles);
+ jl(loop_itiles);
+ }
+ xor_(reg_ti, reg_ti);
+ add(reg_tj, 1);
+ cmp(reg_tj, jcp.jtiles);
+ jl(loop_jtiles);
+ }
+ xor_(reg_tj, reg_tj);
+ add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize);
+ jmp(loop_mb);
+ }
+ L(done);
+ };
+
+ auto compute_transform = [=]() {
+ mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
+ xor_(reg_ti, reg_ti);
+ xor_(reg_tj, reg_tj);
+
+ mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
+ mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
+ imul(reg_temp, reg_tile_count, simd_w * typesize);
+ add(reg_dst, reg_temp);
+
+ Label loop_jtiles, loop_itiles, next_tile_block, next_tile;
+ L(loop_jtiles);
+
+ {
+ L(loop_itiles);
+ {
+ load_src();
+
+ trans_I_3x3_4x4();
+
+ add(reg_tile_count, 1);
+ cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
+ jge(next_tile_block);
+ add(reg_dst, simd_w * typesize);
+ jmp(next_tile);
+
+ L(next_tile_block);
+ sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
+ * simd_w * typesize);
+ size_t tblk_off = alpha * alpha * jcp.ic_block
+ * jcp.nb_tile_block_ur * jcp.tile_block_ur
+ * simd_w * typesize;
+ add(reg_dst, tblk_off);
+ xor_(reg_tile_count, reg_tile_count);
+
+ L(next_tile);
+ add(reg_ti, 1);
+ cmp(reg_ti, jcp.itiles);
+ jl(loop_itiles);
+ }
+ xor_(reg_ti, reg_ti);
+ add(reg_tj, 1);
+ cmp(reg_tj, jcp.jtiles);
+ jl(loop_jtiles);
+ }
+ };
+
+ preamble();
+ init_G();
+ if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
+ compute_transform_SDGtWo();
+ else
+ compute_transform();
+ postamble();
+}
+
+void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
+diff_dst_transform_generate(bool with_bias) {
+
+ constexpr int G_size = 8;
+ auto zmm_G = [](int i) {
+ return Xbyak::Zmm(31);
+ };
+
+ auto zmm_src = [=](int j, int i) {
+ return Xbyak::Zmm(G_size + j * 4 + i);
+ };
+
+ auto zmm_bias = Xbyak::Zmm(31);
+
+ auto load_src = [=]() {
+ if (with_bias) vmovups(zmm_bias, ptr[reg_bias]);
+ mov(reg_ydim, reg_tj);
+ shl(reg_ydim, 2); //tj * tile_size(=4)
+ for (int j = 0; j < tile_size; j++) {
+ /* check if tile index is within physical spatial boundaries*/
+ mov(reg_maskj, 0xffff);
+ cmp(reg_ydim, jcp.oh);
+ cmovge(reg_maskj, reg_zero);
+
+ /*address offset for tile in src*/
+ mov(reg_src_offset, reg_ydim);
+ imul(reg_src_offset, reg_src_offset, jcp.ow);
+
+ mov(reg_xdim, reg_ti);
+ shl(reg_xdim, 2); // xdim = ti * tile_size
+
+ add(reg_src_offset, reg_xdim);
+ imul(reg_src_offset, reg_src_offset, simd_w * typesize);
+ for (int i = 0; i < tile_size; i++) {
+ /* check if tile index is within physical spatial boundaries*/
+ mov(reg_maski, 0xffff);
+ cmp(reg_xdim, jcp.ow);
+ cmovge(reg_maski, reg_zero);
+ and_(reg_maski, reg_maskj);
+
+ Opmask kmask_src = Xbyak::Opmask(7);
+ kmovw(kmask_src, reg_maski_32);
+ vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i));
+ vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]);
+ if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias,
+ ptr[reg_src + reg_src_offset]);
+
+ add(reg_xdim, 1); //xdim = ti * tile_size + i
+ add(reg_src_offset, simd_w * typesize);
+ }
+ add(reg_ydim, 1);
+ }
+ if(with_bias) vmovups(ptr[reg_bias], zmm_bias);
+ };
+
+ auto zmm_t = [=](int i) {
+ return Xbyak::Zmm(G_size + 16 + i);
+ };
+
+ auto zmm_T = [=](int j, int i) {
+ return Xbyak::Zmm(j * 4 + i);
+ };
+
+ auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) {
+ if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
+ vmovups(ptr[reg_dst + dst_off], a);
+ else
+ vmovntps(ptr[reg_dst + dst_off], a);
+ };
+
+ auto trans_W_3x3_4x4 = [=]() {
+ mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
+ for (int i = 0; i < tile_size; i++) {
+ vbroadcastss(zmm_G(0), ptr[reg_G]);
+ vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0));
+
+ vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
+ vmovups(zmm_t(1), zmm_t(0));
+ vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1));
+
+ vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
+ vmovups(zmm_t(2), zmm_t(0));
+ vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2));
+
+ vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
+ vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3));
+
+ vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
+ vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4));
+
+ vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
+ vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5));
+
+ vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
+ vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6));
+
+ vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
+ vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7));
+ vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3));
+ vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3));
+ vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4));
+ vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4));
+ vmovups(zmm_T(5, i), zmm_src(3, i));
+ }
+
+ for (int j = 0; j < alpha; j++) {
+ vbroadcastss(zmm_G(0), ptr[reg_G]);
+ vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0));
+
+ vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
+ vmovups(zmm_t(1), zmm_t(0));
+ vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1));
+
+ vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
+ vmovups(zmm_t(2), zmm_t(0));
+ vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2));
+
+ vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
+ vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3));
+
+ vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
+ vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4));
+
+ vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
+ vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5));
+
+ vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
+ vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6));
+
+ vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
+ vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7));
+ vsubps(zmm_t(5), zmm_t(1), zmm_t(3));
+ vaddps(zmm_t(1), zmm_t(1), zmm_t(3));
+ vaddps(zmm_t(6), zmm_t(2), zmm_t(4));
+ vsubps(zmm_t(2), zmm_t(2), zmm_t(4));
+ vmovups(zmm_t(3), zmm_T(j, 3));
+
+ int alpha_offset = (jcp.oc / jcp.nb_oc)
+ * (jcp.ntiles / jcp.tile_block) * typesize;
+ int dst_off = j * alpha * alpha_offset;
+ movps(reg_dst, dst_off, zmm_t(0));
+ dst_off += alpha_offset;
+ movps(reg_dst, dst_off, zmm_t(5));
+ dst_off += alpha_offset;
+ movps(reg_dst, dst_off, zmm_t(1));
+ dst_off += alpha_offset;
+ movps(reg_dst, dst_off, zmm_t(6));
+ dst_off += alpha_offset;
+ movps(reg_dst, dst_off, zmm_t(2));
+ dst_off += alpha_offset;
+ movps(reg_dst, dst_off, zmm_t(3));
+ }
+
+ };
+ auto compute_transform_SDGtWo = [=]() {
+ mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
+ mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
+ if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
+
+ xor_(reg_zero, reg_zero);
+ xor_(reg_oc_ur, reg_oc_ur);
+ Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done;
+
+ L(loop_oc_ur);
+ {
+ mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
+ mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
+ xor_(reg_tile_count, reg_tile_count);
+ L(loop_mb);
+ {
+ L(loop_jtiles);
+ {
+ L(loop_itiles);
+ {
+ load_src();
+
+ trans_W_3x3_4x4();
+
+ add(reg_tile_count, 1);
+ cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
+ jge(tiles_done);
+
+ add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
+ add(reg_ti, 1);
+ cmp(reg_ti, jcp.itiles);
+ jl(loop_itiles);
+ }
+ xor_(reg_ti, reg_ti);
+ add(reg_tj, 1);
+ cmp(reg_tj, jcp.jtiles);
+ jl(loop_jtiles);
+ }
+ xor_(reg_tj, reg_tj);
+ add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize);
+ jmp(loop_mb);
+ }
+
+ L(tiles_done);
+ mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
+ add(reg_dst, simd_w * typesize);
+ mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
+ add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
+
+ if (with_bias) add(reg_bias, simd_w * typesize);
+ add(reg_oc_ur, 1);
+ cmp(reg_oc_ur, jcp.oc_reg_block);
+ jl(loop_oc_ur);
+ }
+ };
+
+ auto compute_transform = [=]() {
+ mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
+ mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
+ if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
+
+ mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
+ mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
+ imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
+ add(reg_dst, reg_temp);
+
+ xor_(reg_zero, reg_zero);
+ xor_(reg_oc_ur, reg_oc_ur);
+ Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile;
+
+ L(loop_oc_ur);
+ {
+ xor_(reg_ti, reg_ti);
+ xor_(reg_tj, reg_tj);
+
+ L(loop_jtiles);
+ {
+ L(loop_itiles);
+ {
+ load_src();
+
+ trans_W_3x3_4x4();
+
+ add(reg_tile_count, 1);
+ cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
+ jge(next_tile_block);
+ add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
+ jmp(next_tile);
+
+ L(next_tile_block);
+ sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
+ * jcp.oc_reg_block * simd_w * typesize);
+ int tblk_off = alpha * alpha * (jcp.oc/jcp.nb_oc)
+ * (jcp.ntiles/jcp.tile_block) * typesize;
+ add(reg_dst, tblk_off);
+ xor_(reg_tile_count, reg_tile_count);
+
+ L(next_tile);
+ add(reg_ti, 1);
+ cmp(reg_ti, jcp.itiles);
+ jl(loop_itiles);
+ }
+ xor_(reg_ti, reg_ti);
+ add(reg_tj, 1);
+ cmp(reg_tj, jcp.jtiles);
+ jl(loop_jtiles);
+ }
+
+ mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
+ mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
+ imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
+ add(reg_dst, reg_temp);
+ add(reg_dst, simd_w * typesize);
+ mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
+ add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
+
+ if (with_bias) add(reg_bias, simd_w * typesize);
+ add(reg_oc_ur, 1);
+ cmp(reg_oc_ur, jcp.oc_reg_block);
+ jl(loop_oc_ur);
+ }
+ };
+
+ preamble();
+ if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
+ compute_transform_SDGtWo();
+ } else {
+ compute_transform();
+ }
+ postamble();
+}
+
+void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
+diff_weights_transform_generate(bool first_tile) {
+ int G_size = 4;
+
+ auto zmm_G = [](int i) {
+ return Xbyak::Zmm(i);
+ };
+
+ auto init_G = [=]() {
+ mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
+ for (int i = 0; i < G_size; i++)
+ vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
+ };
+
+ auto zmm_src = [=](int i) {
+ return Xbyak::Zmm(G_size + i);
+ };
+
+ auto load_src = [=](int i) {
+ for (int j = 0; j < alpha; j++) {
+ size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block
+ * jcp.ic_block * simd_w * simd_w * typesize;
+ size_t src_off = (j * alpha + i) * alpha_offset;
+ vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off));
+ }
+ };
+
+ auto zmm_t = [=](int i) {
+ return Xbyak::Zmm(G_size + 6 + i);
+ };
+
+ auto zmm_T = [=](int j, int i) {
+ return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i);
+ };
+
+ auto zmm_dst = [=](int i) {
+ return Xbyak::Zmm(G_size + i);
+ };
+
+ auto zmm_temp = Xbyak::Zmm(31);
+
+ auto store_dst = [=](int j) {
+ for (int i = 0; i < jcp.kw; i++) {
+ size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize;
+
+ if (!first_tile) {
+ vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off));
+ vaddps(zmm_dst(i), zmm_dst(i), zmm_temp);
+ }
+ vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i));
+ }
+ };
+
+ auto compute_transform = [=] () {
+ mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
+ mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
+
+ xor_(reg_ic_simd, reg_ic_simd);
+ Label loop_ic_simd;
+ L(loop_ic_simd);
+ {
+ for (int i = 0; i < alpha; i++) {
+ load_src(i);
+
+ vaddps(zmm_t(0), zmm_src(1), zmm_src(2));
+ vaddps(zmm_t(1), zmm_src(3), zmm_src(4));
+ vmovups(zmm_t(2), zmm_src(5));
+ vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
+
+ vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0));
+ vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1));
+ vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2));
+ vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1));
+ vsubps(zmm_temp, zmm_src(3), zmm_src(4));
+ vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2));
+ vmovups(zmm_T(2, i), zmm_t(2));
+ vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3));
+ }
+
+ for (int j = 0; j < jcp.kh; j++) {
+ vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2));
+ vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4));
+ vmovups(zmm_t(2), zmm_T(j, 5));
+ vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
+
+ vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0));
+ vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1));
+ vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2));
+ vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1));
+ vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4));
+ vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2));
+ vmovups(zmm_dst(2), zmm_t(2));
+ vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3));
+
+ store_dst(j);
+ }
+
+ add(reg_src, jcp.oc_reg_block * simd_w * typesize);
+ add(reg_dst, simd_w * typesize);
+ add(reg_ic_simd, 1);
+ cmp(reg_ic_simd, simd_w);
+ jl(loop_ic_simd);
+ }
+ };
+ preamble();
+ push(reg_EVEX_max_8b_offt);
+ mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
+ init_G();
+ compute_transform();
+ pop(reg_EVEX_max_8b_offt);
+ postamble();
+}
+
+void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate(
+ bool is_first_tile)
+{
+ auto zmm_srcA = [=]() {
+ return Xbyak::Zmm(0);
+ };
+
+ auto zmm_srcB = [=] (size_t N_ur){
+ return Xbyak::Zmm(N_ur + 1);
+ };
+
+ auto broadcastB = [=](size_t K_ur) {
+ for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
+ size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast)
+ * sizeof(float);
+ vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off));
+ }
+ };
+
+ auto load_srcA = [=] (size_t K_ur, int M_ur) {
+ size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
+ + M_ur * jcp.dimM_simd_block) * sizeof(float);
+ vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off));
+ };
+
+ auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){
+ size_t idx = 1 // zmm_srcA
+ + jcp.dimN_bcast_ur // zmm_srcB
+ + M_reg_ur * jcp.dimN_bcast_ur + N_bcast;
+ assert(idx < 32);
+ return Xbyak::Zmm(idx);
+ };
+ auto prepare_accumm = [=](){
+ for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
+ for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
+ Zmm zmm = zmm_dstC(M_reg_ur, N_bcast);
+ vpxord(zmm, zmm, zmm);
+ }
+ }
+ };
+
+ auto store_dstC = [=](){
+ /******** Write C back to memory *******/
+ for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) {
+ for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) {
+ Zmm zmm = zmm_dstC(M_reg, N_ur);
+ size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
+ + M_reg * jcp.dimM_simd_block) * sizeof(float);
+ if (!is_first_tile) {
+ vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off));
+ vaddps(zmm, zmm, Xbyak::Zmm(0));
+ }
+ vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm);
+ }
+ }
+ };
+
+ auto inner_loops = [=]() {
+ Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur;
+
+ mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
+ L(dimM_block_loop);
+ { /************* OC_block (M) loop ***********/
+ mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
+ L(dimN_block_loop);
+ { /*************** IC_block (N) loop *********/
+
+ mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur);
+ L(dimN_bcast_ur);
+ {
+ prepare_accumm();
+
+ mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
+ L(dimK_block_loop);
+ {
+ /************* nb_tile_ur(K) loop ********/
+ for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) {
+
+ broadcastB(K_ur);
+
+ for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
+ load_srcA(K_ur, M_reg_ur);
+ for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) {
+ vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(),
+ zmm_srcB(N_bcast));
+ }
+ }
+ }
+ add(reg_srcA, jcp.dimK_reg_block
+ * jcp.dimM_reg_block * jcp.dimM_simd_block
+ * sizeof(float));
+ add(reg_srcB, jcp.dimK_reg_block
+ * jcp.dimN_reg_block
+ * sizeof(float));
+ sub(reg_dimK_block_loop_cnt, 1);
+ jnz(dimK_block_loop);
+ }
+
+ store_dstC();
+
+ sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
+ * jcp.dimM_reg_block * jcp.dimM_simd_block
+ * sizeof(float));
+ sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block
+ * jcp.dimN_reg_block
+ * sizeof(float));
+ add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float));
+ add(reg_dstC, jcp.dimN_bcast_ur
+ * jcp.dimM_reg_block * jcp.dimM_simd_block
+ * sizeof(float));
+ sub(reg_nb_dimN_bcast_ur, 1);
+ jnz(dimN_bcast_ur);
+ }
+
+ sub(reg_srcB, jcp.dimN_reg_block * sizeof(float));
+ add(reg_srcB, jcp.dimK_block
+ * jcp.dimK_reg_block
+ * jcp.dimN_reg_block * sizeof(float));
+ sub(reg_dimN_block_loop_cnt, 1);
+ jnz(dimN_block_loop);
+ }
+
+ sub(reg_srcB, jcp.dimN_block
+ * jcp.dimK_block * jcp.dimK_reg_block
+ * jcp.dimN_reg_block
+ * sizeof(float));
+ add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
+ * jcp.dimM_reg_block * jcp.dimM_simd_block
+ * sizeof(float));
+ sub(reg_dimM_block_loop_cnt, 1);
+ jnz(dimM_block_loop);
+ }
+ };
+
+ /* Preamble */
+ preamble();
+
+ inner_loops();
+
+ /* Postamble */
+ postamble();
+ ret();
+}
+
+namespace {
+
+void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) {
+/*M params*/
+ jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
+ / jcp.dimM_simd_block;
+ jcp.oc_reg_block = jcp.dimM_reg_block;
+ jcp.oc_block = jcp.dimM_block;
+ jcp.nb_oc = jcp.dimM_nb_block;
+ /*N params*/
+ jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
+ jcp.ic_block = jcp.dimN_block;
+ jcp.nb_ic = jcp.dimN_nb_block;
+
+ /*K params*/
+ jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
+ jcp.tile_block_ur = jcp.dimK_reg_block;
+ jcp.nb_tile_block_ur = jcp.dimK_block;
+ jcp.tile_block = jcp.dimK_nb_block;
+}
+
+status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
+
+ size_t K_blk_ur, N_blk, M_blk;
+ /* IS this strategy feasible? */
+ auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) {
+ size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float);
+ size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float);
+ size_t nthreads = mkldnn_get_max_threads();
+ return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size)
+ && (jcp.dimK / nthreads >= 1.0);
+ };
+
+ auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur,
+ int max_block=1) {
+ size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float);
+ size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float);
+ size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float);
+ size_t nthreads = mkldnn_get_max_threads();
+ bool load_balance=true;
+ if (!(jcp.dimK % nthreads)) {
+ load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0);
+ }
+ return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size)
+ && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size)
+ && load_balance
+ && (M_L2_block < L2_cache_size);
+ };
+
+ auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
+ int useless_arg=0) {
+ return (dimK_ur >= 2) && (dimK_ur <= 8);
+ };
+
+ auto blocking_ok = [&](){
+ size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
+ * K_blk_ur * sizeof(float);
+ size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block
+ * K_blk_ur * sizeof(float);
+ size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
+ * N_blk * jcp.dimN_reg_block * sizeof(float);
+ size_t L2_block = M_L2_block + V_L2_block + U_L2_block;
+ /*Replace 2.375 with L2+L3 cache size*/
+ return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size);
+ };
+
+ if (test_MV_large_enough(jcp)) {
+ if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
+ jcp.dimM_reg_block = 2;
+ } else {
+ jcp.dimM_reg_block = 1;
+ }
+ jcp.dimM_simd_block = jcp.oc_simd_block;
+ jcp.dimN_reg_block = jcp.ic_simd_block;
+ jcp.dimN_bcast_ur = 8;
+ /*dimK_block and dimK_ur*/
+ size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1);
+
+ jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block;
+ jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block;
+ for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) {
+ if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) {
+ for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
+ if (!(jcp.dimN_block % N_blk)) {
+ for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
+ if (!(jcp.dimM_block % M_blk) && blocking_ok()) {
+ jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
+ if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented;
+ jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
+ jcp.dimN_block = N_blk;
+ jcp.dimM_block = M_blk;
+ jcp.sched_policy = WSCHED_WEI_SDGtWo;
+ set_jcp_WEI_params(jcp);
+ jcp.nthr = nstl::min(mkldnn_get_max_threads(),
+ jcp.tile_block);
+ return status::success;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ return status::unimplemented;
+}
+
+status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) {
+ if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
+ jcp.dimM_reg_block = 2;
+ } else {
+ jcp.dimM_reg_block = 1;
+ }
+ jcp.dimN_bcast_ur = 8;
+ jcp.dimN_reg_block = jcp.ic_simd_block;
+ jcp.dimM_simd_block = jcp.oc_simd_block;
+ jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block;
+ jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block;
+ float C1 = 0.0, C2 = 0.0;
+ float C1_max = 0.5, C2_max = 1.4;
+ int N_blk, M_blk, K_blk_ur;
+
+ auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
+ int useless_arg=0) {
+ return (dimK_ur >= 2) && (dimK_ur <= 8);
+ };
+
+ auto blocking_ok = [&]() -> bool {
+ size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float);
+ size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float);
+ bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size)
+ && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size);
+
+ size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block;
+ size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block;
+ size_t nb_K_blk = jcp.dimK / K_blk_ur;
+ size_t nthreads = mkldnn_get_max_threads();
+ bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads;
+ if (!(nb_K_blk % nthreads)) {
+ load_balance = load_balance && (nb_K_blk % nthreads == 0);
+ }
+
+ size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float);
+
+ size_t L2_block = V_L2_block;
+ /*Replace 2.375 with L2+L3 cache size*/
+ bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size);
+ return L1_cond && load_balance && L2_cond;
+ };
+
+ for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) {
+ if (jcp.dimK % K_blk_ur == 0) {
+ for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
+ if (jcp.dimN_block % N_blk == 0) {
+ for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
+ if (jcp.dimM_block % M_blk == 0) {
+ if (blocking_ok()) {
+ jcp.dimN_block = N_blk;
+ jcp.dimM_block = M_blk;
+ jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
+ jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
+ jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
+ set_jcp_WEI_params(jcp);
+ return status::success;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ jcp.dimK_reg_block = 1;
+ jcp.dimK_block = 1;
+ jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
+ set_jcp_WEI_params(jcp);
+ return status::success;
+}
+} // namespace
+status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf(
+ jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
+ const memory_desc_wrapper &diff_weights_d) {
+ if (!mayiuse(avx512_core))
+ return status::unimplemented;
+ else
+ jcp.ver = ver_avx512_core;
+
+ jcp.nthr = mkldnn_get_max_threads();
+
+ jcp.prop_kind = cd.prop_kind;
+ const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
+ jcp.mb = src_d.dims()[0];
+ jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
+ jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.ih = src_d.dims()[2];
+ jcp.iw = src_d.dims()[3];
+ jcp.oh = diff_dst_d.dims()[2];
+ jcp.ow = diff_dst_d.dims()[3];
+ jcp.kh = diff_weights_d.dims()[with_groups + 2];
+ jcp.kw = diff_weights_d.dims()[with_groups + 3];
+ jcp.t_pad = cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][1];
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+ jcp.r_pad = nstl::max(
+ 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
+ jcp.b_pad = nstl::max(
+ 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
+ jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
+ jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
+ jcp.ohp = jcp.oh;
+ jcp.owp = jcp.ow;
+ jcp.with_bias = (cd.diff_bias_desc.format_kind != format_kind::undef);
+ jcp.dilate_h = cd.dilates[0];
+ jcp.dilate_w = cd.dilates[1];
+
+ bool ok_to_pad_channels = jcp.ngroups == 1;
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+ }
+
+ // Winograd specific initialization
+ jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
+ jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
+ jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
+
+ // Winograd kernel works only for 3x3 convolution with stride 1
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
+ if (jcp.ngroups != 1)
+ return status::unimplemented;
+ if ((jcp.kh != 3) || (jcp.kw != 3))
+ return status::unimplemented;
+ if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
+ return status::unimplemented;
+ if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
+ return status::unimplemented;
+ if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
+ return status::unimplemented;
+
+ format_tag_t dat_tag = nChw16c;
+ format_tag_t wei_tag = with_groups ? gOIhw16i16o : OIhw16i16o;
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
+
+ if (jcp.src_tag != dat_tag) return status::unimplemented;
+ if (jcp.wei_tag != wei_tag) return status::unimplemented;
+ if (jcp.dst_tag != dat_tag) return status::unimplemented;
+
+ bool layout_consistency = true
+ && jcp.ic <= src_d.padded_dims()[1]
+ && jcp.oc <= diff_dst_d.padded_dims()[1]
+ && jcp.ic <= diff_weights_d.padded_dims()[with_groups + 1]
+ && jcp.oc <= diff_weights_d.padded_dims()[with_groups + 0];
+ if (!layout_consistency) return status::unimplemented;
+
+ /******************Kernel blocking Parameters ***********/
+ jcp.ic_simd_block = simd_w;
+ jcp.oc_simd_block = simd_w;
+
+ jcp.dimK = jcp.ntiles;
+ jcp.dimN = jcp.ic;
+ jcp.dimM = jcp.oc;
+ jcp.dimM_simd_block = jcp.oc_simd_block;
+ jcp.dimN_reg_block = jcp.ic_simd_block;
+ jcp.sched_policy = WSCHED_INVALID;
+ status_t res = set_wsched_WEI_SDGtWo(jcp);
+ if (res == status::unimplemented) {
+ res = set_wsched_WEI_S_D_Giot_W(jcp);
+ assert(res == status::success);
+ }
+ return res;
+}
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp
new file mode 100644
index 0000000000..025a554d92
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp
@@ -0,0 +1,291 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP
+#define JIT_AVX512_CORE_FP32_WINO_CONV_4x3_KERNEL_HPP
+
+#include "c_types_map.hpp"
+
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+
+#include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
+ : public jit_generator {
+ _jit_avx512_core_fp32_wino_conv_4x3_data_kernel(
+ jit_conv_winograd_conf_t ajcp)
+ : jcp(ajcp) {
+ {
+ this->weights_transform_data_ker_generate();
+ weights_transform_data_ker
+ = (decltype(weights_transform_data_ker)) this->getCode();
+ }
+ {
+ align();
+ const Xbyak::uint8 *addr = getCurr();
+ this->input_transform_data_ker_generate();
+ input_transform_data_ker = (decltype(input_transform_data_ker))addr;
+ }
+ {
+ align();
+ const Xbyak::uint8 *addr = getCurr();
+ this->output_transform_data_ker_generate();
+ output_transform_data_ker
+ = (decltype(output_transform_data_ker))addr;
+ }
+ {
+ align();
+ const Xbyak::uint8 *addr = getCurr();
+ this->gemm_loop_generate();
+ gemm_loop_ker = (decltype(gemm_loop_ker))addr;
+ }
+ }
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_fp32_wino_conv_4x3_data_kernel)
+
+ static status_t init_conf_common(jit_conv_winograd_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d);
+
+ static status_t init_conf_kernel(
+ jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK);
+
+ jit_conv_winograd_conf_t jcp;
+ void (*gemm_loop_ker)(float *, const float *, const float *, const int);
+ void (*input_transform_data_ker)(jit_wino_transform_call_s *);
+ void (*output_transform_data_ker)(jit_wino_transform_call_s *);
+ void (*weights_transform_data_ker)(jit_wino_transform_call_s *);
+
+protected:
+ using reg64_t = const Xbyak::Reg64;
+ using reg32_t = const Xbyak::Reg32;
+ enum { typesize = sizeof(float) };
+
+ void gemm_loop_generate();
+ void input_transform_data_ker_generate();
+ void output_transform_data_ker_generate();
+ void weights_transform_data_ker_generate();
+
+ /* registers used for GEMM */
+ reg64_t reg_dstC = abi_param1;
+ reg64_t reg_srcA = abi_param2;
+ reg64_t reg_srcB = abi_param3;
+ reg64_t reg_is_beta_zero = abi_param4;
+
+ reg64_t reg_dimM_block_loop_cnt = r10;
+ reg64_t reg_dimK_block_loop_cnt = r11;
+
+ /* registers used for transforms*/
+ reg64_t param = abi_param1;
+
+ /* registers used for output_transform_data_ker */
+ reg64_t oreg_temp = abi_not_param1;
+ reg64_t oreg_Ow = r9;
+ reg64_t oreg_src = r11;
+ reg64_t oreg_tile_block = r12;
+ reg64_t oreg_tile_block_ur = r13;
+ reg64_t oreg_nb_tile_block_ur = r14;
+ reg64_t oreg_O = r8;
+ reg64_t oreg_T = r10;
+ reg64_t oreg_dst = r11;
+ reg64_t oreg_ydim = r14;
+ reg64_t oreg_xdim = r15;
+ reg64_t oreg_out_j = r12;
+ reg64_t oreg_bias = rbx;
+ reg64_t imm_addr64 = rax;
+
+ /* registers used for input_transform_data_ker */
+ reg64_t ireg_temp = abi_not_param1;
+ reg64_t ireg_jtiles = rax;
+ reg64_t ireg_itiles = rbx;
+ reg64_t ireg_I = r8;
+ reg64_t ireg_src = r13;
+ reg64_t ireg_ydim = r14;
+ reg64_t ireg_xdim = r15;
+ reg64_t ireg_inp_j = r12;
+ reg64_t ireg_inp_i = rdx;
+ reg64_t ireg_mask_j = r11;
+ reg64_t ireg_mask = rsi;
+ reg32_t ireg_mask_32 = esi;
+ reg64_t ireg_zero = r9;
+ reg64_t ireg_Iw = r9;
+ reg64_t ireg_T = r10;
+ reg64_t ireg_tile_block = r12;
+ reg64_t ireg_tile_block_ur = r13;
+ reg64_t ireg_nb_tile_block_ur = r14;
+ reg64_t ireg_output = r15;
+
+ /* registers used for wei transform */
+ reg64_t wreg_temp = abi_not_param1;
+ reg64_t wreg_F = r8;
+ reg64_t wreg_src = r9;
+ reg64_t wreg_MT = r15;
+ reg64_t wreg_M = r14;
+ reg64_t wreg_dst = r10;
+ reg64_t wreg_dst_aux = r9;
+ reg64_t wreg_dst_idx = r8;
+ reg64_t wreg_Fw = r11;
+ reg64_t wreg_T = r12;
+ reg64_t wreg_cnt_j = rdx;
+ reg64_t wreg_F_aux = r14;
+ reg64_t wreg_Fw_aux = r15;
+};
+
+struct jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel
+ : _jit_avx512_core_fp32_wino_conv_4x3_data_kernel {
+ using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::
+ _jit_avx512_core_fp32_wino_conv_4x3_data_kernel;
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp, const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_conv_winograd_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_t &src_md,
+ memory_desc_t &weights_md, const memory_desc_t &dst_md,
+ const primitive_attr_t &attr);
+};
+
+struct jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel
+ : public _jit_avx512_core_fp32_wino_conv_4x3_data_kernel {
+ using _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::
+ _jit_avx512_core_fp32_wino_conv_4x3_data_kernel;
+
+ static status_t init_conf(jit_conv_winograd_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d);
+};
+
+struct jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel
+ : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(
+ _jit_avx512_core_conv_winograd_bwd_weights_kernel_f32)
+
+ jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel(
+ jit_conv_winograd_conf_t ajcp)
+ : jcp(ajcp)
+ {
+ //******************* First iter kernel ********************//
+ this->gemm_loop_generate(true);
+ gemm_loop_ker_first_iter = (decltype(gemm_loop_ker_first_iter))this->getCode();
+
+ align();
+ const Xbyak::uint8 *addr = getCurr();
+ this->src_transform_generate();
+ src_transform = (decltype(src_transform))addr;
+
+ if (jcp.with_bias) {
+ align();
+ addr = getCurr();
+ this->diff_dst_transform_generate(true);
+ diff_dst_transform_wbias = (decltype(diff_dst_transform_wbias))addr;
+ }
+
+ align();
+ addr = getCurr();
+ this->diff_dst_transform_generate(false);
+ diff_dst_transform = (decltype(diff_dst_transform))addr;
+
+ if (jcp.sched_policy != WSCHED_WEI_SDGtWo && jcp.tile_block > 1) {
+ align();
+ addr = getCurr();
+ this->gemm_loop_generate(false);
+ gemm_loop_ker = (decltype(gemm_loop_ker))addr;
+ }
+
+ align();
+ addr = getCurr();
+ this->diff_weights_transform_generate(true);
+ diff_weights_transform = (decltype(diff_weights_transform))addr;
+
+ if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
+ align();
+ addr = getCurr();
+ this->diff_weights_transform_generate(false);
+ diff_weights_transform_accum =
+ (decltype(diff_weights_transform_accum))addr;
+ };
+ }
+
+ static status_t init_conf(jit_conv_winograd_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &diff_dst_d,
+ const memory_desc_wrapper &diff_weights_d);
+
+ jit_conv_winograd_conf_t jcp;
+ void (*gemm_loop_ker)(float *, const float *, const float *);
+ void (*gemm_loop_ker_first_iter)(float *, const float *, const float *);
+ void (*src_transform)(jit_wino_transform_call_s *);
+ void (*diff_dst_transform)(jit_wino_transform_call_s *);
+ void (*diff_dst_transform_wbias)(jit_wino_transform_call_s *);
+ void (*diff_weights_transform)(jit_wino_transform_call_s *);
+ void (*diff_weights_transform_accum)(jit_wino_transform_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ using reg32_t = const Xbyak::Reg32;
+ enum { typesize = sizeof(float) };
+
+ void src_transform_generate();
+ void diff_dst_transform_generate(bool with_bias);
+ void diff_weights_transform_generate(bool first_tile);
+
+ /*registers common to transforms*/
+ reg64_t reg_transp = abi_param1;
+ reg64_t reg_ti = rbx;
+ reg64_t reg_tj = abi_not_param1;
+ reg64_t reg_src = r8;
+ reg64_t reg_dst = r9;
+ reg64_t reg_G = rsi; /*TODO: check if this is ok*/
+ reg64_t reg_temp = rsi;
+
+ /*registers common to src/diff_dst transform*/
+ reg64_t reg_I = r10;
+ reg64_t reg_ydim = r11;
+ reg64_t reg_xdim = r12;
+ reg64_t reg_src_offset = r13;
+ reg64_t reg_zero = r14;
+ reg64_t reg_tile_count = r15;
+ reg64_t reg_maski = rsi;
+ reg32_t reg_maski_32 = esi;
+ reg64_t reg_maskj = rdx;
+
+ reg64_t reg_T = rax;
+ reg64_t reg_oc_ur = rax;
+ reg64_t reg_ic_simd = r14;
+ reg64_t reg_bias = r10;
+
+ void gemm_loop_generate(bool is_first_tile);
+
+ reg64_t reg_dstC = abi_param1;
+ reg64_t reg_srcA = abi_param2;
+ reg64_t reg_srcB = abi_param3;
+
+ reg64_t reg_dimM_block_loop_cnt = r9;
+ reg64_t reg_dimN_block_loop_cnt = r10;
+ reg64_t reg_nb_dimN_bcast_ur = r11;
+ reg64_t reg_dimK_block_loop_cnt = r12;
+};
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp
new file mode 100644
index 0000000000..002010ffa2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.cpp
@@ -0,0 +1,1284 @@
+/*******************************************************************************
+ * Copyright 2018 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_avx512_core_u8s8s32x_wino_convolution.hpp"
+#include "jit_generator.hpp"
+
+#include <string.h>
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+using namespace Xbyak;
+
+namespace {
+ // Below scales are applied to source and weights data accordingly
+ // because this winograd implementation
+ // transforms source which may increase values up to 4x
+ // and transforms weights which may increase values up to 9/4x
+ const float adj_src_scale = 1.f / 4.f;
+ const float adj_wei_scale = 4.f / 9.f;
+ // Winograd transforms need ic and oc to be multiples of 16
+ const int load_block = 16;
+}
+
+/// SRC TRANSFORMS /////////////////////////////////////////////////////////////
+struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(
+ jit_avx512_core_u8s8s32x_wino_conv_src_trans_t)
+
+ jit_conv_conf_2x3_wino_t jcp;
+ const primitive_attr_t &attr_;
+
+ struct call_params_t {
+ const void *src;
+ const void *wino_src;
+ const void *v_y_masks;
+ const void *v_x_masks;
+ };
+ void (*ker_)(const call_params_t *);
+
+ jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
+ jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) {
+ generate();
+ ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
+ }
+ void generate();
+
+ int reg_inp_ind(int i) {
+ assert(i < jcp.alpha * jcp.alpha);
+ return (31 - i);
+ }
+
+ Xmm vreg_inp(int i) {
+ return Xmm(reg_inp_ind(i));
+ }
+
+ Zmm zmm_inp(int i) {
+ return Zmm(reg_inp_ind(i));
+ }
+
+ Xmm vreg_tmp(int i) {
+ assert(i < jcp.alpha * jcp.alpha);
+ return Xmm(15 - i);
+ }
+ Xmm vreg_out(int i) {
+ assert(i < jcp.alpha * jcp.alpha);
+ return Xmm(31 - i);
+ }
+
+ Opmask y_mask = Opmask(1);
+ Opmask r_mask = Opmask(2);
+ Opmask x_mask(int id) {
+ assert(id < 4);
+ return Opmask(3 + id);
+ }
+
+ Reg64 reg_ptr_src = r14;
+ Reg64 reg_ptr_dst = r13;
+
+ Reg64 reg_ptr_v_y_masks = r12;
+ Reg64 reg_ptr_v_x_masks = r11;
+
+ Reg64 reg_aux_ptr_src = r10;
+ Reg64 reg_aux_ptr_dst = r9;
+
+ Reg64 reg_ic_block = r8;
+
+ int unsign_val_in_wino_domain;
+
+ Reg64 reg_scratch_src_alpha = rdx;
+ Xmm xmm_src_alpha = Xmm(0);
+ Zmm zmm_src_alpha = Zmm(0);
+
+ Reg64 reg_shift = rax;
+ Xmm xmm_shift = Xmm(1);
+ Xmm xmm_zero = Xmm(0);
+
+ Reg64 reg_maskx = rbx;
+ Reg64 reg_masky = rsi;
+ Reg64 reg_nomask = reg_maskx;
+};
+
+void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
+ Label ic_block_label;
+ Label end_label;
+ Label mask_label;
+ Label nomask_label;
+
+ auto load_src = [=](bool mask) {
+ for (int y = 0; y < jcp.alpha; y++) {
+ if (mask)
+ kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
+ for (int x = 0; x < jcp.alpha; x++) {
+ Zmm zmm_i = zmm_inp(y * jcp.alpha + x);
+ Xmm vreg_i = vreg_inp(y * jcp.alpha + x);
+ int inp_offset = sizeof(uint8_t)
+ * ((-jcp.t_pad + y) * jcp.iw * jcp.ic
+ + (-jcp.l_pad + x) * jcp.ic);
+ if (mask) {
+ kandw(r_mask, y_mask, x_mask(x));
+ vmovdqu8(vreg_i | r_mask | T_z,
+ EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
+ } else {
+ vmovdqu8(vreg_i,
+ EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
+ }
+ vpmovzxbd(zmm_i, vreg_i); // to int32
+ vcvtdq2ps(zmm_i, zmm_i); // to fp32
+ vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
+ vcvtps2dq(zmm_i, zmm_i); // to int32
+ vpmovusdb(vreg_i, zmm_i); // to u8
+ }
+ }
+ };
+
+ preamble();
+
+# define READ_PARAM(reg, field) \
+ mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
+ READ_PARAM(reg_ptr_src, src);
+ READ_PARAM(reg_ptr_dst, wino_src);
+ READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
+ READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
+# undef READ_PARAM
+
+ mov(reg_maskx, ptr[reg_ptr_v_x_masks]);
+ mov(reg_masky, ptr[reg_ptr_v_y_masks]);
+ test(reg_maskx, reg_maskx);
+ jz(end_label, T_NEAR); // skip kernel if x mask is all 0's
+ test(reg_masky, reg_masky);
+ jz(end_label, T_NEAR); // skip kernel if y mask is all 0's
+ and_(reg_maskx, reg_masky);
+ mov(reg_nomask, reg_maskx);
+ not_(reg_nomask); // zero if x and y masks are all 1's
+
+ xor_(reg_shift, reg_shift);
+ mov(reg_shift.cvt8(), (int8_t)-128);
+
+ mov(reg_aux_ptr_src, reg_ptr_src);
+ mov(reg_aux_ptr_dst, reg_ptr_dst);
+
+ for (int i = 0; i < jcp.alpha; i++) {
+ kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
+ }
+
+ mov(reg_scratch_src_alpha, float2int(adj_src_scale));
+
+ mov(reg_ic_block, jcp.ic / load_block);
+ L(ic_block_label);
+ {
+ vmovq(xmm_src_alpha, reg_scratch_src_alpha);
+ vbroadcastss(zmm_src_alpha, xmm_src_alpha);
+
+ test(reg_nomask, reg_nomask);
+ jz(nomask_label, T_NEAR);
+ load_src(true);
+ jmp(mask_label, T_NEAR);
+ L(nomask_label);
+ load_src(false);
+ L(mask_label);
+
+ for(int y = 0; y < 4; y++) {
+ vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2));
+ vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2));
+ vpsubb(vreg_tmp(y*4+2), vreg_inp(y*4+2), vreg_inp(y*4+1));
+ vpsubb(vreg_tmp(y*4+3), vreg_inp(y*4+1), vreg_inp(y*4+3));
+ }
+ for(int x = 0;x < 4; x++) {
+ vpsubb(vreg_out(x+0*4), vreg_tmp(x+4*0), vreg_tmp(x+4*2));
+ vpaddb(vreg_out(x+1*4), vreg_tmp(x+4*1), vreg_tmp(x+4*2));
+ vpsubb(vreg_out(x+2*4), vreg_tmp(x+4*2), vreg_tmp(x+4*1));
+ vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3));
+ }
+
+ vmovd(xmm_shift, reg_shift.cvt32());
+ vpxor(xmm_zero, xmm_zero, xmm_zero);
+ vpshufb(xmm_shift, xmm_shift, xmm_zero);
+
+ for (int i = 0; i < 16; i++) {
+ int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
+ if (i != unsign_val_in_wino_domain)
+ vpsubb(vreg_out(i), vreg_out(i), Xmm(1));
+ vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i));
+ }
+
+ add(reg_aux_ptr_src, sizeof(uint8_t) * load_block);
+ add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block);
+ }
+ dec(reg_ic_block);
+ jnz(ic_block_label, T_NEAR);
+
+ L(end_label);
+ postamble();
+}
+
+/// DST TRANSFORMS /////////////////////////////////////////////////////////////
+struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(
+ jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t)
+
+ jit_conv_conf_2x3_wino_t jcp;
+ const primitive_attr_t &attr_;
+
+ struct call_params_t {
+ const void *wino_dst;
+ const void *dst;
+ const void *v_y_masks;
+ const void *v_x_masks;
+
+ const void *bias;
+ const void *scales;
+ };
+ void (*ker_)(const call_params_t *);
+
+ jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
+ jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr) {
+ generate();
+ ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
+ }
+
+ void generate();
+ bool maybe_relu(int position);
+
+ Zmm vreg_inp(int i) { // 16
+ assert(i < jcp.alpha * jcp.alpha);
+ return Zmm(31 - i);
+ }
+ Zmm vreg_stg(int id) { // 8
+ const int id_reg_stg = jcp.alpha * jcp.alpha + id;
+ assert(id < 8);
+ return Zmm(31 - id_reg_stg);
+ }
+ Zmm vreg_out(int id) { // 4
+ const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
+ assert(id < 4);
+ return Zmm(31 - id_reg_out);
+ }
+ Xmm xmm_out(int id) { // 4
+ const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
+ assert(id < 4);
+ return Xmm(31 - id_reg_out);
+ }
+ Zmm vreg_tmp(int id) { // 2
+ const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
+ assert(id < 2);
+ return Zmm(31 - id_reg_tmp);
+ }
+
+ Zmm vreg_zero = Zmm(0);
+ Zmm vreg_bias = Zmm(1);
+ Zmm vreg_prev_dst = Zmm(2);
+ Zmm zmm_bias_alpha = Zmm(2);
+ Xmm xmm_bias_alpha = Xmm(2);
+
+ Opmask y_mask = Opmask(1);
+ Opmask r_mask = Opmask(2);
+ Opmask x_mask(int id) {
+ assert(id < 4);
+ return Opmask(3 + id);
+ }
+
+ Reg64 reg_scratch_bias_alpha = r15;
+
+ Reg64 reg_ptr_src = r14;
+ Reg64 reg_ptr_dst = r13;
+
+ Reg64 reg_ptr_v_y_masks = r12;
+ Reg64 reg_ptr_v_x_masks = r11;
+
+ Reg64 reg_aux_ptr_src = r10;
+ Reg64 reg_aux_ptr_dst = r9;
+
+ Reg64 reg_oc_block = r8;
+
+ Reg64 reg_ptr_bias = rbx;
+ Reg64 reg_ptr_scales = abi_not_param1;
+ Reg64 reg_ptr_sum_scale = rdx;
+};
+
+bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) {
+ using namespace primitive_kind;
+ const auto &p = attr_.post_ops_;
+
+ if (position == 0) {
+ /* relu before sum */
+ return false
+ || p.contain(eltwise, 0)
+ || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
+ } else if (position == 1) {
+ /* relu after sum */
+ const int sum_idx = p.contain(sum, 0)
+ ? 0 : (p.contain(sum, 1) ? 1 : -1);
+ if (sum_idx == -1)
+ return false;
+
+ return false
+ || p.contain(eltwise, sum_idx + 1)
+ || jcp.dst_dt == data_type::u8;
+ }
+
+ return false;
+}
+
+void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
+ Label oc_block_label;
+
+ auto loop_body = [=]() {
+ const auto &p = attr_.post_ops_;
+ const int sum_idx = p.find(primitive_kind::sum);
+ const float *p_sum_scale = (sum_idx != -1)
+ ? &p.entry_[sum_idx].sum.scale
+ : nullptr;
+ if (p_sum_scale && *p_sum_scale != 1.f)
+ mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
+
+ for(int i = 0; i < 16; i++) {
+ int internal_offset = sizeof(int32_t) * jcp.out_stride * i;
+ vmovups(vreg_inp(i),
+ EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
+ }
+ for(int y = 0; y < jcp.alpha; y++) {
+ vpaddd(vreg_tmp(0), vreg_inp(y*4 + 0), vreg_inp(y*4 + 1));
+ vpaddd(vreg_stg(y*2), vreg_tmp(0), vreg_inp(y*4 + 2));
+
+ vpsubd(vreg_tmp(1), vreg_inp(y*4 + 1), vreg_inp(y*4 + 2));
+ vpsubd(vreg_stg(y*2+1), vreg_tmp(1), vreg_inp(y*4 + 3));
+ }
+ for(int x = 0; x < jcp.m; x++) {
+ vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2*1));
+ vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x+2*2));
+
+ vpsubd(vreg_tmp(1), vreg_stg(x+2*1), vreg_stg(x+2*2));
+ vpsubd(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2*3));
+ }
+
+
+ if (jcp.with_bias) {
+ vmovq(xmm_bias_alpha, reg_scratch_bias_alpha);
+ vbroadcastss(zmm_bias_alpha, xmm_bias_alpha);
+
+ auto bias_addr = ptr [ reg_ptr_bias ];
+ switch (jcp.bia_dt) {
+ case data_type::f32:
+ case data_type::s32: vmovups(vreg_bias, bias_addr); break;
+ case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break;
+ case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break;
+ default: assert(!"unsupported dst data type");
+ }
+ if (jcp.bia_dt != data_type::f32)
+ vcvtdq2ps(vreg_bias, vreg_bias);
+ vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha
+ }
+ for(int y = 0; y < jcp.m; y++) {
+ kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]);
+ for(int x = 0; x < jcp.m; x++) {
+ kandw(r_mask, y_mask, x_mask(x));
+
+ int i = y * jcp.m + x;
+ int offset = jcp.typesize_out *
+ (y * jcp.ow * jcp.oc + x * jcp.oc);
+ Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
+
+ Zmm zmm = vreg_out(i);
+ Xmm xmm = xmm_out(i);
+ vcvtdq2ps(zmm, zmm);
+ if (jcp.with_bias)
+ vaddps(zmm, zmm, vreg_bias);
+ vmulps(zmm, zmm, ptr [reg_ptr_scales]);
+ if (maybe_relu(0))
+ vmaxps(zmm, vreg_zero, zmm);
+ if (p_sum_scale) { // post_op: sum
+ vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
+ switch (jcp.dst_dt) {
+ case data_type::f32:
+ case data_type::s32:
+ vmovups(vreg_prev_dst | r_mask, addr); break;
+ case data_type::s8:
+ vpmovsxbd(vreg_prev_dst | r_mask, addr); break;
+ case data_type::u8:
+ vpmovzxbd(vreg_prev_dst | r_mask, addr); break;
+ default: assert(!"unknown dst_dt");
+ }
+ if (jcp.dst_dt != data_type::f32)
+ vcvtdq2ps(vreg_prev_dst, vreg_prev_dst);
+ if (*p_sum_scale == 1.f)
+ vaddps(zmm, vreg_prev_dst);
+ else
+ vfmadd231ps(zmm, vreg_prev_dst,
+ zword_b[reg_ptr_sum_scale]);
+ }
+ if (maybe_relu(1))
+ vmaxps(zmm, vreg_zero, zmm);
+ if (jcp.dst_dt != data_type::f32)
+ vcvtps2dq(zmm, zmm);
+ switch (jcp.dst_dt) {
+ case data_type::f32:
+ case data_type::s32:
+ vmovups(addr, zmm | r_mask); break;
+ case data_type::s8:
+ vpmovsdb(xmm, zmm); vmovups(addr, xmm | r_mask); break;
+ case data_type::u8:
+ vpmovusdb(xmm, zmm); vmovups(addr, xmm | r_mask); break;
+ default: assert(!"unknown dst_dt");
+ }
+ }
+ }
+ };
+
+ preamble();
+
+# define READ_PARAM(reg, field) \
+ mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
+ READ_PARAM(reg_ptr_src, wino_dst);
+ READ_PARAM(reg_ptr_dst, dst);
+ READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
+ READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
+ READ_PARAM(reg_ptr_bias, bias);
+ READ_PARAM(reg_ptr_scales, scales);
+# undef READ_PARAM
+
+ if (jcp.with_bias)
+ mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale));
+
+ mov(reg_aux_ptr_src, reg_ptr_src);
+ mov(reg_aux_ptr_dst, reg_ptr_dst);
+
+ vpxord(vreg_zero, vreg_zero, vreg_zero);
+
+ for (int i = 0; i < jcp.m; i++)
+ kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
+
+ int oc_blocks = jcp.oc / load_block;
+ mov(reg_oc_block, oc_blocks);
+ L(oc_block_label); {
+ loop_body();
+ add(reg_aux_ptr_src, sizeof(int32_t) * load_block);
+ add(reg_aux_ptr_dst, jcp.typesize_out * load_block);
+
+ add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
+ add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block);
+ }
+ dec(reg_oc_block);
+ jnz(oc_block_label, T_NEAR);
+
+ postamble();
+
+}
+
+/// GEMM kernel ////////////////////////////////////////////////////////////////
+struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t)
+ jit_conv_conf_2x3_wino_t jcp;
+ const primitive_attr_t &attr_;
+
+ struct call_params_t {
+ const void *src;
+ const void *dst;
+ const void *wei;
+ const void *dst_b;
+ };
+ void (*ker_)(const call_params_t *);
+
+ void generate();
+ static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp,
+ const primitive_attr_t &attr);
+
+ jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
+ jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr)
+ {
+ generate();
+ ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
+ }
+
+ static status_t init_conf(
+ jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
+ memory_desc_t &src_md, memory_desc_t &weights_md,
+ memory_desc_t &dst_md, memory_desc_t &bias_md,
+ const primitive_attr_t &attr);
+
+ Zmm vreg_out(int n, int m) {
+ const int id_reg_out = n * jcp.m_block + m;
+ assert(id_reg_out < jcp.n2_block * jcp.m_block);
+ return Zmm(31 - id_reg_out);
+ }
+ Zmm vreg_wei(int i) {
+ assert(31 - jcp.n2_block * jcp.m_block - i
+ > (jcp.ver == ver_vnni ? 0 : 2));
+ return Zmm(31 - jcp.n2_block * jcp.m_block - i);
+ }
+
+ Zmm vreg_src = Zmm(0);
+ Zmm vreg_one = Zmm(1);
+ Zmm vreg_tmp = Zmm(2);
+
+ Reg64 reg_ptr_src = r15;
+
+ Reg64 reg_aux_dst_b = r13;
+ Reg64 reg_aux_dst = r12;
+ Reg64 reg_aux_dst2 = r11;
+ Reg64 reg_aux_wei = r10;
+ Reg64 reg_aux_wei2 = r9;
+ Reg64 reg_aux_src = r8;
+ Reg64 reg_aux_src2 = rax;
+ Reg64 reg_mb = rbx;
+ Reg64 reg_nnb = abi_not_param1;
+ Reg64 reg_scratch = rdx;
+ Reg64 reg_K = rsi;
+};
+
+bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
+ jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
+ using namespace primitive_kind;
+ const auto &p = attr.post_ops_;
+
+ auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
+
+ switch (p.len_) {
+ case 0: return true;
+ case 1: return is_relu(0) || p.contain(sum, 0);
+ case 2: return (p.contain(sum, 0) && is_relu(1)) ||
+ (p.contain(sum, 1) && is_relu(0));
+ case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
+ default: return false;
+ }
+
+ return false;
+}
+
+void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() {
+ Label nnb_loop_label, K_loop_label, mb_loop_label;
+
+ auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
+ if (jcp.ver == ver_vnni) {
+ vpdpbusd(vreg_acc, vreg_src, vreg_wei);
+ } else {
+ vpmaddubsw(vreg_tmp, vreg_src, vreg_wei);
+ vpmaddwd(vreg_tmp, vreg_tmp, vreg_one);
+ vpaddd(vreg_acc, vreg_acc, vreg_tmp);
+ }
+ };
+
+ preamble();
+# define READ_PARAM(reg, field) \
+ mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
+ READ_PARAM(reg_ptr_src, src);
+ READ_PARAM(reg_aux_dst, dst);
+ READ_PARAM(reg_aux_wei, wei);
+ READ_PARAM(reg_aux_dst_b, dst_b);
+# undef READ_PARAM
+
+ if (jcp.ver != ver_vnni) {
+ xor_(reg_scratch, reg_scratch);
+ Reg16 _t = reg_scratch.cvt16();
+ mov(_t, 0x1);
+ vpbroadcastw(vreg_one, _t);
+ }
+
+ if (!jcp.small_mb) {
+ mov(reg_nnb, jcp.n_chunks);
+ L(nnb_loop_label);
+ }
+ mov(reg_aux_dst2, reg_aux_dst);
+ mov(reg_aux_src, reg_ptr_src);
+ mov(reg_mb, jcp.M / jcp.m_block);
+ L(mb_loop_label);
+ {
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
+ for (int m = 0; m < jcp.m_block; m++) {
+ int offset = jcp.typesize_acc * nb2 * jcp.n_block;
+ vmovups(vreg_out(nb2, m),
+ EVEX_compress_addr(reg_aux_dst_b, offset));
+ }
+ }
+ mov(reg_aux_src2, reg_aux_src);
+ mov(reg_aux_wei2, reg_aux_wei);
+ mov(reg_K, jcp.k_chunks);
+ L(K_loop_label);
+ {
+ for (int k = 0; k < jcp.k2_block; k += 4) {
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
+ int wei_offset
+ = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K);
+ vmovups(vreg_wei(nb2),
+ EVEX_compress_addr(reg_aux_wei2, wei_offset));
+ }
+ for (int m = 0; m < jcp.m_block; m++) {
+ int inp_offset = jcp.typesize_in * m * jcp.K;
+ vpbroadcastd(vreg_src,
+ EVEX_compress_addr(reg_aux_src2, inp_offset));
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
+ compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src);
+ }
+ add(reg_aux_src2, jcp.typesize_in * 4);
+ add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block);
+ }
+ }
+ dec(reg_K);
+ jnz(K_loop_label, T_NEAR);
+
+ for (int m = 0; m < jcp.m_block; m++) {
+ for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
+ int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block);
+ vmovups(EVEX_compress_addr(reg_aux_dst2, offset),
+ vreg_out(nb2, m));
+ }
+ }
+ add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K);
+ add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N);
+ }
+ dec(reg_mb);
+ jnz(mb_loop_label, T_NEAR);
+
+ if (!jcp.small_mb) {
+ add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
+ add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
+ add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K);
+
+ dec(reg_nnb);
+ jnz(nnb_loop_label, T_NEAR);
+ }
+
+ postamble();
+}
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
+ if (jcp.ver == ver_vnni) {
+ return (jcp.mb <= mkldnn_get_max_threads()
+ && (jcp.mb > 4
+ && jcp.ic > 64
+ && !(jcp.oc > 128 && jcp.ih < 14)))
+ || jcp.mb > mkldnn_get_max_threads();
+ }
+ return true;
+}
+}
+
+status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
+::init_conf(jit_conv_conf_2x3_wino_t &jcp,
+ const convolution_desc_t &cd, memory_desc_t &src_md,
+ memory_desc_t &wei_md, memory_desc_t &dst_md,
+ memory_desc_t &bias_md, const primitive_attr_t &attr) {
+ const memory_desc_wrapper src_d(&src_md);
+ const memory_desc_wrapper wei_d(&wei_md);
+ const memory_desc_wrapper dst_d(&dst_md);
+ const memory_desc_wrapper bias_d(&bias_md);
+
+ const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
+
+ jcp.nthr = mkldnn_get_max_threads();
+
+ jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.ih = src_d.dims()[2];
+ jcp.iw = src_d.dims()[3];
+ jcp.oh = dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[3];
+ jcp.kh = wei_d.dims()[with_groups + 2];
+ jcp.kw = wei_d.dims()[with_groups + 3];
+ jcp.t_pad = cd.padding[0][0];
+ jcp.b_pad = cd.padding[1][0];
+ jcp.l_pad = cd.padding[0][1];
+ jcp.r_pad = cd.padding[1][1];
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+ jcp.dilate_h = cd.dilates[0];
+ jcp.dilate_w = cd.dilates[1];
+
+ jcp.ver = ver_avx512_core;
+ if (!(mayiuse(avx512_core) &&
+ src_d.data_type() == data_type::u8
+ && wei_d.data_type() == data_type::s8
+ && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
+ data_type::s8, data_type::u8)))
+ return status::unimplemented;
+ if (mayiuse(avx512_core_vnni))
+ jcp.ver = ver_vnni;
+
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
+ // block sizes needed for GEMM kernel
+ jcp.ic_block = 4;
+ jcp.oc_block = 16;
+
+ bool ok = true
+ && jcp.ngroups == 1
+ && jcp.oc % load_block == 0 && jcp.ic % load_block == 0
+ && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
+ && everyone_is(3, jcp.kh, jcp.kw)
+ && everyone_is(1, jcp.stride_h, jcp.stride_w)
+ && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
+ && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad
+ && one_of(jcp.t_pad, 0, 1)
+ && one_of(jcp.l_pad, 0, 1);
+ if (!ok) return status::unimplemented;
+
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+ jcp.dst_dt = cd.dst_desc.data_type;
+
+ jcp.typesize_in = types::data_type_size(src_d.data_type());
+ jcp.typesize_out = types::data_type_size(dst_d.data_type());
+ jcp.typesize_acc = sizeof(int32_t);
+ jcp.typesize_bia = jcp.with_bias
+ ? types::data_type_size(bias_d.data_type())
+ : 0;
+
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+
+ jcp.m = 2;
+ jcp.r = 3;
+ jcp.alpha = jcp.m + jcp.r - 1;
+
+ int aa = jcp.alpha * jcp.alpha;
+ int L1_cap = get_cache_size(1, true);
+ int L2_cap = get_cache_size(2, true);
+ // need 1 extra reg for bcast, and 2 tmp regs for non-vnni
+ int free_regs = jcp.ver == ver_vnni ? 31 : 29;
+
+ auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) {
+ float thr_eff;
+ float Z = (float)jcp.ic + jcp.oc;
+ float Y = (float)jcp.ic * jcp.oc;
+ if (small_mb == 0) { // outer par
+ int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix);
+ thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
+ } else { // inner par
+ int tranw = iy * ix / jcp.alpha;
+ int gemmw = aa * (jcp.nb_oc / n2_b);
+ int tranw_r = rnd_up(tranw, jcp.nthr);
+ int gemmw_r = rnd_up(gemmw, jcp.nthr);
+ thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y);
+ }
+ return thr_eff;
+ };
+
+ auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) {
+ float mem_eff, req_mem;
+ int M = ix * iy / jcp.alpha;
+ if (small_mb == 0) { // outer parallelization strategy
+ // memory for wino transforms (other memory has poor reuse)
+ req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc);
+ mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f;
+ } else { // inner parallelization strategy
+ // memory used during gemm
+ int N = jcp.oc_block * n2_b;
+ req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N;
+ mem_eff = nstl::min(1.f, L2_cap / req_mem);
+ // memory used during wino transforms
+ int M_per_thr = div_up(M, jcp.nthr);
+ req_mem = (float)aa * M_per_thr
+ * (jcp.ic + jcp.typesize_acc * jcp.oc);
+ if (req_mem > L2_cap)
+ mem_eff = 0.1f;
+ }
+ return mem_eff;
+ };
+
+ auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff,
+ float mem_eff, float reg_eff) {
+ // these coefficients are chosen empirically
+ float mem_fac = 0.1f, reg_fac = 0.2f;
+ // normalized overhead relative to memory and register components
+ float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff;
+ // thread and work components affect all others
+ tot_eff *= thr_eff * work_eff;
+ return tot_eff;
+ };
+
+ auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff,
+ int &m_block, int &n2_block, float &tot_eff) {
+ int M = (ix * iy) / jcp.alpha;
+ int max_m_block = nstl::min(M, free_regs);
+ int max_n2_block = nstl::min(jcp.nb_oc, free_regs);
+ tot_eff = 0.f;
+ for (int im = max_m_block; im > 0; im--) {
+ if (M % im)
+ continue;
+ for (int in2 = max_n2_block; in2 > 0; in2--) {
+ int used_regs = (im + 1) * in2;
+ float mem_eff = get_mem_eff(small_mb, ix, iy, in2);
+ float reg_eff = (float)(im * in2) / (im + in2);
+ float thr_eff = get_thr_eff(small_mb, ix, iy, in2);
+ float cur_tot_eff = get_tot_eff(
+ small_mb, thr_eff, work_eff, mem_eff, reg_eff);
+ if (jcp.nb_oc % in2 || used_regs > free_regs
+ || cur_tot_eff <= tot_eff)
+ continue;
+ tot_eff = cur_tot_eff;
+ m_block = im;
+ n2_block = in2;
+ }
+ }
+ };
+
+ /* Selecting xb and yb blocking */
+ int min_yb = jcp.m;
+ int min_xb = jcp.m;
+ int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2));
+ int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2));
+ float best_eff = 0.f;
+ for (int ix = min_xb; ix <= max_xb; ix += 2) {
+ assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2);
+ for (int iy = max_yb; iy >= min_yb; iy -= 2) {
+ assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2);
+
+ int m_b[2];
+ int n2_b[2];
+ bool small_mb;
+ float inner_eff, outer_eff, work_eff;
+
+ int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix);
+ work_eff = (float)jcp.oh * jcp.ow / tiled_area;
+ if (best_eff > 0.f && work_eff < 4.f / 9.f)
+ continue; // no gain from Winograd transformation
+
+ /* outer parallelization */
+ find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff);
+
+ /* inner parallelization */
+ find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff);
+
+ small_mb = inner_eff > outer_eff;
+ float eff = small_mb ? inner_eff : outer_eff;
+ if (eff > best_eff) {
+ best_eff = eff;
+ jcp.yb = iy;
+ jcp.xb = ix;
+ jcp.m_block = m_b[small_mb];
+ jcp.n2_block = n2_b[small_mb];
+ jcp.small_mb = small_mb;
+ }
+ }
+ }
+
+ assert((jcp.m_block + 1) * jcp.n2_block <= free_regs);
+ assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
+
+ jcp.mb_block = 1;
+ if (jcp.small_mb) {
+ // For small mb harness, set mb_block as large as possible subject to
+ // the constraint that winograd activations fit into available L3 cache
+ int L3_cap = get_cache_size(3, true);
+ int M = jcp.xb * jcp.yb / 4;
+ int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in;
+ int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc;
+ int max_mb_block = nstl::min(
+ jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size));
+ for (int i = max_mb_block; i > 1; i--) {
+ if (jcp.mb % i == 0) {
+ jcp.mb_block = i;
+ break;
+ }
+ }
+ }
+ jcp.nb_mb = jcp.mb / jcp.mb_block;
+
+ jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4;
+ jcp.N = jcp.oc;
+ jcp.K = jcp.ic;
+
+ jcp.inp_stride = jcp.M * jcp.ic;
+ jcp.out_stride = jcp.M * jcp.oc;
+ jcp.wei_stride = jcp.ic * jcp.oc;
+ jcp.bia_stride = jcp.oc;
+
+ jcp.n_block = jcp.oc_block;
+ jcp.k_block = jcp.ic_block;
+
+ jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block;
+
+ // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4
+ // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is
+ // a multiple of load_block = 16, we just use that for now.
+ jcp.k2_block = load_block;
+ jcp.k_chunks = jcp.K / jcp.k2_block;
+
+ const auto &oscales = attr.output_scales_;
+ jcp.is_oc_scale = oscales.mask_ == 1 << 1;
+ assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
+
+ /* re-create weights primitive descriptor
+ and set weights wino_blocking */
+ memory_desc_t expect_wei_md = wei_md;
+
+ expect_wei_md.format_kind = format_kind::wino;
+ expect_wei_md.data_type = data_type::s8;
+ mkldnn_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
+ wd.wino_format = mkldnn_wino_wei_aaOIoi;
+ wd.r = jcp.r;
+ wd.alpha = jcp.alpha;
+ wd.ic = jcp.ic;
+ wd.oc = jcp.oc;
+ wd.ic_block = jcp.ic_block;
+ wd.oc_block = jcp.oc_block;
+ wd.oc2_block = jcp.n2_block;
+ wd.ic2_block = 1;
+ wd.adj_scale = adj_wei_scale;
+
+ size_t max_size = types::data_type_size(data_type::s8) *
+ jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
+ max_size += types::data_type_size(data_type::s32) *
+ jcp.alpha * jcp.alpha * jcp.oc;
+ wd.size = max_size;
+
+ if (wei_md.format_kind == format_kind::any)
+ wei_md = expect_wei_md;
+ if (wei_md != expect_wei_md)
+ return status::unimplemented;
+
+ const int tilesize = jcp.alpha * jcp.alpha;
+ const int numtiles = jcp.M;
+ const int alltiles = numtiles * tilesize;
+
+ jcp.size_wino_src
+ = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K)
+ / jcp.typesize_in;
+ jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic;
+ jcp.size_wino_dst = alltiles * jcp.oc;
+
+ return status::success;
+}
+////////////////////////////////////////////////////////////////////////////////
+
+template <data_type_t dst_data_type>
+status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+ pd_t::jit_conf() {
+ return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(
+ jcp_, *this->desc(), this->src_md_, this->weights_md_,
+ this->dst_md_,this->bias_md_, *this->attr());
+}
+
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::pd_t::
+init_scratchpad() {
+ auto scratchpad = this->scratchpad_registry().registrar();
+
+ int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr;
+ scratchpad.book(key_wino_V,
+ sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K);
+ scratchpad.book(key_wino_M,
+ sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K);
+
+ dim_t scale_count = attr()->output_scales_.count_;
+ scratchpad.book(key_conv_adjusted_scales,
+ sizeof(float) * nstl::max<dim_t>(scale_count, 16));
+}
+
+template <data_type_t dst_data_type>
+jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+ jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+{
+ kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
+ pd()->jcp_, *pd()->attr());
+ src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
+ pd()->jcp_, *pd()->attr());
+ dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
+ pd()->jcp_, *pd()->attr());
+}
+
+template <data_type_t dst_data_type>
+jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+ ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() {
+ delete kernel_;
+ delete src_trans_;
+ delete dst_trans_;
+}
+
+template <data_type_t dst_data_type>
+const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+adjust_oscales(const memory_tracking::grantor_t &scratchpad) const {
+ const float *oscales = pd()->attr()->output_scales_.scales_;
+ auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / (adj_src_scale * adj_wei_scale);
+ if (count == 1)
+ utils::array_set(loc_scales, oscales[0] * factor, 16);
+ else
+ for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor;
+ return loc_scales;
+}
+
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+execute_forward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ const auto &jcp = kernel_->jcp;
+ if (jcp.small_mb)
+ execute_forward_small_mb(src, weights, bias, dst, this->scratchpad(ctx));
+ else
+ execute_forward_mbN(src, weights, bias, dst, this->scratchpad(ctx));
+}
+
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+execute_forward_mbN(const src_data_t *src, const wei_data_t *wei,
+ const char *bia, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const auto &jcp = kernel_->jcp;
+ const float *oscales = adjust_oscales(scratchpad);
+
+ auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
+ auto wino_src_base = scratchpad.template get<src_data_t>(key_wino_V);
+ auto wino_dst_base = scratchpad.template get<acc_data_t>(key_wino_M);
+
+ parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb),
+ [&](int mb, int tile_y_b, int tile_x_b) {
+
+ int tile_y = tile_y_b * jcp.yb;
+ int tile_x = tile_x_b * jcp.xb;
+
+ int ithr = mkldnn_get_thread_num();
+ auto wino_src = wino_src_base + jcp.size_wino_src * ithr;
+ auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr;
+
+ auto src_trans_p =
+ jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
+ auto dst_trans_p =
+ jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
+ auto gemm_p =
+ jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t();
+
+ /* transformation of input tensor to winograd domain */
+ for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
+ for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
+ uint16_t v_y_masks[4], v_x_masks[4];
+
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+
+ int v_ys = nstl::max(0, jcp.t_pad - y);
+ int v_ye = nstl::min(jcp.alpha,
+ nstl::max(0, jcp.ih + jcp.t_pad - y));
+
+ int v_xs = nstl::max(0, jcp.l_pad - x);
+ int v_xe = nstl::min(jcp.alpha,
+ nstl::max(0, jcp.iw + jcp.l_pad - x));
+
+#pragma unroll(4)
+ for (int i = 0; i < jcp.alpha; i++) {
+ v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
+ v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
+ }
+ auto local_s = src
+ + mb * jcp.ih * jcp.iw * jcp.ic
+ + y * jcp.iw * jcp.ic + x * jcp.ic;
+ auto local_w = wino_src + m * jcp.ic;
+
+ src_trans_p.src = local_s;
+ src_trans_p.wino_src = local_w;
+ src_trans_p.v_y_masks = v_y_masks;
+ src_trans_p.v_x_masks = v_x_masks;
+
+ src_trans_->ker_(&src_trans_p);
+ }
+ }
+ /* gemms */
+ for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
+ // start threads at different GEMMs to help bring weights into LLC
+ int offset = (tile_ij + ithr) % 16;
+ gemm_p.src = wino_src + jcp.inp_stride * offset;
+ gemm_p.dst = wino_dst + jcp.out_stride * offset;
+ gemm_p.wei = wei + jcp.wei_stride * offset;
+ gemm_p.dst_b = dst_bias + jcp.bia_stride * offset;
+
+ kernel_->ker_(&gemm_p);
+ }
+
+ /* transformation from winograd domain to output tensor */
+ for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
+ for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
+ uint16_t v_y_masks[2], v_x_masks[2];
+
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+
+#pragma unroll(2)
+ for (int i = 0; i < jcp.m; i++) {
+ v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
+ v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
+ }
+ auto local_d = dst
+ + mb * jcp.oh * jcp.ow * jcp.oc
+ + y * jcp.ow * jcp.oc + x * jcp.oc;
+ auto local_w = wino_dst + m * jcp.oc;
+
+ auto scales = oscales;
+ dst_trans_p.dst = local_d;
+ dst_trans_p.wino_dst = local_w;
+ dst_trans_p.v_y_masks = v_y_masks;
+ dst_trans_p.v_x_masks = v_x_masks;
+
+ dst_trans_p.scales = scales;
+ dst_trans_p.bias = bia;
+
+ dst_trans_->ker_(&dst_trans_p);
+ }
+ }
+ });
+}
+
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei,
+ const char *bia, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const auto &jcp = kernel_->jcp;
+ const float *oscales = adjust_oscales(scratchpad);
+
+ auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
+ auto wino_src = scratchpad.template get<src_data_t>(key_wino_V);
+ auto wino_dst = scratchpad.template get<acc_data_t>(key_wino_M);
+
+ for (int mbb = 0; mbb < jcp.nb_mb; mbb++) {
+ for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
+ for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
+ /* transformation of input tensor to winograd domain */
+ parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
+ [&](int y_in_block_b, int x_in_block_b, int mb) {
+ int y_in_block = y_in_block_b * 2;
+ int x_in_block = x_in_block_b * 2;
+
+ auto src_trans_p =
+ jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
+
+ uint16_t v_y_masks[4], v_x_masks[4];
+
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
+ + (x_in_block / 2);
+
+ int v_ys = nstl::max(0, jcp.t_pad - y);
+ int v_ye = nstl::min(
+ jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
+
+ int v_xs = nstl::max(0, jcp.l_pad - x);
+ int v_xe = nstl::min(
+ jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
+
+#pragma unroll(4)
+ for (int i = 0; i < jcp.alpha; i++) {
+ v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
+ v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
+ }
+ auto local_s = src
+ + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic
+ + y * jcp.iw * jcp.ic + x * jcp.ic;
+ auto local_w = wino_src + m * jcp.ic;
+
+ src_trans_p.src = local_s;
+ src_trans_p.wino_src = local_w;
+ src_trans_p.v_y_masks = v_y_masks;
+ src_trans_p.v_x_masks = v_x_masks;
+
+ src_trans_->ker_(&src_trans_p);
+ });
+
+ /* gemms */
+ parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
+ auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::
+ call_params_t();
+
+ gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
+ gemm_p.dst = wino_dst + jcp.out_stride * tile_ij
+ + nnb * jcp.n2_block * jcp.n_block;
+ gemm_p.wei = wei + jcp.wei_stride * tile_ij
+ + nnb * jcp.n2_block * jcp.n_block * jcp.K;
+ gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij
+ + nnb * jcp.n2_block * jcp.n_block;
+
+ kernel_->ker_(&gemm_p);
+ });
+
+ /* transformation from winograd domain to output tensor */
+ parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
+ [&](int y_in_block_b, int x_in_block_b, int mb) {
+ int y_in_block = y_in_block_b * 2;
+ int x_in_block = x_in_block_b * 2;
+
+ auto dst_trans_p =
+ jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
+
+ uint16_t v_y_masks[2], v_x_masks[2];
+
+ int y = y_in_block + tile_y;
+ int x = x_in_block + tile_x;
+ int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
+ + (x_in_block / 2);
+
+#pragma unroll(2)
+ for (int i = 0; i < jcp.m; i++) {
+ v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
+ v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
+ }
+ auto local_d = dst
+ + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc
+ + y * jcp.ow * jcp.oc + x * jcp.oc;
+ auto local_w = wino_dst + m * jcp.oc;
+
+ auto scales = oscales;
+ dst_trans_p.dst = local_d;
+ dst_trans_p.wino_dst = local_w;
+ dst_trans_p.v_y_masks = v_y_masks;
+ dst_trans_p.v_x_masks = v_x_masks;
+
+ dst_trans_p.scales = scales;
+ dst_trans_p.bias = bia;
+
+ dst_trans_->ker_(&dst_trans_p);
+ });
+ }}}
+}
+
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s8>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::u8>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s32>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::f32>;
+
+} // namespace cpu
+} // namespace impl
+} // namespace mkldnn
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp
new file mode 100644
index 0000000000..9e6e57b051
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_u8s8s32x_wino_convolution.hpp
@@ -0,0 +1,128 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP
+#define CPU_JIT_AVX512_CORE_U8S8S32X_WINO_CONVOLUTION_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_primitive_conf.hpp"
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t;
+struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t;
+struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t;
+
+template <data_type_t dst_data_type>
+struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_()
+ {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_int8_wino:", avx512_core, ""),
+ jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::convolution_auto,
+ alg_kind::convolution_winograd)
+ && expect_data_types(data_type::u8, data_type::s8,
+ data_type::undef, dst_data_type, data_type::s32)
+ && IMPLICATION(with_bias(), utils::one_of(
+ desc()->bias_desc.data_type, data_type::f32,
+ data_type::s32, data_type::s8, data_type::u8))
+ && !has_zero_dim_memory()
+ && set_default_formats();
+
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_conf();
+ if (status != status::success) return status;
+ set_default_alg_kind(alg_kind::convolution_winograd);
+
+ init_scratchpad();
+
+ return status;
+ }
+
+ jit_conv_conf_2x3_wino_t jcp_;
+
+ protected:
+ status_t jit_conf();
+ void init_scratchpad();
+
+ bool set_default_formats() {
+ using namespace format_tag;
+ return set_default_formats_common(nhwc, any, nhwc);
+ }
+ };
+
+ typedef typename prec_traits<data_type::u8>::type src_data_t;
+ typedef typename prec_traits<data_type::s8>::type wei_data_t;
+ typedef typename prec_traits<data_type::s32>::type acc_data_t;
+ typedef typename prec_traits<dst_data_type>::type dst_data_t;
+
+ jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd);
+ ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t();
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ const float *adjust_oscales(const memory_tracking::grantor_t &scratchpad)
+ const;
+ void execute_forward(const exec_ctx_t &ctx) const;
+ void execute_forward_small_mb(const src_data_t *src, const wei_data_t *wei,
+ const char *bia, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const;
+ void execute_forward_mbN(const src_data_t *src, const wei_data_t *wei,
+ const char *bia, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t *kernel_;
+ jit_avx512_core_u8s8s32x_wino_conv_src_trans_t *src_trans_;
+ jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t *dst_trans_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp
new file mode 100644
index 0000000000..f4ec29ab00
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp
@@ -0,0 +1,820 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_memory.hpp"
+
+#include "jit_uni_1x1_conv_utils.hpp"
+#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp"
+
+#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+
+using namespace Xbyak;
+
+bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_eltwise(int position)
+{
+ using namespace primitive_kind;
+ const auto &p = attr_.post_ops_;
+
+ if (position == 0) {
+ /* eltwise before sum */
+ return p.contain(eltwise, 0);
+ } else if (position == 1) {
+ /* eltwise after sum */
+ return p.contain(sum, 0) && p.contain(eltwise, 1);
+ }
+
+ return false;
+}
+
+void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk)
+{
+ mov(aux1_reg_bcast_data, reg_bcast_data);
+ mov(aux_reg_bcast_data, reg_bcast_data);
+
+ mov(aux_reg_output_data, reg_output_data);
+ mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off));
+
+ Label bcast_loop;
+ Label bcast_loop_tail;
+
+ cmp(bcast_loop_iter, jcp.ur);
+ jl(bcast_loop_tail, T_NEAR);
+
+ L(bcast_loop); {
+ assert(jcp.bcast_block % jcp.ur == 0);
+ int num_substeps = jcp.bcast_block / jcp.ur;
+ assert(num_substeps > 0 && num_substeps < 10);
+ for (int i = 0; i < num_substeps; i++) {
+ reduce_loop(load_loop_blk, jcp.ur, i, false);
+ if (i < num_substeps - 1) {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
+ add(aux_reg_output_data, jcp.bcast_loop_output_substep);
+ }
+ else {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
+ - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
+ int output_offset = jcp.bcast_loop_output_step
+ - (num_substeps - 1) * jcp.bcast_loop_output_substep;
+
+ add(aux_reg_output_data, output_offset);
+ }
+ }
+ sub(bcast_loop_iter, jcp.bcast_block);
+ cmp(bcast_loop_iter, jcp.bcast_block);
+ jge(bcast_loop, T_NEAR);
+ }
+
+ L(bcast_loop_tail);
+ if (jcp.ur_tail) {
+ Label bcast_loop_tail_out;
+ cmp(bcast_loop_iter, 0);
+ jz(bcast_loop_tail_out, T_NEAR);
+ reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
+ L(bcast_loop_tail_out);
+ }
+}
+
+void jit_avx512_core_x8s8s32x_1x1_conv_kernel::cvt2ps(data_type_t type_in,
+ zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) {
+ zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
+ switch (type_in) {
+ case data_type::f32:
+ case data_type::s32: vmovups(zmm, op); break;
+ case data_type::s8: vpmovsxbd(zmm, op); break;
+ case data_type::u8: vpmovzxbd(zmm, op); break;
+ default: assert(!"unsupported data type");
+ }
+ if (type_in != data_type::f32)
+ vcvtdq2ps(zmm_in, zmm_in);
+}
+
+void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
+ int ur, int substep, bool wraparound)
+{
+ auto vreg_load = [=](int i_load) {
+ return Zmm(ur * load_loop_blk + i_load);
+ };
+
+ auto vreg_accum = [=](int i_load, int i_ur) {
+ return Zmm(i_ur * load_loop_blk + i_load);
+ };
+
+ auto zmm_bias_alpha = [=]() {
+ return Zmm(ur * load_loop_blk);
+ };
+
+ auto xmm_bias_alpha = [=]() {
+ return Xmm(ur * load_loop_blk);
+ };
+ auto bias_ptr = [=](int i_load) {
+ return EVEX_compress_addr(reg_bias_data,
+ jcp.typesize_bia * jcp.oc_block * i_load);
+ };
+
+ auto comp_ptr = [=](int i_load) {
+ return EVEX_compress_addr(reg_comp_data,
+ sizeof(int32_t) * jcp.oc_block * i_load);
+ };
+
+ auto scale_ptr = [=](int i_load) {
+ return EVEX_compress_addr(reg_ptr_scales,
+ jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load));
+ };
+
+ auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
+ assert(i_ur < jcp.ur);
+ assert(i_reduce <= jcp.reduce_loop_unroll);
+ assert(jcp.reduce_loop_unroll == jcp.reduce_block);
+
+ int offt = (jcp.ic_without_padding * i_ur + i_reduce);
+
+ return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
+ bcast);
+ };
+
+ auto load_ptr = [=](int i_reduce, int i_load) {
+ int u0 = i_reduce % jcp.reduce_loop_unroll;
+ int u1 = i_reduce / jcp.reduce_loop_unroll;
+
+ int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
+
+ return EVEX_compress_addr(aux_reg_load_data,
+ u1 * jcp.reduce_loop_load_step
+ + jcp.typesize_in * offt);
+ };
+
+ auto output_ptr = [=](int i_load, int i_ur) {
+ return EVEX_compress_addr(aux_reg_output_data,
+ jcp.typesize_out * (jcp.oc_without_padding * i_ur
+ + i_load * jcp.load_block));
+ };
+
+ auto init = [=]() {
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load)
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ auto r = vreg_accum(i_load, i_ur);
+ vpxord(r, r, r);
+ }
+ if (jcp.signed_input) {
+ xor_(reg_scratch, reg_scratch);
+ Reg8 _t8 = reg_scratch.cvt8();
+ mov(_t8, (int8_t)-128);
+ vpbroadcastb(zmm_shift, _t8);
+ }
+ };
+
+ auto store = [=](const bool mask_flag_in) {
+ const auto &p = attr_.post_ops_;
+ const int sum_idx = p.find(primitive_kind::sum);
+ const float *p_sum_scale = (sum_idx != -1)
+ ? &p.entry_[sum_idx].sum.scale
+ : nullptr;
+ mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
+ mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
+ if (p_sum_scale && *p_sum_scale != 1.f) {
+ mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data);
+ mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
+ }
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ mov(reg_scratch, float2int(jcp.wei_adj_scale));
+ vmovq(xmm_bias_alpha(), reg_scratch);
+ vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
+ }
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1;
+ auto zmm_bias = zmm_tmp;
+ auto zmm_comp = zmm_bcast;
+ if (jcp.with_bias) {
+ if (jcp.signed_input)
+ mov(reg_bias_data,
+ EVEX_compress_addr(rsp,reg_bias_data_off));
+ cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag);
+ if (jcp.signed_input && jcp.ver != ver_vnni)
+ vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
+ }
+ if (jcp.signed_input) {
+ mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
+ cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag);
+ }
+
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ auto r = vreg_accum(i_load, i_ur);
+ vcvtdq2ps(r, r);
+ if (jcp.signed_input)
+ vaddps(r, r, zmm_comp);
+ if (jcp.with_bias)
+ vaddps(r, r, zmm_bias);
+
+ zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r;
+ vmulps(mask_zmm, r, scale_ptr(i_load));
+ }
+ }
+
+ if (maybe_eltwise(0))
+ eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
+
+ if (p_sum_scale) { // post_op: sum
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ const bool mask_flag = mask_flag_in &&
+ i_load == load_loop_blk - 1;
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+ auto zmm_prev_dst = zmm_zero;
+
+ auto r = vreg_accum(i_load, i_ur);
+ cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur),
+ mask_flag);
+
+ if (*p_sum_scale == 1.f)
+ vaddps(r, zmm_prev_dst);
+ else
+ vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
+ }
+ }
+ }
+
+ if (maybe_eltwise(1))
+ eltwise_injector_->compute_vector_range(0, ur * load_loop_blk);
+
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ const bool mask_flag = mask_flag_in &&
+ i_load == load_loop_blk - 1;
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ auto r = vreg_accum(i_load, i_ur);
+ if (jcp.dst_dt == data_type::u8) {
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+ vmaxps(r, zmm_zero, r);
+ }
+ if (jcp.dst_dt != data_type::f32)
+ vcvtps2dq(r, r);
+ }
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ auto r = vreg_accum(i_load, i_ur);
+ zmm_t r_zmm = mask_flag ? r | ktail_mask : r;
+
+ switch (jcp.dst_dt) {
+ case data_type::f32:
+ case data_type::s32:
+ vmovups(output_ptr(i_load, i_ur), r_zmm); break;
+ case data_type::s8:
+ vpmovsdb(output_ptr(i_load, i_ur), r_zmm); break;
+ case data_type::u8:
+ vpmovusdb(output_ptr(i_load, i_ur), r_zmm); break;
+ default: assert(!"unknown dst_dt");
+ }
+ }
+ }
+ mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
+ if (p_sum_scale && *p_sum_scale != 1.f)
+ mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off));
+ };
+
+ auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
+ if (jcp.ver == ver_vnni) {
+ vpdpbusd(vreg_acc, vreg_src, vreg_wei);
+ } else {
+ vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
+ vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
+ vpaddd(vreg_acc, vreg_acc, zmm_tmp);
+ }
+ };
+
+ auto fma_block = [=](bool last_block) {
+ int reduce_step = 4;
+ int tail_size = jcp.ic_without_padding % reduce_step;
+ int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding
+ ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step)
+ : jcp.reduce_loop_unroll;
+ for (int i_reduce = 0; i_reduce < loop_unroll;
+ i_reduce += reduce_step) {
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load)
+ vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load));
+ for (int i_ur = 0; i_ur < ur; ++i_ur) {
+ if (last_block && tail_size != 0
+ && i_reduce == loop_unroll - reduce_step) {
+ Xmm xmm_bcast = Xmm(zmm_bcast.getIdx());
+ for (int r = 0; r < tail_size; ++r)
+ vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
+ + jcp.ic_without_padding * i_ur + i_reduce + r], r);
+ vpbroadcastd(zmm_bcast, xmm_bcast);
+ } else {
+ vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
+ }
+ if (jcp.signed_input)
+ vpsubb(zmm_bcast, zmm_bcast, zmm_shift);
+ for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+ compute(vreg_accum(i_load, i_ur),
+ vreg_load(i_load), zmm_bcast);
+ }
+ }
+ }
+ };
+
+ Label reduce_loop;
+ Label reduce_loop_tail;
+
+ mov(aux_reg_load_data, reg_load_data);
+
+ mov(aux_reg_bcast_data, aux1_reg_bcast_data);
+ init();
+
+ mov(reduce_loop_iter, reg_reduce_loop_work);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jle(reduce_loop_tail, T_NEAR);
+
+ L(reduce_loop); {
+ fma_block(false);
+ add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
+ add(aux_reg_load_data, jcp.reduce_loop_load_step);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jg(reduce_loop, T_NEAR);
+ }
+
+ L(reduce_loop_tail);
+ if (jcp.ic != jcp.ic_without_padding) {
+ fma_block(true);
+ } else {
+ fma_block(false);
+ }
+
+ if (jcp.oc_without_padding != jcp.oc) {
+ Label end_store, common_store;
+ mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
+
+ /*Check if it is the last load_loop_blk*/
+ sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
+ cmp(reg_load_loop_work, 0);
+ jg(common_store, T_NEAR);
+
+ /*Check if it is the last ocb*/
+ test(reg_reduce_pos_flag, FLAG_OC_LAST);
+ jz(common_store, T_NEAR);
+
+ store(true);
+ jmp(end_store, T_NEAR);
+
+ L(common_store);
+ store(false);
+
+ L(end_store);
+
+ add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
+ } else {
+ store(false);
+ }
+}
+
+void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
+{
+ preamble();
+
+ xor_(reg_scratch, reg_scratch);
+ Reg16 _t = reg_scratch.cvt16();
+ mov(_t, 0x1);
+ vpbroadcastw(zmm_one, _t);
+
+ sub(rsp, stack_space_needed);
+
+ if (jcp.oc_without_padding != jcp.oc) {
+ int tail_size = jcp.oc_without_padding % jcp.oc_block;
+ int mask = (1 << tail_size) - 1;
+ Reg32 regw_tmp = reg_last_load.cvt32();
+ mov(regw_tmp, mask);
+ kmovw(ktail_mask, regw_tmp);
+ }
+
+ if (jcp.with_bias)
+ mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
+ if (jcp.signed_input) {
+ mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
+ mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]);
+ mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
+ }
+ mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
+ mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
+ mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
+ mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
+ mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
+
+ mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
+ mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
+ mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work);
+ mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
+ mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
+
+
+ auto load_loop_body = [=](int load_loop_blk) {
+ bcast_loop(load_loop_blk);
+ add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
+ if (jcp.with_bias) {
+ if (jcp.signed_input)
+ mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off));
+ add(reg_bias_data,
+ load_loop_blk * jcp.load_block * jcp.typesize_bia);
+ if (jcp.signed_input)
+ mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
+ }
+ if (jcp.signed_input) {
+ mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
+ add(reg_comp_data,
+ load_loop_blk * jcp.load_block * sizeof(int32_t));
+ mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
+ }
+ mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
+ mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
+ add(reg_ptr_scales,
+ jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float));
+ mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
+ mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
+ add(reg_output_data,
+ load_loop_blk * jcp.load_block * jcp.typesize_out);
+ sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
+ };
+
+ const int simd_w = 16;
+
+ Label load_loop_blk[7];
+
+ static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
+ const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast);
+ const int *ur_cases_fma = ur_cases_fma_expl_bcast;
+ const int *ur_cases = ur_cases_fma;
+ const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases);
+
+ for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
+ int label_idx = num_ur_cases - ur_idx - 1;
+ if (jcp.ur <= ur_cases[ur_idx]) {
+ cmp(reg_load_loop_work, simd_w * (label_idx + 1));
+ jle(load_loop_blk[label_idx], T_NEAR);
+ }
+ }
+
+ for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
+ if (jcp.ur <= ur_cases[ur_idx]) {
+ int label_idx = num_ur_cases - ur_idx - 1;
+ L(load_loop_blk[label_idx]);
+ {
+ if (label_idx == 0) {
+ cmp(reg_load_loop_work, 0);
+ je(load_loop_blk[num_ur_cases], T_NEAR);
+ }
+
+ for (int _i = 1; _i <= label_idx + 1; _i++) {
+ prefetcht0(ptr [ reg_load_data + _i * jcp.ic * jcp.oc_block ]);
+ prefetcht1(ptr [ reg_output_data + _i * jcp.oc_block ]);
+ }
+
+ load_loop_body(label_idx + 1);
+ if (label_idx - 1 > 0) {
+ cmp(reg_load_loop_work, 2 * label_idx * simd_w);
+ je(load_loop_blk[label_idx - 1], T_NEAR);
+ }
+ cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
+ jge(load_loop_blk[label_idx]);
+ }
+ for (int idx = label_idx - 1; idx > 0; --idx) {
+ cmp(reg_load_loop_work, simd_w * (idx + 1));
+ je(load_loop_blk[idx], T_NEAR);
+ }
+ if (ur_idx < num_ur_cases - 2) {
+ cmp(reg_load_loop_work, simd_w);
+ jle(load_loop_blk[0], T_NEAR);
+ }
+ }
+ }
+ L(load_loop_blk[num_ur_cases]);
+
+ add(rsp, stack_space_needed);
+
+ postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+}
+
+bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok(
+ jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ using namespace primitive_kind;
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+
+ switch (p.len_) {
+ case 0: return true;
+ case 1: return is_eltwise(0) || p.contain(sum, 0);
+ case 2: return (p.contain(sum, 0) && is_eltwise(1))
+ || (p.contain(sum, 1) && is_eltwise(0));
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
+ jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
+ const primitive_attr_t &attr, int nthreads, bool reduce_src) {
+ if (!mayiuse(avx512_core)) return status::unimplemented;
+
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ if (!one_of(src_d.data_type(), data_type::u8, data_type::s8)
+ || weights_d.data_type() != data_type::s8
+ || !one_of(dst_d.data_type(),
+ data_type::f32, data_type::s32, data_type::s8, data_type::u8))
+ return status::unimplemented;
+ jcp.ver = ver_avx512_core;
+ if (mayiuse(avx512_core_vnni))
+ jcp.ver = ver_vnni;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.ic_without_padding = jcp.ic;
+ jcp.ih = src_d.dims()[2];
+ jcp.iw = src_d.dims()[3];
+ jcp.oh = dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[3];
+ jcp.kh = weights_d.dims()[with_groups + 2];
+ jcp.kw = weights_d.dims()[with_groups + 3];
+ jcp.t_pad = cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][1];
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
+
+ jcp.os = jcp.oh * jcp.ow;
+ jcp.is = jcp.ih * jcp.iw;
+ jcp.tr_is = rnd_up(jcp.is, 4);
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise)
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+
+ format_tag_t dat_tag = format_tag::nhwc;
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+
+ bool args_ok = true
+ && jcp.ngroups == 1
+ && jcp.src_tag == dat_tag
+ && jcp.dst_tag == dat_tag;
+ if (!args_ok) return status::unimplemented;
+
+ const int simd_w = 16;
+
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.ic, simd_w);
+
+ args_ok = true
+ && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
+ && jcp.t_pad == 0 && jcp.l_pad == 0
+ && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
+ && jcp.kh == 1 && jcp.kw == 1;
+ if (!args_ok) return status::unimplemented;
+
+ jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+ jcp.dst_dt = cd.dst_desc.data_type;
+
+ jcp.ic_block = jcp.oc_block = simd_w;
+
+ jcp.typesize_in = types::data_type_size(src_d.data_type());
+ jcp.typesize_out = types::data_type_size(dst_d.data_type());
+ jcp.typesize_bia = jcp.with_bias
+ ? types::data_type_size(bias_d.data_type())
+ : 0;
+
+ const int SMALL_SPATIAL = 7 * 7;
+ const int BIG_REDUCE_DIM = 1024;
+
+ int load_blocking = 0;
+ int load_blocking_max = 0;
+ int bcast_blocking = 0;
+ int bcast_blocking_max = 0;
+ int reduce_blocking = 0;
+ int reduce_blocking_max = 0;
+ jcp.load_grp_count = 1;
+ jcp.use_vmovntps = false;
+
+ const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in);
+ const int L2_capacity = (L2_size * 3) / 4;
+
+ int size_treshold = 28;
+ int max_regs = 0;
+ int min_regs = 6;
+ if (jcp.ver == ver_vnni)
+ max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold)
+ && (jcp.oc < 128 || jcp.ic < 128)) ? min_regs : 9;
+ else
+ max_regs = 8;
+ jcp.expl_bcast = true;
+
+ if (jcp.mb == 1 && jcp.ic > 128
+ && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) {
+ if (jcp.os <= SMALL_SPATIAL && jcp.oc * jcp.ic < L2_size)
+ max_regs = min_regs; // mobilenet_v2 performance improvement
+ jcp.ur = nstl::min(max_regs, jcp.os);
+ } else {
+ const int spatial = jcp.oh;
+ jcp.ur = 1;
+ for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
+ if ((spatial >= size_treshold && spatial % ur_w == 0)
+ || (spatial < size_treshold && jcp.os % ur_w == 0)) {
+ jcp.ur = ur_w;
+ break;
+ }
+ }
+ if (jcp.ur == 1) {
+ jcp.ur = nstl::min(max_regs, jcp.os);
+ int os_tail = jcp.os % max_regs;
+ for (int i = max_regs; i >= min_regs; i--) {
+ int i_tail = jcp.os % i;
+ if (i_tail > os_tail || i_tail == 0) {
+ jcp.ur = i;
+ os_tail = i_tail;
+ if (i_tail == 0)
+ break;
+ }
+ }
+ }
+ }
+
+ jcp.reduce_dim = jcp.ic;
+ jcp.reduce_block = jcp.ic_block;
+
+ jcp.load_dim = jcp.oc;
+ jcp.load_block = jcp.oc_block;
+
+ jcp.bcast_dim = jcp.is;
+
+ jcp.bcast_block = jcp.ur;
+
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.typesize_in;
+
+ jcp.reduce_loop_load_step
+ = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
+
+ jcp.bcast_loop_output_step = jcp.ur * jcp.oc_without_padding * jcp.typesize_out;
+ jcp.bcast_loop_output_substep = -1; // unused
+ jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_without_padding * jcp.typesize_in;
+ jcp.bcast_loop_bcast_substep = -1; // unused
+
+ jcp.load_loop_load_step
+ = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
+
+ jcp.load_loop_iter_step = jcp.load_block;
+
+ jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
+
+ int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
+ int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
+
+ reduce_blocking = nb_reduce;
+ if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
+ reduce_blocking = 64;
+ else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
+ reduce_blocking = 16;
+ reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
+ reduce_blocking *= jcp.reduce_block;
+
+ bool cmp_reduce = reduce_blocking <= jcp.reduce_dim;
+ if (cmp_reduce)
+ jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
+ load_blocking = jcp.load_dim;
+
+ jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
+ jcp.load_grp_count = best_divider(
+ nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
+
+ if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.load_dim * jcp.reduce_dim >= L2_size) {
+ jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
+ } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= nthreads
+ && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
+ jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); //
+ load_blocking = jcp.load_block;
+ }
+
+ bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
+ div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block;
+ bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
+ bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
+
+ int space_for_bcast
+ = (L2_capacity - /* kernel_size - */
+ 2 * jcp.load_block * reduce_blocking
+ - jcp.ur * reduce_blocking - 3 * 1024);
+ if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
+ space_for_bcast /= 2;
+
+ int bcast_in_cache
+ = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
+ bcast_blocking = nstl::min(
+ bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
+
+ load_blocking_max = load_blocking;
+ bcast_blocking_max = bcast_blocking * 3 / 2;
+ reduce_blocking_max = reduce_blocking;
+
+ assert(load_blocking);
+ assert(load_blocking_max);
+ assert(bcast_blocking);
+ assert(bcast_blocking_max);
+ assert(reduce_blocking);
+ assert(reduce_blocking_max);
+ assert(load_blocking % jcp.load_block == 0);
+ assert(reduce_blocking % jcp.reduce_block == 0);
+ assert(load_blocking_max % jcp.load_block == 0);
+ assert(reduce_blocking_max % jcp.reduce_block == 0);
+
+ assert(jcp.reduce_loop_unroll % 4 == 0);
+ assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
+
+ assert(jcp.bcast_block % jcp.ur == 0);
+ assert(jcp.reduce_dim % jcp.reduce_block == 0);
+
+ jcp.ur_tail = jcp.bcast_dim % jcp.ur;
+
+ jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
+ jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
+ jcp.nb_load_blocking = load_blocking / jcp.load_block;
+ jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
+ jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
+ jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
+
+ jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
+ jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
+ jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
+
+ // miniumum size of load dim chunk for work distribution within threads
+ jcp.nb_load_chunk = 1;
+ // peformance improvements for googlenet_v3, mb=1;
+ // TODO: generalize this condition and rewrite it in appropriate manner
+ if (jcp.mb == 1 && jcp.nb_load % 4 == 0 && jcp.ic / jcp.oc >= 4
+ && jcp.ic * jcp.oc <= L2_size) {
+ jcp.nb_load_chunk = 4;
+ jcp.load_grp_count = nstl::max(jcp.nb_load / 4, jcp.load_grp_count);
+ }
+
+ const auto &oscales = attr.output_scales_;
+ jcp.is_oc_scale = oscales.mask_ == 1 << 1;
+ assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
+
+ jcp.wei_adj_scale =
+ (weights_d.extra().flags | memory_extra_flags::scale_adjust)
+ ? weights_d.extra().scale_adjust : 1.f;
+
+ return status::success;
+}
+
+void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ using namespace mkldnn::impl::memory_tracking::names;
+
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ dim_t count = nstl::max<dim_t>(attr.output_scales_.count_, 16);
+ scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
+ }
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp
new file mode 100644
index 0000000000..22e9732a1f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp
@@ -0,0 +1,131 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
+#define JIT_AVX512_CORE_X8S8S32X_1X1_CONV_KERNEL_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_avx512_core_x8s8s32x_1x1_conv_kernel: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_1x1_conv_fwd_ker_t)
+ jit_avx512_core_x8s8s32x_1x1_conv_kernel(jit_1x1_conv_conf_t ajcp,
+ const primitive_attr_t &attr) : jcp(ajcp), attr_(attr),
+ eltwise_injector_(nullptr)
+ {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
+ this, jcp.eltwise);
+
+ this->generate();
+ jit_ker = (void (*)(jit_1x1_conv_call_s *)) this->getCode();
+ }
+
+ ~jit_avx512_core_x8s8s32x_1x1_conv_kernel() {
+ delete eltwise_injector_;
+ }
+
+ static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_1x1_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d,
+ const memory_desc_wrapper &bias_d,
+ const primitive_attr_t &attr,
+ int nthreads, bool reduce_src);
+
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr);
+
+ bool maybe_eltwise(int position);
+
+ jit_1x1_conv_conf_t jcp;
+ const primitive_attr_t &attr_;
+ void (*jit_ker)(jit_1x1_conv_call_s *);
+
+ private:
+ jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
+
+ using reg64_t = const Xbyak::Reg64;
+ using zmm_t = const Xbyak::Zmm;
+ using mask_t = const Xbyak::Opmask;
+
+ reg64_t reg_bcast_data = r8;
+ reg64_t reg_ptr_scales = r8;
+ reg64_t reg_output_data = r9;
+ reg64_t reg_load_data = r10;
+ reg64_t reg_ptr_sum_scale = r10;
+ reg64_t reg_reduce_loop_work = r11;
+ reg64_t reg_bias_data = r12;
+ reg64_t reg_comp_data = r12;
+ reg64_t reg_scratch = r13;
+ reg64_t aux_reg_bcast_data = r14;
+ reg64_t aux_reg_load_data = r15;
+ reg64_t imm_addr64 = r15;
+ reg64_t reg_reduce_pos_flag = rax;
+ reg64_t aux1_reg_bcast_data = rbx;
+ reg64_t reg_bcast_loop_work = rbx;
+ reg64_t bcast_loop_iter = rdx; // Note: Fix me
+ reg64_t reg_load_loop_work = rsi;
+ reg64_t aux_reg_output_data = abi_not_param1;
+ reg64_t reduce_loop_iter = abi_param1;
+
+ reg64_t reg_last_load = r8;
+ mask_t ktail_mask = k6;
+
+ mask_t vmask = k7;
+
+ Xbyak::Zmm zmm_tmp = Xbyak::Zmm(28);
+ Xbyak::Zmm zmm_one = Xbyak::Zmm(29);
+ Xbyak::Zmm zmm_zero = Xbyak::Zmm(30);
+ Xbyak::Zmm zmm_bcast = Xbyak::Zmm(31);
+ Xbyak::Zmm zmm_shift = Xbyak::Zmm(30);
+
+ Xbyak::Zmm zmm_bias_alpha = Xbyak::Zmm(31);
+ Xbyak::Xmm xmm_bias_alpha = Xbyak::Xmm(31);
+
+ int bcast_loop_work_off = 0;
+ int reg_bias_data_off = 8;
+ int reg_bcast_data_off = 16;
+ int reg_load_data_off = 24;
+ int reg_ptr_sum_scale_off = 32;
+ int reg_comp_data_off = 40;
+ int stack_space_needed = 48;
+
+ void bcast_loop(int load_loop_blk);
+ void reduce_loop(int load_loop_blk, int ur, int substep, bool wraparound);
+
+ void generate();
+ void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
+ bool mask_flag);
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp
new file mode 100644
index 0000000000..0bf09fc677
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.cpp
@@ -0,0 +1,292 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_generator.hpp"
+
+#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+namespace {
+template <typename T, typename U>
+void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end,
+ T nx, T &nx_start, T &nx_end, T nx_divider)
+{
+ const T grp_size = utils::div_up(nthr, nx_divider);
+ const T grp_count = utils::div_up(nthr, grp_size);
+
+ T grp = ithr / grp_size;
+ T grp_ithr = ithr % grp_size;
+ T grp_nthr = grp_size;
+ T first_grps = nthr % grp_count;
+ if (first_grps > 0 && grp >= first_grps) {
+ ithr -= first_grps * grp_size;
+ grp_nthr--;
+ grp = ithr / grp_nthr + first_grps;
+ grp_ithr = ithr % grp_nthr;
+ }
+ balance211(nx, grp_count, grp, nx_start, nx_end);
+ balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
+}
+}
+
+/* convolution forward */
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>::
+execute_forward(const exec_ctx_t &ctx) const
+{
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) {
+ auto local_scales = scratchpad.template get<float>(
+ key_conv_adjusted_scales);
+ auto scales = pd()->attr()->output_scales_.scales_;
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / pd()->jcp_.wei_adj_scale;
+ if (count == 1) {
+ utils::array_set(local_scales, scales[0] * factor, 16);
+ } else {
+ for (size_t c = 0; c < count; c++)
+ local_scales[c] = scales[c] * factor;
+ }
+ }
+
+ parallel(kernel_->jcp.nthr, [&](const int ithr, const int nthr) {
+ execute_forward_thr(ithr, nthr, src, weights, bias, dst, scratchpad);
+ });
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, dst_type>
+::execute_forward_thr(const int ithr, const int nthr, const src_data_t *src,
+ const wei_data_t *weights, const char *bias, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const {
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const size_t bia_dt_size = pd()->with_bias()
+ ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
+
+ const auto &jcp = kernel_->jcp;
+ auto rtus_space = scratchpad.get<src_data_t>(key_conv_rtus_space);
+ auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
+
+ const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
+
+ const int stride_h = pd()->desc()->strides[0];
+ const int stride_w = pd()->desc()->strides[1];
+ const int pad_t = pd()->desc()->padding[0][0];
+ const int pad_l = pd()->desc()->padding[0][1];
+
+ const auto &oscales = pd()->attr()->output_scales_;
+
+ int offset = jcp.ngroups * (jcp.oc / jcp.oc_block) * (jcp.ic / jcp.ic_block)
+ * jcp.oc_block * jcp.ic_block;
+ wei_data_t *w = const_cast<wei_data_t *>(weights);
+ int32_t* compensation = (jcp.signed_input)
+ ? reinterpret_cast<int32_t *>(w + offset) : 0;
+
+ auto step = [](int default_step, int remaining, int tail_step) {
+ assert(default_step <= tail_step);
+ return remaining < tail_step ? remaining : default_step;
+ };
+
+ auto p = jit_1x1_conv_call_s();
+
+ auto rp = rtus_driver_t<avx512_common>::call_params_t();
+ const int nb_oc = jcp.nb_load;
+ const int os_block = jcp.bcast_block;
+
+ int bcast_start{0}, bcast_end{0}, ocb_start{0}, ocb_end{0};
+ balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
+ jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end,
+ jcp.load_grp_count);
+ if (jcp.nb_load_chunk > 1) {
+ ocb_start *= jcp.nb_load_chunk;
+ ocb_end *= jcp.nb_load_chunk;
+ }
+
+ auto init_bcast = [&](int iwork, int &n, int &g, int &bcast_step,
+ int &oh, int &ow, int &ih, int &iw)
+ {
+ int osb{0};
+ nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
+ jcp.nb_bcast);
+ bcast_step = step(jcp.nb_bcast_blocking, jcp.nb_bcast - osb,
+ jcp.nb_bcast_blocking_max);
+ bcast_step = nstl::min(bcast_step, bcast_end - iwork);
+
+ const int os = osb * os_block;
+ oh = os / jcp.ow;
+ ow = os % jcp.ow;
+
+ ih = nstl::max(oh * stride_h - pad_t, 0);
+ iw = nstl::max(ow * stride_w - pad_l, 0);
+ rp.iw_start = iw;
+
+ p.bcast_dim = this_block_size(os, jcp.os,
+ bcast_step * os_block);
+ rp.os = p.bcast_dim;
+ };
+
+ auto init_load = [&](int ocb, int &load_step)
+ {
+ load_step = step(jcp.nb_load_blocking, ocb_end - ocb,
+ jcp.nb_load_blocking_max);
+ p.load_dim = this_block_size(ocb * jcp.oc_block,
+ ocb_end * jcp.oc_block, load_step * jcp.oc_block);
+
+ if (ocb + load_step >= nb_oc)
+ p.first_last_flag |= FLAG_OC_LAST;
+ else
+ p.first_last_flag &= ~FLAG_OC_LAST;
+
+ };
+
+ auto init_reduce = [&]()
+ {
+ p.reduce_dim = this_block_size(0, jcp.ic, jcp.ic);
+ rp.icb = p.reduce_dim / jcp.reduce_block;
+ };
+
+ auto inner_ker = [&](int ocb, int n, int g, int oh, int ow,
+ int ih, int iw)
+ {
+ const int icb = 0; // Start from the first IC block
+ const int _ocb = g * nb_oc + ocb;
+ const int _icb = g;
+
+ const size_t dst_off = dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow);
+
+ p.output_data = &dst[dst_off];
+ p.load_data = &weights[pd()->with_groups()
+ ? weights_d.blk_off(g, ocb, icb)
+ : weights_d.blk_off(ocb, icb)];
+ p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
+ p.compensation = (jcp.signed_input)
+ ? &compensation[_ocb * jcp.oc_block] : 0;
+ p.scales = (jcp.signed_input && jcp.ver != ver_vnni)
+ ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block]
+ : &oscales.scales_[jcp.is_oc_scale * _ocb * jcp.oc_block];
+ if (pd()->rtus_.reduce_src_) {
+ rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
+ + _icb * jcp.is * jcp.ic_block;
+ if (ocb == ocb_start) {
+ rp.src = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
+ rtus_driver_->ker_(&rp);
+ }
+ p.bcast_data = rp.ws;
+ } else
+ p.bcast_data = src + src_d.blk_off(n, _icb * jcp.ic_block, ih, iw);
+
+ kernel_->jit_ker(&p);
+ };
+
+ if (jcp.loop_order == loop_rlb) {
+ init_reduce();
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ inner_ker(ocb, n, g, oh, ow, ih, iw);
+ iwork += bcast_step;
+ }
+ ocb += load_step;
+ }
+ } else if (jcp.loop_order == loop_lbr) {
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ init_reduce();
+ inner_ker(ocb, n, g, oh, ow, ih, iw);
+ iwork += bcast_step;
+ }
+ ocb += load_step;
+ }
+ } else if (jcp.loop_order == loop_rbl) {
+ init_reduce();
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ inner_ker(ocb, n, g, oh, ow, ih, iw);
+ ocb += load_step;
+ }
+ iwork += bcast_step;
+ }
+ } else if (jcp.loop_order == loop_blr) {
+ int iwork = bcast_start;
+ while (iwork < bcast_end) {
+ int n, g, bcast_step, oh, ow, ih, iw;
+ init_bcast(iwork, n, g, bcast_step, oh, ow, ih, iw);
+ int ocb = ocb_start;
+ while (ocb < ocb_end) {
+ int load_step;
+ init_load(ocb, load_step);
+ init_reduce();
+ inner_ker(ocb, n, g, oh, ow, ih, iw);
+ ocb += load_step;
+ }
+ iwork += bcast_step;
+ }
+ } else {
+ assert(!"unsupported loop order");
+ }
+}
+
+using namespace data_type;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>;
+template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>;
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp
new file mode 100644
index 0000000000..ad9027ac17
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_convolution.hpp
@@ -0,0 +1,159 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP
+#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp"
+#include "jit_uni_1x1_conv_utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template<impl::data_type_t src_type, impl::data_type_t dst_type>
+struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t : public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_(), rtus_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_int8_1x1:", avx512_core, ""),
+ jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<
+ src_type, dst_type>);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(src_type, data_type::s8, data_type::undef,
+ dst_type, data_type::s32)
+ && IMPLICATION(with_bias(), utils::one_of(
+ desc()->bias_desc.data_type, data_type::f32,
+ data_type::s32, data_type::s8, data_type::u8))
+ && !has_zero_dim_memory()
+ && set_default_formats_common(dat_tag(), format_tag::any,
+ dat_tag())
+ && set_or_check_wei_format();
+ if (!ok) return status::unimplemented;
+
+ const convolution_desc_t *conv_d = desc();
+ const memory_desc_t *src_d = src_md();
+ rtus_prepare(this, conv_d, src_d, dst_md());
+
+ status_t status = jit_avx512_core_x8s8s32x_1x1_conv_kernel::
+ init_conf(jcp_, *conv_d, *src_d, *weights_md(), *dst_md(),
+ *weights_md(1), *attr(), mkldnn_get_max_threads(),
+ rtus_.reduce_src_);
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad(
+ scratchpad, jcp_, *attr());
+
+ rtus_prepare_space_info(this, scratchpad);
+
+ return status::success;
+ }
+
+ jit_1x1_conv_conf_t jcp_;
+ reduce_to_unit_stride_t rtus_;
+
+ protected:
+ format_tag_t dat_tag() const { return format_tag::nhwc; }
+
+ bool set_or_check_wei_format() {
+ using namespace format_tag;
+
+ const bool is_src_s8 = src_md_.data_type == data_type::s8;
+ format_tag_t wei_tag = with_groups() ? gOIhw4i16o4i : OIhw4i16o4i;
+
+ memory_desc_t want_wei_md = weights_md_;
+ memory_desc_init_by_tag(want_wei_md, wei_tag);
+ if (is_src_s8) {
+ want_wei_md.extra.flags = 0
+ | memory_extra_flags::compensation_conv_s8s8
+ | memory_extra_flags::scale_adjust;
+ want_wei_md.extra.compensation_mask = (1 << 0)
+ + (with_groups() ? (1 << 1) : 0);
+ want_wei_md.extra.scale_adjust =
+ mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
+ }
+
+ if (weights_md_.format_kind == format_kind::any) {
+ weights_md_ = want_wei_md;
+ return true;
+ }
+
+ return weights_md_ == want_wei_md;
+ }
+ };
+
+ template <cpu_isa_t isa, typename conv_t>
+ friend void init_rtus_driver(conv_t *self);
+
+ jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , kernel_(nullptr), rtus_driver_(nullptr)
+ {
+ kernel_ = new jit_avx512_core_x8s8s32x_1x1_conv_kernel(pd()->jcp_,
+ *pd()->attr());
+ init_rtus_driver<avx512_common>(this);
+ }
+
+ ~jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t() {
+ delete kernel_;
+ delete rtus_driver_;
+ }
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<data_type::s8>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+ typedef typename prec_traits<data_type::s32>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+ private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ void execute_forward_thr(const int ithr, const int nthr,
+ const src_data_t *src, const wei_data_t *weights,
+ const char *bias, dst_data_t *dst,
+ const memory_tracking::grantor_t &scratchpad) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx512_core_x8s8s32x_1x1_conv_kernel *kernel_;
+ rtus_driver_t<avx512_common> *rtus_driver_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp
new file mode 100644
index 0000000000..e89d068302
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_1x1_deconvolution.hpp
@@ -0,0 +1,140 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP
+#define CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+#include "type_helpers.hpp"
+#include "primitive_iterator.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_deconvolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_uni_1x1_conv_utils.hpp"
+#include "jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t src_type, impl::data_type_t dst_type>
+struct jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t
+ : public cpu_primitive_t {
+ struct pd_t : public cpu_deconvolution_fwd_pd_t {
+ pd_t(engine_t *engine, const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , conv_pd_(nullptr) {}
+
+ pd_t(const pd_t &other)
+ : cpu_deconvolution_fwd_pd_t(other)
+ , conv_pd_(other.conv_pd_->clone())
+ {}
+
+ ~pd_t() { delete conv_pd_; }
+
+ DECLARE_COMMON_PD_T(conv_pd_->name(),
+ jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t<src_type, dst_type>);
+
+ status_t init_convolution() {
+ convolution_desc_t cd;
+ status_t status;
+
+ auto dd = desc();
+ status = conv_desc_init(&cd, prop_kind::forward_training,
+ alg_kind::convolution_direct, &(dd->src_desc),
+ &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc),
+ dd->strides, dd->dilates, dd->padding[0], dd->padding[1],
+ dd->padding_kind);
+
+ if (status == status::success) {
+ status = mkldnn_primitive_desc::create<conv_pd_t>(
+ &conv_pd_, (op_desc_t *)&cd, &attr_, engine_, nullptr);
+ }
+
+ if (status == status::success)
+ status = set_default_params();
+
+ return status;
+ };
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && desc()->alg_kind == alg_kind::deconvolution_direct
+ && !has_zero_dim_memory()
+ && desc()->src_desc.data_type == src_type
+ && desc()->dst_desc.data_type == dst_type
+ && desc()->weights_desc.data_type == data_type::s8
+ && IMPLICATION(with_bias(), utils::one_of(
+ desc()->bias_desc.data_type, data_type::f32,
+ data_type::s32, data_type::s8, data_type::u8))
+ && desc()->accum_data_type == data_type::s32;
+ if (!ok) return status::unimplemented;
+
+ CHECK(init_convolution());
+
+ return status::success;
+ }
+
+ virtual void init_scratchpad_md() override {
+ const auto conv_1x1_pd = static_cast<conv_pd_t *>(conv_pd_);
+ scratchpad_md_ = *conv_1x1_pd->scratchpad_md();
+ }
+
+ protected:
+ status_t set_default_params() {
+ auto conv_1x1_pd_ = static_cast<conv_pd_t *>(conv_pd_);
+ src_md_ = *conv_1x1_pd_->src_md();
+ dst_md_ = *conv_1x1_pd_->dst_md();
+ weights_md_ = *conv_1x1_pd_->weights_md();
+ if (with_bias())
+ bias_md_ = *conv_1x1_pd_->weights_md(1);
+ return status::success;
+ }
+
+ using conv_pd_t = typename jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t
+ <src_type, dst_type>::pd_t;
+ friend jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t;
+ primitive_desc_t *conv_pd_;
+ };
+
+ jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
+
+ ~jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t()
+ { delete conv_p_; }
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ return conv_p_->execute(ctx);
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ primitive_t *conv_p_;
+};
+
+}
+}
+}
+
+#endif /* CPU_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp
new file mode 100644
index 0000000000..10e98a00c4
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.cpp
@@ -0,0 +1,1182 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_memory.hpp"
+
+#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp"
+
+#define GET_OFF(field) offsetof(jit_conv_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+using namespace Xbyak;
+
+namespace {
+void pick_loop_order(jit_conv_conf_t &jcp, int nthr)
+{
+ jcp.loop_order = loop_cwgn;
+ if (jcp.ngroups > 1) {
+ jcp.loop_order = loop_ngcw;
+ if (jcp.mb < nthr)
+ jcp.loop_order = jcp.ndims == 3 ? loop_nwcg : loop_nhwcg;
+ }
+}
+}
+
+template<typename Vmm>
+bool _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::maybe_eltwise(int position)
+{
+ using namespace primitive_kind;
+ const auto &p = attr_.post_ops_;
+
+ if (position == 0) {
+ /* eltwise before sum */
+ return p.contain(eltwise, 0);
+ } else if (position == 1) {
+ /* eltwise after sum */
+ return p.contain(sum, 0) && p.contain(eltwise, 1);
+ }
+
+ return false;
+}
+
+template<typename Vmm>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::prepare_output(int ur_w)
+{
+ int nb_oc_block
+ = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+ for (int k = 0; k < nb_oc_block; k++)
+ for (int j = 0; j < ur_w; j++) {
+ Vmm vmm = vmm_out(j, k);
+ vpxord(vmm, vmm, vmm);
+ }
+ if (jcp.signed_input) {
+ xor_(reg_scratch, reg_scratch);
+ if (jcp.is_depthwise && !jcp.is_fast_depthwise) {
+ Reg32 _t32 = reg_scratch.cvt32();
+ mov(_t32, (uint32_t)128);
+ vpbroadcastd(vmm_shift, _t32);
+ } else {
+ Reg8 _t8 = reg_scratch.cvt8();
+ mov(_t8, (int8_t)128);
+ vpbroadcastb(vmm_shift, _t8);
+ }
+ }
+}
+
+template<typename Vmm>
+const Vmm _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::
+ vmm_mask(const Vmm vmm_in, bool mask_flag, bool store) {
+ return vmm_in;
+}
+
+template<>
+const Zmm _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::
+ vmm_mask(const Zmm zmm_in, bool mask_flag, bool store) {
+ return mask_flag ? (store ? zmm_in | ktail_mask : zmm_in | ktail_mask | T_z)
+ : zmm_in;
+}
+
+
+template<typename Vmm>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::cvt2ps(data_type_t type_in,
+ const Vmm vmm_in, const Operand &op, bool mask_flag) {
+ //const Vmm vmm = mask_flag ? vmm_in | ktail_mask | T_z : vmm_in;
+ const Vmm vmm = vmm_mask(vmm_in, mask_flag);
+ switch (type_in) {
+ case data_type::f32:
+ case data_type::s32: vmovups(vmm, op); break;
+ case data_type::s8: vpmovsxbd(vmm, op); break;
+ case data_type::u8: vpmovzxbd(vmm, op); break;
+ default: assert(!"unsupported data type");
+ }
+ if (type_in != data_type::f32)
+ vcvtdq2ps(vmm_in, vmm_in);
+}
+
+template<typename Vmm>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_eltwise(int ur_w) {
+ int nb_oc_block
+ = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+ if (ur_w == jcp.ur_w)
+ eltwise_injector_->compute_vector_range(0, nb_oc_block * jcp.ur_w);
+ else
+ for (int k = 0; k < nb_oc_block; k++)
+ eltwise_injector_->compute_vector_range(k * jcp.ur_w,
+ k * jcp.ur_w + ur_w);
+}
+
+template<typename Vmm>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::store_output(
+ int ur_w, bool last_oc_block_flag) {
+ int nb_oc_block
+ = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+ int oc_block = jcp.is_depthwise ? jcp.ch_block : jcp.oc_block;
+
+ mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
+ mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
+ if (jcp.signed_input)
+ mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
+
+ const auto &p = attr_.post_ops_;
+ const int sum_idx = p.find(primitive_kind::sum);
+ const float *p_sum_scale = nullptr;
+ if (sum_idx != -1) {
+ const auto &p_entry = p.entry_[sum_idx];
+ p_sum_scale = &p_entry.sum.scale;
+ }
+
+ if (p_sum_scale && *p_sum_scale != 1.f)
+ mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
+
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ /* put 'wei_adj_scale = 0.5' for bias calculation */
+ mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
+ vmovq(xmm_bias_alpha(), reg_bias_alpha);
+ vbroadcastss(vmm_bias_alpha(), xmm_bias_alpha());
+ }
+
+ for (int k = 0; k < nb_oc_block; k++) {
+ const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
+ int scale_offset = jcp.is_oc_scale * (sizeof(float) * k * oc_block);
+ if (jcp.with_bias) {
+ int bias_offset = jcp.typesize_bia * k * oc_block;
+ auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
+
+ cvt2ps(jcp.bia_dt, vmm_bias, bias_addr, mask_flag);
+ if (jcp.signed_input && jcp.ver != ver_vnni)
+ /* bias *= 0.5 */
+ vmulps(vmm_bias, vmm_bias, vmm_bias_alpha());
+ }
+ if (jcp.signed_input) {
+ int comp_offset = sizeof(int32_t) * k * oc_block;
+ auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
+
+ cvt2ps(data_type::s32, vmm_comp, comp_addr, mask_flag);
+ }
+ /* add to zmm_accum: compensation, bias and permute */
+ for (int j = 0; j < ur_w; j++) {
+ Vmm vmm = vmm_out(j, k);
+ if (jcp.is_fast_depthwise)
+ vpermd(zmm_out(j, k), zmm_permute, zmm_out(j, k));
+ vcvtdq2ps(vmm, vmm);
+ if (jcp.signed_input)
+ vaddps(vmm, vmm, vmm_comp);
+ if (jcp.with_bias)
+ vaddps(vmm, vmm, vmm_bias);
+
+ const Vmm vmm_k = vmm_mask(vmm, mask_flag);
+ vmulps(vmm_k, vmm,
+ EVEX_compress_addr(reg_ptr_scales, scale_offset));
+ }
+ }
+
+ /* Do post-ops */
+ if (maybe_eltwise(0)) compute_eltwise(ur_w);
+ if (p_sum_scale) { // post_op: sum
+ for (int k = 0; k < nb_oc_block; k++) {
+ const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
+ for (int j = 0; j < ur_w; j++) {
+ int aux_output_offset
+ = jcp.typesize_out
+ * (k * oc_block
+ + j * jcp.oc_without_padding * jcp.ngroups);
+ auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
+ Vmm vmm = vmm_out(j, k);
+ cvt2ps(jcp.dst_dt, vmm_prev_dst, addr, mask_flag);
+ if (*p_sum_scale == 1.f)
+ vaddps(vmm, vmm_prev_dst);
+ else
+ vfmadd231ps(vmm, vmm_prev_dst, zword_b[reg_ptr_sum_scale]);
+ }
+ }
+ }
+ if (maybe_eltwise(1)) compute_eltwise(ur_w);
+
+ /* write out register to output_addr */
+ for (int k = 0; k < nb_oc_block; k++) {
+ const bool mask_flag = last_oc_block_flag && k == nb_oc_block - 1;
+ for (int j = 0; j < ur_w; j++) {
+ Vmm vmm = vmm_out(j, k);
+ if (jcp.dst_dt == data_type::u8) {
+ vpxord(vmm_zero, vmm_zero, vmm_zero);
+ vmaxps(vmm, vmm_zero, vmm);
+ }
+
+ if (jcp.dst_dt != data_type::f32) {
+ /* Note: using Zmm for rounding in Xmm/Ymm kernel
+ because there is no instruction to do rounding
+ from Xmm/Ymm -> Xmm/Ymm.
+ Embedded rounding is not supported for Xmm.
+ TODO: maybe avoid Zmm if it helps performance.*/
+ Zmm zmm = zmm_out(j, k);
+ vcvtps2dq(zmm, zmm);
+ }
+ }
+
+ for (int j = 0; j < ur_w; j++) {
+ int aux_output_offset = jcp.typesize_out
+ * (k * oc_block + j * jcp.oc_without_padding * jcp.ngroups);
+ auto addr = EVEX_compress_addr(reg_out, aux_output_offset);
+
+ Vmm vmm = vmm_out(j, k);
+ const Vmm r_vmm = vmm_mask(vmm, mask_flag, true);
+
+ switch (jcp.dst_dt) {
+ case data_type::f32:
+ case data_type::s32: vmovups(addr, r_vmm); break;
+ case data_type::s8: vpmovsdb(addr, r_vmm); break;
+ case data_type::u8: vpmovusdb(addr, r_vmm); break;
+ default: assert(!"unknown dst_dt");
+ }
+ }
+ }
+
+}
+
+template <typename Vmm>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker_dw(
+ int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
+ assert(!"invalid group blocking for depthwise convolution");
+}
+
+template <>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>::compute_ker_dw(
+ int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
+
+ auto input_spatial_index = [=](int oi, int ki) {
+ return (ki * (jcp.dilate_w + 1) + oi * jcp.stride_w - pad_l);
+ };
+
+ auto input_offset2 = [=](int ii, int ci) {
+ return jcp.typesize_in * (ii * jcp.ngroups + ci * jcp.ch_block);
+ };
+
+ auto input_offset3 = [=](int oi, int ci, int ki) {
+ return jcp.typesize_in * input_offset2(input_spatial_index(oi, ki), ci);
+ };
+
+ auto kernel_offset = [=](int ci, int ki) {
+ return jcp.typesize_in * ((ci * jcp.kh * jcp.kw + ki) * jcp.ch_block);
+ };
+
+ auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
+ // okay for depthwise since src is zero-extended
+ if (jcp.ver == ver_vnni) {
+ vpdpbusd(vreg_acc, vreg_src, vreg_wei);
+ } else {
+ vpmaddwd(zmm_tmp, vreg_src, vreg_wei);
+ vpaddd(vreg_acc, vreg_acc, zmm_tmp);
+ }
+ };
+
+ int ii_start = 0;
+ int ii_end = -1;
+ if (jcp.is_resrc_depthwise && !h_padded) {
+ // find bounds of input spatial indices
+ bool first = true;
+ for (int ki = 0; ki < jcp.kw; ki++) {
+ int oi_start = get_ow_start(ki, pad_l);
+ int oi_end = get_ow_end(ur_w, ki, pad_r);
+ for (int oi = oi_start; oi < oi_end; oi++) {
+ int ii = input_spatial_index(oi, ki);
+ if (first || ii < ii_start)
+ ii_start = ii;
+ if (first || ii > ii_end)
+ ii_end = ii;
+ first = false;
+ }
+ }
+ }
+
+ if (jcp.signed_input) {
+ vpxord(zmm_shifted_zero, zmm_shifted_zero, zmm_shifted_zero);
+ vpaddb(zmm_shifted_zero, zmm_shifted_zero, vmm_shift);
+ }
+ for (int ci = 0; ci < jcp.nb_ch_blocking; ci++) {
+ const bool mask_flag = last_ic_block_flag != no_last_block
+ && ci == jcp.nb_ch_blocking - 1;
+ if (jcp.is_resrc_depthwise && !h_padded) {
+ // now we can load input once and reuse up to jcp.kw times
+ for (int ii = ii_start; ii <= ii_end; ii++) {
+ int aux_input_offset = input_offset2(ii, ci);
+ const Zmm zmm_inp_tmp = zmm_inp(ii, jcp.nb_ch_blocking);
+ const Zmm zmm_inp_msk = mask_flag
+ ? zmm_inp_tmp | ktail_mask | T_z
+ : zmm_inp_tmp;
+ if (jcp.is_fast_depthwise) {
+ assert(!mask_flag);
+ vbroadcasti32x4(zmm_inp_msk,
+ EVEX_compress_addr(aux_reg_inp, aux_input_offset));
+ } else {
+ vpmovzxbd(zmm_inp_msk,
+ EVEX_compress_addr(aux_reg_inp, aux_input_offset));
+ }
+ if (jcp.signed_input)
+ vpaddb(zmm_inp_tmp, zmm_inp_tmp, vmm_shift);
+ }
+ }
+ for (int ki = 0; ki < jcp.kw; ki++) {
+ int aux_kernel_offset = kernel_offset(ci, ki);
+ if (jcp.is_fast_depthwise) {
+ vbroadcasti32x4(zmm_wei,
+ EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
+ vmovdqu8(zmm_wei | kblend_mask | T_z, zmm_wei);
+ } else {
+ vpmovsxbd(zmm_wei,
+ EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
+ }
+ if (h_padded) {
+ assert(jcp.signed_input);
+ for (int oi = 0; oi < ur_w; oi++)
+ compute(zmm_out(oi, ci), zmm_wei, zmm_shifted_zero);
+ } else {
+ const Zmm r_zmm_src = mask_flag ? zmm_src | ktail_mask : zmm_src;
+ int oi_start = get_ow_start(ki, pad_l);
+ int oi_end = get_ow_end(ur_w, ki, pad_r);
+ int start_ = jcp.signed_input ? 0 : oi_start;
+ int end_ = jcp.signed_input ? ur_w : oi_end;
+ for (int oi = start_; oi < end_; oi++) {
+ if (oi >= oi_start && oi < oi_end) {
+ if (jcp.is_resrc_depthwise) {
+ int ii = input_spatial_index(oi, ki);
+ zmm_src = zmm_inp(ii, jcp.nb_ch_blocking);
+ } else {
+ int aux_input_offset = input_offset3(oi, ci, ki);
+ if (jcp.is_fast_depthwise) {
+ assert(!mask_flag);
+ vbroadcasti32x4(r_zmm_src,
+ EVEX_compress_addr(aux_reg_inp,
+ aux_input_offset));
+ } else {
+ vpmovzxbd(r_zmm_src,
+ EVEX_compress_addr(aux_reg_inp,
+ aux_input_offset));
+ }
+ if (jcp.signed_input)
+ vpaddb(zmm_src, zmm_src, vmm_shift);
+ }
+ } else if (jcp.signed_input) {
+ zmm_src = zmm_shifted_zero;
+ }
+ compute(zmm_out(oi, ci), zmm_wei, zmm_src);
+ }
+ }
+ }
+ }
+}
+
+template<typename Vmm>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::compute_ker(int ur_w, int pad_l,
+ int pad_r, ic_block_t last_ic_block_flag, bool h_padded) {
+ if (jcp.is_depthwise)
+ return compute_ker_dw(ur_w, pad_l, pad_r, last_ic_block_flag, h_padded);
+
+ int kw = jcp.kw;
+ int stride_w = jcp.stride_w;
+ int ic_block = jcp.ic_block;
+ int oc_block = jcp.oc_block;
+ int ch_block_all = jcp.ch_block * ic_block * oc_block;
+
+ int nb_oc_block = jcp.nb_oc_blocking;
+
+ auto input_offset = [=](int oi, int ic, int ki) {
+ return jcp.typesize_in
+ * ((ki * (jcp.dilate_w + 1) + oi * stride_w - pad_l)
+ * jcp.ic_without_padding * jcp.ngroups + 4 * ic);
+ };
+ auto kernel_offset = [=](int ii, int ic, int ki) {
+ return jcp.typesize_in
+ * ((ii * jcp.nb_ic * jcp.kh * jcp.kw + ki) * ch_block_all
+ + 4 * ic * oc_block);
+ };
+ auto compute = [=](Vmm vreg_acc, Vmm vreg_wei, Vmm vreg_src) {
+ if (jcp.ver == ver_vnni) {
+ vpdpbusd(vreg_acc, vreg_src, vreg_wei);
+ } else {
+ vpmaddubsw(vmm_tmp, vreg_src, vreg_wei);
+ vpmaddwd(vmm_tmp, vmm_tmp, vmm_one);
+ vpaddd(vreg_acc, vreg_acc, vmm_tmp);
+ }
+ };
+
+ for (int ki = 0; ki < kw; ki++) {
+ int jj_start = get_ow_start(ki, pad_l);
+ int jj_end = get_ow_end(ur_w, ki, pad_r);
+ int tail_size = jcp.ic_without_padding % 4;
+ int _start = (jcp.signed_input) ? 0 : jj_start;
+ int _end = (jcp.signed_input) ? ur_w : jj_end;
+ /* Skip the last loads of input if (ic%16)/4 < ic_block/4 */
+ int icb = (last_ic_block_flag != no_last_block)
+ ? div_up((jcp.ic_without_padding % ic_block), 4)
+ : ic_block / 4;
+ for (int ic = 0; ic < icb; ic++) {
+ if (h_padded == true) {
+ /* fill padded area with shifted values */
+ Vmm inp = vmm_inp(0,nb_oc_block);
+ vpxord(inp, inp, inp);
+ vpaddb(inp, inp, vmm_shift);
+ } else {
+ for (int jj = _start; jj < _end; jj++) {
+ int aux_input_offset = input_offset(jj, ic, ki);
+ if (jj >= jj_start && jj < jj_end) {
+ if (last_ic_block_flag == last_sp_block
+ && tail_size != 0 && ic == icb - 1) {
+ Xmm xmm_tmp = Xmm(vmm_inp(jj, nb_oc_block).getIdx());
+ for (int r = 0; r < tail_size; ++r)
+ vpinsrb(xmm_tmp, xmm_tmp,
+ ptr[aux_reg_inp + aux_input_offset + r], r);
+ vpbroadcastd(vmm_inp(jj, nb_oc_block), xmm_tmp);
+ } else {
+ vpbroadcastd(vmm_inp(jj, nb_oc_block),
+ EVEX_compress_addr(
+ aux_reg_inp, aux_input_offset));
+ }
+ if (jcp.signed_input)
+ vpaddb(vmm_inp(jj, nb_oc_block),
+ vmm_inp(jj, nb_oc_block), vmm_shift);
+ } else {
+ /* fill padded area with shifted values */
+ if (jcp.signed_input) {
+ Vmm inp = vmm_inp(jj, nb_oc_block);
+ vpxord(inp, inp, inp);
+ vpaddb(inp, inp, vmm_shift);
+ }
+ }
+ }
+ }
+ for (int ii = 0; ii < nb_oc_block; ii++) {
+ int aux_kernel_offset = kernel_offset(ii, ic, ki);
+ vmovups(vmm_wei,
+ EVEX_compress_addr(aux_reg_ker, aux_kernel_offset));
+ for (int jj = _start; jj < _end; jj++) {
+ Vmm inp = (h_padded == true)
+ ? vmm_inp(0,nb_oc_block) : vmm_inp(jj, nb_oc_block);
+ compute(vmm_out(jj, ii), vmm_wei, inp);
+ }
+ }
+ }
+ }
+}
+
+template<typename Vmm>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::kh_loop(
+ int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag) {
+ Label kh_label, skip_kh_loop;
+ Label t_overflow_label, no_t_overflow_label,
+ b_overflow_label, no_b_overflow_label;
+
+ int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
+ int shift_kernel_ptr = jcp.typesize_in * jcp.kw * ch_block_all;
+ int shift_input_ptr = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
+ * jcp.ic_without_padding * jcp.ngroups;
+
+ mov(aux_reg_inp, reg_inp);
+ mov(aux_reg_ker, reg_ker);
+
+ if (jcp.signed_input && jcp.ndims > 3) {
+ mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
+ cmp(reg_overflow, 0);
+ je(no_t_overflow_label, T_NEAR);
+ L(t_overflow_label); {
+ compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
+
+ add(aux_reg_ker, shift_kernel_ptr);
+ dec(reg_overflow);
+ cmp(reg_overflow, 0);
+ jg(t_overflow_label, T_NEAR);
+ }
+ L(no_t_overflow_label);
+ }
+ mov(reg_kj, ptr[param1 + GET_OFF(kh_padding)]);
+ if ((jcp.signed_input) || (!jcp.signed_input &&
+ (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
+ cmp(reg_kj, 0);
+ je(skip_kh_loop, T_NEAR);
+ }
+ L(kh_label); {
+ compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, false);
+
+ add(aux_reg_ker, shift_kernel_ptr);
+ add(aux_reg_inp, shift_input_ptr);
+ dec(reg_kj);
+ cmp(reg_kj, 0);
+ jg(kh_label, T_NEAR);
+ }
+ L(skip_kh_loop);
+ if (jcp.signed_input && jcp.ndims > 3) {
+ mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
+ cmp(reg_overflow, 0);
+ je(no_b_overflow_label, T_NEAR);
+ L(b_overflow_label); {
+ compute_ker(ur_w, pad_l, pad_r, last_ic_block_flag, true);
+
+ add(aux_reg_ker, shift_kernel_ptr);
+ dec(reg_overflow);
+ cmp(reg_overflow, 0);
+ jg(b_overflow_label, T_NEAR);
+ }
+ L(no_b_overflow_label);
+ }
+}
+
+template<typename Vmm>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::icb_loop(
+ int ur_w, int pad_l, int pad_r, bool is_last_sp_block)
+{
+ prepare_output(ur_w);
+
+ // IC loop
+ Label icb_label;
+ mov(reg_icb, jcp.nb_ic);
+ L(icb_label);
+ if (jcp.ngroups % jcp.ch_block != 0 || jcp.ic_without_padding != jcp.ic) {
+ Label common_ker, end_ker;
+
+ cmp(reg_icb, 1); // The last IC block
+ jne(common_ker, T_NEAR);
+
+ kh_loop(ur_w, pad_l, pad_r,
+ is_last_sp_block ? last_sp_block : last_ic_block);
+ jmp(end_ker, T_NEAR);
+
+ L(common_ker);
+ kh_loop(ur_w, pad_l, pad_r, no_last_block);
+
+ L(end_ker);
+ } else {
+ kh_loop(ur_w, pad_l, pad_r, no_last_block);
+ }
+ // End of IC Loop
+ int inp_step = jcp.ic_block;
+ int ker_step = jcp.kh * jcp.kw * jcp.oc_block * jcp.ic_block;
+ add(reg_inp, jcp.typesize_in * inp_step);
+ add(reg_ker, jcp.typesize_in * ker_step);
+
+ dec(reg_icb);
+ cmp(reg_icb, 0);
+ jg(icb_label, T_NEAR);
+
+ sub(reg_inp, jcp.typesize_in * inp_step * jcp.nb_ic);
+ sub(reg_ker, jcp.typesize_in * ker_step * jcp.nb_ic);
+
+ if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
+ Label common_store, end_store;
+
+ if (jcp.is_depthwise)
+ cmp(reg_oc_blocks, jcp.nb_ch - jcp.nb_ch_blocking);
+ else
+ cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
+
+ jne(common_store, T_NEAR);
+
+ store_output(ur_w, true); // last oc block
+ jmp(end_store, T_NEAR);
+
+ L(common_store);
+ store_output(ur_w, false);
+
+ L(end_store);
+ } else {
+ store_output(ur_w, false);
+ }
+}
+
+template<typename Vmm>
+void _jit_avx512_core_x8s8s32x_fwd_kernel<Vmm>::generate()
+{
+ Label permute_index_table;
+ int inp_shift_pad = jcp.typesize_in * (jcp.ur_w * jcp.stride_w - jcp.l_pad)
+ * jcp.ic_without_padding * jcp.ngroups;
+ int inp_shift_pad_second_block = -1 * jcp.typesize_in * jcp.l_pad
+ * jcp.ic_without_padding * jcp.ngroups;
+ int inp_shift = jcp.typesize_in *
+ (jcp.ur_w * jcp.stride_w * jcp.ic_without_padding
+ * jcp.ngroups);
+ int out_shift = jcp.typesize_out *
+ (jcp.ur_w * jcp.oc_without_padding * jcp.ngroups);
+ preamble();
+
+ if (jcp.is_depthwise) {
+ int idx = jcp.max_regs_ur - 1;
+ if (!jcp.is_resrc_depthwise)
+ zmm_src = Zmm(++idx);
+ if (jcp.ver != ver_vnni)
+ zmm_tmp = Zmm(++idx);
+ if (jcp.is_fast_depthwise)
+ zmm_permute = Zmm(++idx);
+ if (jcp.signed_input) {
+ zmm_shifted_zero = Zmm(++idx);
+ ++idx; // due to extra register used for shifts and compensations
+ }
+ assert(idx == ker_dw_reg_base_idx);
+ }
+
+ if (!jcp.is_depthwise && jcp.ver != ver_vnni) {
+ xor_(reg_scratch, reg_scratch);
+ Reg16 _t16 = reg_scratch.cvt16();
+ mov(_t16, 0x1);
+ vpbroadcastw(vmm_one, _t16);
+ }
+
+ mov(reg_inp, ptr[param1 + GET_OFF(src)]);
+ mov(reg_out, ptr[param1 + GET_OFF(dst)]);
+ mov(reg_ker, ptr[param1 + GET_OFF(filt)]);
+
+ if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
+ int tail_size = jcp.is_depthwise
+ ? jcp.ngroups % jcp.ch_block
+ : jcp.oc_without_padding % jcp.oc_block;
+ int mask = (1 << tail_size) - 1;
+ mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
+ Reg32 regw_tmp = reg_oi.cvt32();
+ mov(regw_tmp, mask);
+ kmovw(ktail_mask, regw_tmp);
+ }
+ if (jcp.is_fast_depthwise) {
+ // prepare mask register for blending weights
+ mov(reg_scratch, 0x8888444422221111);
+ kmovq(kblend_mask, reg_scratch);
+ // load permute indices from data section
+ mov(reg_scratch, permute_index_table);
+ vmovdqu32(zmm_permute, ptr[reg_scratch]);
+ }
+
+ int r_pad = nstl::max(0, (jcp.ow - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1)
+ - (jcp.iw + jcp.l_pad - 1));
+ int n_oi = jcp.ow / jcp.ur_w;
+ int r_pad1 = (jcp.ur_w * n_oi - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1);
+
+ if (jcp.nb_ow == 1) {
+ if (r_pad1 > 0 || jcp.ur_w_tail == 0)
+ n_oi--;
+
+ xor_(reg_oi, reg_oi);
+ if (jcp.ow == jcp.ur_w) {
+ icb_loop(jcp.ur_w, jcp.l_pad, r_pad, true);
+ } else {
+ if (n_oi == 0) {
+ icb_loop(jcp.ur_w, jcp.l_pad, r_pad1, jcp.ur_w_tail == 0);
+ add(reg_inp, inp_shift_pad);
+ add(reg_out, out_shift);
+ if (jcp.ur_w_tail != 0) {
+ icb_loop(jcp.ur_w_tail, 0, r_pad, true);
+ }
+ } else {
+ if (jcp.l_pad > 0) {
+ icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
+ add(reg_inp, inp_shift_pad);
+ add(reg_out, out_shift);
+
+ inc(reg_oi);
+ }
+ if ((jcp.l_pad <= 0 && n_oi > 0) || (jcp.l_pad > 0 && n_oi > 1))
+ {
+ Label ow_loop_label;
+ L(ow_loop_label); {
+ icb_loop(jcp.ur_w, 0, 0, false);
+ add(reg_inp, inp_shift);
+ add(reg_out, out_shift);
+
+ inc(reg_oi);
+ cmp(reg_oi, n_oi);
+ jl(ow_loop_label, T_NEAR);
+ }
+ }
+ if (r_pad1 > 0 || jcp.ur_w_tail == 0) {
+ icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
+ add(reg_inp, inp_shift);
+ add(reg_out, out_shift);
+ }
+ if (jcp.ur_w_tail != 0) {
+ icb_loop(jcp.ur_w_tail, 0, r_pad, true);
+ }
+ }
+ }
+ } else {
+ // ow block is only processed.
+ // Number of block is passed as parameter owb,
+ // and padding processing depends on this number.
+ Label end_label, last_oi_label, middle_ow_blocks_label, tail_label,
+ oi_loop_label, oi_loop_end_label;
+
+ assert(jcp.ow_block % jcp.ur_w == 0);
+ int n_oi_not_last_ow_block = jcp.ow_block / jcp.ur_w;
+ // to simplify code (and general regs usage),
+ // size of ow block must be >= 2 * ur_w
+ assert(n_oi_not_last_ow_block > 1);
+ int n_oi_next_last_ow_block = n_oi_not_last_ow_block;
+ int n_oi_first_ow_block = n_oi_not_last_ow_block;
+ int n_oi_last_ow_block
+ = (jcp.ow - jcp.ow_block * (jcp.nb_ow - 1)) / jcp.ur_w;
+ // prepare right padding
+ bool next_last_ow_block_padded = r_pad1 > 0 && n_oi_last_ow_block == 0;
+ bool first_ow_block_padded
+ = next_last_ow_block_padded && jcp.nb_ow == 2;
+ bool last_ow_block_padded
+ = (r_pad1 > 0 || jcp.ur_w_tail == 0) && n_oi_last_ow_block > 0;
+
+ if (last_ow_block_padded) n_oi_last_ow_block--;
+ else if (first_ow_block_padded) n_oi_first_ow_block--;
+ else if (next_last_ow_block_padded) n_oi_next_last_ow_block--;
+
+ mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
+ cmp(reg_owb, 0); // is that the first ow-block ?
+ jg(middle_ow_blocks_label, T_NEAR);
+
+ // the first ow block, compute left padding
+ mov(reg_oi, n_oi_first_ow_block);
+ if (jcp.l_pad > 0) {
+ icb_loop(jcp.ur_w, jcp.l_pad, 0, false);
+ add(reg_inp, inp_shift_pad);
+ add(reg_out, out_shift);
+
+ dec(reg_oi);
+ }
+ jmp(oi_loop_label, T_NEAR);
+
+ // middle or last ow block entry
+ L(middle_ow_blocks_label);
+
+ if (jcp.l_pad > 0) {
+ // just to consider left padding, not compute
+ add(reg_inp, inp_shift_pad_second_block);
+ }
+
+ // set number of iteration for oi-loop
+ if (n_oi_last_ow_block != n_oi_not_last_ow_block) {
+ cmp(reg_owb, jcp.nb_ow - 1); // last ow-block ?
+ mov(reg_oi, n_oi_last_ow_block);
+ je(oi_loop_label, T_NEAR);
+ }
+
+ if (n_oi_next_last_ow_block != n_oi_not_last_ow_block) {
+ cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
+
+ mov(reg_oi, n_oi_next_last_ow_block);
+ je(oi_loop_label, T_NEAR);
+ }
+ mov(reg_oi, n_oi_not_last_ow_block); // other middle ow-blocks
+
+ // oi loop w/o padding
+ L(oi_loop_label); {
+ cmp(reg_oi, 0);
+ jle(oi_loop_end_label, T_NEAR);
+
+ icb_loop(jcp.ur_w, 0, 0, false);
+
+ add(reg_inp, inp_shift);
+ add(reg_out, out_shift);
+ dec(reg_oi);
+
+ jmp(oi_loop_label, T_NEAR);
+ }
+ L(oi_loop_end_label);
+
+ mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
+ cmp(reg_owb, 0); // first ow-block ?
+ if (first_ow_block_padded)
+ je(last_oi_label, T_NEAR);
+ else
+ je(end_label, T_NEAR);
+
+ cmp(reg_owb, jcp.nb_ow - 2); // next to last ow-block ?
+ jl(end_label, T_NEAR);
+ if (next_last_ow_block_padded)
+ je(last_oi_label, T_NEAR);
+ else
+ je(end_label, T_NEAR);
+
+ // that is last block
+ if (!last_ow_block_padded)
+ jmp(tail_label, T_NEAR);
+
+ // last oi block with right padding
+ L(last_oi_label);
+ icb_loop(jcp.ur_w, 0, r_pad1, jcp.ur_w_tail == 0);
+ add(reg_inp, inp_shift);
+ add(reg_out, out_shift);
+
+ mov(reg_owb, ptr[param1 + GET_OFF(owb)]);
+ cmp(reg_owb, jcp.nb_ow - 1); // last ow_block?
+ jl(end_label, T_NEAR);
+
+ // ur_w tail
+ L(tail_label);
+ if (jcp.ur_w_tail != 0) {
+ icb_loop(jcp.ur_w_tail, 0, r_pad, true);
+ }
+ L(end_label);
+ }
+ postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+
+ if (jcp.is_fast_depthwise) {
+ align(64);
+ L(permute_index_table);
+ const uint32_t _idx[]
+ = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 };
+ for (size_t i = 0; i < sizeof(_idx) / sizeof(_idx[0]); ++i)
+ dd(_idx[i]);
+ }
+}
+
+bool jit_avx512_core_x8s8s32x_fwd_kernel::post_ops_ok(
+ jit_conv_conf_t &jcp, const primitive_attr_t &attr)
+{
+ using namespace primitive_kind;
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+
+ switch (p.len_) {
+ case 0: return true;
+ case 1: return is_eltwise(0) || p.contain(sum, 0);
+ case 2: return (p.contain(sum, 0) && is_eltwise(1)) ||
+ (p.contain(sum, 1) && is_eltwise(0));
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, memory_desc_t &src_md,
+ memory_desc_t &weights_md, memory_desc_t &dst_md,
+ memory_desc_t &bias_md, const primitive_attr_t &attr,
+ int nthreads)
+{
+ using namespace prop_kind;
+
+ const memory_desc_wrapper src_d(&src_md);
+ const memory_desc_wrapper weights_d(&weights_md);
+ const memory_desc_wrapper dst_d(&dst_md);
+ const memory_desc_wrapper bias_d(&bias_md);
+
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ int ndims = src_d.ndims();
+ bool is_1d = ndims == 3;
+
+ if (!(mayiuse(avx512_core)
+ && one_of(src_d.data_type(), data_type::u8, data_type::s8)
+ && weights_d.data_type() == data_type::s8
+ && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
+ data_type::s8, data_type::u8)))
+ return status::unimplemented;
+
+ jcp = zero<decltype(jcp)>();
+ jcp.ndims = ndims;
+ jcp.prop_kind = cd.prop_kind;
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.ic_without_padding = jcp.ic;
+ jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
+ jcp.iw = src_d.dims()[ndims - 1];
+ jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
+ jcp.ow = dst_d.dims()[ndims - 1];
+ jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
+ jcp.kw = weights_d.dims()[with_groups + ndims - 1];
+ jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
+ jcp.l_pad = cd.padding[0][ndims - 3];
+ jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
+ jcp.stride_w = cd.strides[ndims - 3];
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ jcp.ur_h = 1; /* no code-unrolling by h so far */
+
+ jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
+ jcp.dilate_w = cd.dilates[ndims - 3];
+
+ jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
+ jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.ic, jcp.oc);
+
+ if (jcp.is_depthwise) {
+ jcp.ch_block = 16;
+ jcp.ic_block = 1;
+ jcp.oc_block = 1;
+ } else {
+ jcp.ch_block = 1;
+ jcp.ic_block = 16;
+ jcp.oc_block = 16;
+
+ if (jcp.ngroups == 1) {
+ /* For non grouped convolutions, pad channels by 16 if needed */
+ jcp.oc = rnd_up(jcp.oc, jcp.oc_block);
+ jcp.ic = rnd_up(jcp.ic, jcp.ic_block);
+ } else if (!is_1d && jcp.ngroups != 1 && jcp.ic % jcp.ic_block != 0) {
+ /* For grouped convolutions, MKL-DNN doesn't support padding.
+ Use Ymm when channels per group is multiple of 8,
+ Xmm when channels per group is multiple of 4 */
+ jcp.ic_block = jcp.ic % 8 == 0 ? 8 : 4;
+ jcp.oc_block = jcp.ic_block;
+ }
+ if (jcp.ic % jcp.ic_block !=0 || jcp.oc % jcp.oc_block != 0)
+ return status::unimplemented;
+ }
+
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+ - (jcp.ih + jcp.t_pad - 1);
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise)
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+
+ jcp.ver = mayiuse(avx512_core_vnni) ? ver_vnni : ver_avx512_core;
+ jcp.is_fast_depthwise = true && jcp.is_depthwise && jcp.ver == ver_vnni
+ && jcp.ngroups % jcp.ch_block == 0; // for groups not multiple of 16
+ // would require byte masking
+ // for load from src
+ jcp.is_resrc_depthwise = jcp.is_depthwise && jcp.stride_w < jcp.kw
+ && jcp.kw < 4 && jcp.dilate_w == 0;
+ if (jcp.is_depthwise) {
+ jcp.max_regs_ur = 31 - jcp.is_fast_depthwise - !jcp.is_resrc_depthwise
+ - 2 * jcp.signed_input - (jcp.ver != ver_vnni);
+ } else {
+ jcp.max_regs_ur = jcp.ver == ver_vnni ? 31 : 28;
+ }
+
+ auto set_or_check_wei_format = [&]() {
+ using namespace format_tag;
+ format_tag_t wei_tag;
+ if (jcp.ic_block == 16 || jcp.ch_block == 16) {
+ if (is_1d) {
+ wei_tag = with_groups
+ ? jcp.is_depthwise ? Goiw16g : gOIw4i16o4i
+ : OIw4i16o4i;
+ } else {
+ wei_tag = with_groups
+ ? jcp.is_depthwise ? Goihw16g : gOIhw4i16o4i
+ : OIhw4i16o4i;
+ }
+ } else if (with_groups && jcp.ic_block == 8) {
+ wei_tag = gOIhw2i8o4i;
+ } else
+ wei_tag = gOIhw4o4i;
+
+ memory_desc_t want_wei_md = weights_md;
+ memory_desc_init_by_tag(want_wei_md, wei_tag);
+ if (jcp.signed_input) {
+ want_wei_md.extra.flags = 0
+ | memory_extra_flags::compensation_conv_s8s8
+ | memory_extra_flags::scale_adjust;
+ want_wei_md.extra.compensation_mask = (1 << 0)
+ + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
+ want_wei_md.extra.scale_adjust =
+ mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
+ }
+
+ if (weights_md.format_kind == format_kind::any) {
+ weights_md = want_wei_md;
+ return true;
+ }
+
+ return weights_md == want_wei_md;
+ };
+
+ if (!set_or_check_wei_format())
+ return status::unimplemented;
+
+ format_tag_t dat_tag = utils::pick(ndims - 3,
+ format_tag::nwc, format_tag::nhwc);
+
+ if (src_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(src_md, dat_tag));
+ jcp.src_tag = dat_tag;
+ } else {
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ }
+ if (jcp.src_tag != dat_tag)
+ return status::unimplemented;
+
+ if (dst_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
+ jcp.dst_tag = dat_tag;
+ } else {
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+ }
+ if (jcp.dst_tag != dat_tag)
+ return status::unimplemented;
+
+ if (jcp.with_bias) {
+ if (bias_d.format_kind() == format_kind::any)
+ CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
+ }
+
+ jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+ jcp.dst_dt = cd.dst_desc.data_type;
+
+ jcp.typesize_in = types::data_type_size(src_d.data_type());
+ jcp.typesize_out = types::data_type_size(dst_d.data_type());
+ jcp.typesize_bia = jcp.with_bias
+ ? types::data_type_size(bias_d.data_type())
+ : 0;
+
+ jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+
+ // Try to use 4 channel-groups at a time to avoid false sharing (depthwise)
+ int nb_ch_blocking = 4;
+ for ( /* init above */ ; nb_ch_blocking > 1; nb_ch_blocking--)
+ if (jcp.nb_ch % nb_ch_blocking == 0)
+ break;
+ jcp.nb_ch_blocking = jcp.is_depthwise ? nb_ch_blocking : 1;
+
+ // If OC blocking is incommensurate with the number of OC blocks (general
+ // requirement for all convolutions), or if it results in an unrolling
+ // factor smaller than the left padding (special requirement for SSD:fc6),
+ // then search for a smaller OC blocking that satisfies both constraints.
+ auto is_oc_blocking_ok = [&](int block) {
+ int ur_w = nstl::min(jcp.ow, jcp.max_regs_ur / (block + 1));
+ return jcp.nb_oc % block == 0
+ && jcp.l_pad <= ur_w && jcp.ow % ur_w != 1;
+ };
+
+ // choose nb_oc work chunk size for distribution within threads
+ int max_threading_nb_oc_chunk = 4;
+ // Performance improvements for googlenet_v3 and resnet_50 with mb = 1;
+ // TODO: generalize this condition and rewrite it in appropriate manner
+ if (jcp.ver == ver_vnni && jcp.mb == 1 && jcp.kh == 3 && jcp.kw == 3
+ && jcp.stride_w == 1 && jcp.ic % 64 == 0)
+ max_threading_nb_oc_chunk = 2;
+ jcp.nb_oc_blocking_thr_chunk =
+ nstl::min(max_threading_nb_oc_chunk, jcp.nb_oc);
+ for (; jcp.nb_oc_blocking_thr_chunk > 1; jcp.nb_oc_blocking_thr_chunk--) {
+ if (is_oc_blocking_ok(jcp.nb_oc_blocking_thr_chunk))
+ break;
+ }
+
+ // choose oc blocking for computational kernel
+ jcp.nb_oc_blocking = jcp.nb_oc_blocking_thr_chunk;
+ // Performance improvements for googlenet_v3 with mb = 1;
+ // TODO: generalize this condition and rewrite it in appropriate manner
+ const int size_treshold_for_nb_oc_blocking_reduction = 17;
+ if (jcp.mb == 1 && jcp.ow <= size_treshold_for_nb_oc_blocking_reduction
+ && jcp.stride_w == 1
+ && !(jcp.kh == 1 && jcp.kw == 3)
+ && !(jcp.kh >= 7 && jcp.oc % 64 == 0)) {
+ const int max_nb_oc_blocking = 2;
+ jcp.nb_oc_blocking = nstl::min(max_nb_oc_blocking, jcp.nb_oc);
+ for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
+ if (jcp.nb_oc_blocking_thr_chunk % jcp.nb_oc_blocking == 0
+ && is_oc_blocking_ok(jcp.nb_oc_blocking))
+ break;
+ }
+
+ if (jcp.is_resrc_depthwise)
+ jcp.ur_w = (jcp.max_regs_ur - jcp.kw + jcp.stride_w)
+ / (jcp.nb_ch_blocking + jcp.stride_w);
+ else
+ jcp.ur_w
+ = jcp.max_regs_ur / (jcp.is_depthwise ? jcp.nb_ch_blocking
+ : jcp.nb_oc_blocking + 1);
+ if (jcp.ow < jcp.ur_w)
+ jcp.ur_w = jcp.ow;
+ jcp.ur_w_tail = jcp.ow % jcp.ur_w;
+
+ jcp.ow_block = jcp.ow;
+ int base_work_amount = jcp.mb * jcp.nb_ch * jcp.oh
+ * (jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk);
+ float best_thr_eff
+ = (float)base_work_amount / rnd_up(base_work_amount, nthreads);
+ int max_nb_ow = div_up(jcp.ow, 2 * jcp.ur_w);
+ for (int nb_ow = 1; nb_ow <= max_nb_ow; nb_ow++) {
+ int ow_block
+ = nstl::min(rnd_up(div_up(jcp.ow, nb_ow), jcp.ur_w), jcp.ow);
+ if (ow_block < jcp.nb_oc_blocking_thr_chunk * jcp.oc_block
+ && best_thr_eff > 0.8f)
+ break;
+ if (div_up(jcp.ow, ow_block) != nb_ow)
+ continue;
+ auto work_amount = base_work_amount * nb_ow;
+ float thr_eff = (float)work_amount / rnd_up(work_amount, nthreads);
+ if (ow_block >= 2 * jcp.ur_w && thr_eff > 1.1f * best_thr_eff) {
+ jcp.ow_block = ow_block;
+ best_thr_eff = thr_eff;
+ }
+ if (best_thr_eff > 0.9f)
+ break;
+ }
+ jcp.nb_ow = div_up(jcp.ow, jcp.ow_block);
+
+ bool args_ok = true
+ && jcp.oc % jcp.oc_block == 0
+ && jcp.l_pad <= jcp.ur_w
+ && IMPLICATION(!jcp.is_1stconv, jcp.ic % jcp.ic_block == 0);
+ if (!args_ok)
+ return status::unimplemented;
+
+ int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1)
+ - (jcp.iw + jcp.l_pad - 1));
+ if (r_pad_no_tail > jcp.ur_w)
+ return status::unimplemented;
+
+ pick_loop_order(jcp, nthreads);
+
+ jcp.nb_ic_L2 = jcp.nb_ic;
+
+ const auto &oscales = attr.output_scales_;
+ jcp.is_oc_scale = oscales.mask_ == 1 << 1;
+
+ assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
+
+ jcp.wei_adj_scale =
+ (weights_d.extra().flags | memory_extra_flags::scale_adjust)
+ ? weights_d.extra().scale_adjust : 1.f;
+
+ return status::success;
+}
+
+void jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr) {
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ dim_t count = nstl::max(attr.output_scales_.count_, (dim_t)jcp.ic_block);
+ scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
+ }
+}
+
+template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Zmm>;
+template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Ymm>;
+template struct _jit_avx512_core_x8s8s32x_fwd_kernel<Xmm>;
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp
new file mode 100644
index 0000000000..d8a05ad53e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_conv_kernel.hpp
@@ -0,0 +1,239 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
+#define CPU_JIT_AVX512_CORE_X8S8S32X_CONV_KERNEL_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template<typename Vmm>
+struct _jit_avx512_core_x8s8s32x_fwd_kernel : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(_jit_avx512_core_x8s8s32x_conv_fwd_ker_t)
+
+ enum { STATE_FIRST_DST_LOAD = 0x1U };
+
+ _jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
+ const primitive_attr_t &attr) : jcp(ajcp), attr_(attr),
+ eltwise_injector_(nullptr)
+ {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
+ this, jcp.eltwise);
+
+ generate();
+ jit_ker_ = (void (*)(jit_conv_call_s *))getCode();
+ }
+
+ ~_jit_avx512_core_x8s8s32x_fwd_kernel() {
+ delete eltwise_injector_;
+ }
+
+ jit_conv_conf_t jcp;
+ const primitive_attr_t &attr_;
+ void (*jit_ker_)(jit_conv_call_s *);
+
+private:
+ jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
+
+ enum {
+ typesize = sizeof(float),
+ ker_reg_base_idx = 28,
+ ker_dw_reg_base_idx = 30,
+ };
+ typedef enum {
+ no_last_block,
+ last_ic_block,
+ last_sp_block,
+ } ic_block_t;
+
+ /* data regs */
+ const Xbyak::Reg64 reg_ptr_scales = rax;
+ const Xbyak::Reg64 reg_inp = r8;
+ const Xbyak::Reg64 reg_ker = r9;
+ const Xbyak::Reg64 reg_out = r10;
+ const Xbyak::Reg64 aux_reg_inp = r11;
+ const Xbyak::Reg64 reg_ptr_sum_scale = r11;
+ const Xbyak::Reg64 aux_reg_ker = r12;
+ const Xbyak::Reg64 reg_compensation = r14;
+ /* counter regs */
+ const Xbyak::Reg64 reg_bias_alpha = abi_not_param1;
+ const Xbyak::Reg64 reg_oi = rbx;
+ const Xbyak::Reg64 reg_bias = rdx;
+ const Xbyak::Reg64 reg_oc_blocks = rsi;
+ const Xbyak::Reg64 reg_owb = aux_reg_ker;
+ const Xbyak::Reg64 reg_scratch = reg_compensation;
+ const Xbyak::Reg64 reg_kj = reg_ptr_scales;
+ const Xbyak::Reg64 reg_overflow = reg_ptr_scales;
+ const Xbyak::Reg64 reg_icb = reg_bias;
+
+ const Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
+ const Xbyak::Opmask kblend_mask = Xbyak::Opmask(3);
+
+ const Vmm vmm_wei = Vmm(31);
+ /* used during bias section of store_output */
+ const Vmm vmm_comp = Vmm(30); // only for signed input
+ const Vmm vmm_bias = Vmm(31);
+ /* used during post_op sum section of store_output */
+ const Vmm vmm_prev_dst = Vmm(31);
+ /* used during write-out section of store_output */
+ const Vmm vmm_zero = Vmm(31);
+
+ /* used in compute_ker (but set during prepare_output) */
+ const Vmm vmm_shift = vmm_comp; // only for signed input
+ /* used in compute_ker (but only for pre-VNNI machines) */
+ const Vmm vmm_tmp = Vmm(28); // not used for depthwise
+ const Vmm vmm_one = Vmm(29); // set at start of kernel, not used for depthwise.
+
+ /* registers use only for depthwise
+ groups are always blocked by 16(padded if needed),
+ hence use only Zmm registers */
+ const Xbyak::Zmm zmm_wei = Xbyak::Zmm(31);
+ Xbyak::Zmm zmm_tmp;
+ Xbyak::Zmm zmm_src;
+ Xbyak::Zmm zmm_shifted_zero;
+ Xbyak::Zmm zmm_permute;
+
+ Vmm vmm_out(int i_ur, int i_oc) {
+ int idx = i_ur + i_oc * jcp.ur_w;
+ assert(idx < (jcp.is_depthwise
+ ? ker_dw_reg_base_idx : ker_reg_base_idx));
+ return Vmm(idx);
+ }
+ Xbyak::Zmm zmm_out(int i_ur, int i_oc) {
+ int idx = i_ur + i_oc * jcp.ur_w;
+ assert(idx < (jcp.is_depthwise
+ ? ker_dw_reg_base_idx : ker_reg_base_idx));
+ return Xbyak::Zmm(idx);
+ }
+ Vmm vmm_inp(int i_ic, int nb_x_blocking) {
+ int idx = i_ic + nb_x_blocking * jcp.ur_w;
+ assert(idx < 31);
+ return Vmm(idx);
+ }
+ Xbyak::Zmm zmm_inp(int i_ic, int nb_x_blocking) {
+ int idx = i_ic + nb_x_blocking * jcp.ur_w;
+ assert(idx < 31);
+ return Xbyak::Zmm(idx);
+ }
+ Vmm vmm_bias_alpha() {
+ int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+ return Vmm(nb_c_block * jcp.ur_w);
+ }
+ Xbyak::Xmm xmm_bias_alpha() {
+ int nb_c_block = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+ return Xbyak::Xmm(nb_c_block * jcp.ur_w);
+ }
+ int get_ow_start(int ki, int pad_l) {
+ return nstl::max(0,
+ utils::div_up(pad_l - ki * (jcp.dilate_w + 1), jcp.stride_w));
+ }
+ int get_ow_end(int ur_w, int ki, int pad_r) {
+ return ur_w - nstl::max(0, utils::div_up(pad_r
+ - (jcp.kw - 1 - ki)
+ * (jcp.dilate_w + 1),
+ jcp.stride_w));
+ }
+
+ bool maybe_eltwise(int position);
+ void prepare_output(int ur_w);
+ void store_output(int ur_w, bool last_oc_block_flag);
+ void compute_ker_dw(
+ int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag, bool h_padded);
+ void compute_ker(int ur_w, int pad_l, int pad_r,
+ ic_block_t last_ic_block_flag, bool h_padded = false);
+ void compute_eltwise(int ur_w);
+ void kh_loop(int ur_w, int pad_l, int pad_r, ic_block_t last_ic_block_flag);
+ void icb_loop(
+ int ur_w, int pad_l, int pad_r, bool is_last_spatial_block);
+ void generate();
+ void cvt2ps(data_type_t type_in, Vmm ymm_in, const Xbyak::Operand &op,
+ bool mask_flag);
+ const Vmm vmm_mask(const Vmm vmm_in, bool mask_flag, bool store = false);
+};
+
+struct jit_avx512_core_x8s8s32x_fwd_kernel {
+
+ jit_avx512_core_x8s8s32x_fwd_kernel(jit_conv_conf_t ajcp,
+ const primitive_attr_t &attr) :
+ jit_ker(nullptr),
+ zmm_kernel_(nullptr),
+ ymm_kernel_(nullptr),
+ xmm_kernel_(nullptr) {
+ int ch_block = ajcp.is_depthwise ? ajcp.ch_block : ajcp.ic_block;
+ switch (ch_block) {
+ case 16:
+ zmm_kernel_ =
+ new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm>(
+ ajcp, attr);
+ jit_ker = zmm_kernel_->jit_ker_;
+ return;
+ case 8:
+ ymm_kernel_ =
+ new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm>(
+ ajcp, attr);
+ jit_ker = ymm_kernel_->jit_ker_;
+ return;
+ case 4:
+ xmm_kernel_ =
+ new _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm>(
+ ajcp, attr);
+ jit_ker = xmm_kernel_->jit_ker_;
+ return;
+ default:
+ assert(!"invalid channel blocking");
+ }
+ }
+
+ ~jit_avx512_core_x8s8s32x_fwd_kernel() {
+ delete xmm_kernel_;
+ delete ymm_kernel_;
+ delete zmm_kernel_;
+ }
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ memory_desc_t &src_pd,
+ memory_desc_t &weights_pd,
+ memory_desc_t &dst_pd,
+ memory_desc_t &bias_pd,
+ const primitive_attr_t &attr,
+ int nthreads);
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
+
+ void (*jit_ker)(jit_conv_call_s *);
+ _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Zmm> *zmm_kernel_;
+ _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Ymm> *ymm_kernel_;
+ _jit_avx512_core_x8s8s32x_fwd_kernel<Xbyak::Xmm> *xmm_kernel_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp
new file mode 100644
index 0000000000..cdbf333d5e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.cpp
@@ -0,0 +1,423 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_avx512_core_x8s8s32x_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+using namespace nstl;
+
+using jit_conv_ker_t = void (*)(jit_conv_call_s *);
+
+#define wht_blk_off(d, g, ...) \
+ (pd()->with_groups() \
+ ? (d).blk_off((g), __VA_ARGS__) \
+ : (d).blk_off(__VA_ARGS__))
+
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
+ dst_type>::execute_forward_1d(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const size_t bia_dt_size = pd()->with_bias()
+ ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
+
+ const auto &jcp = pd()->jcp_;
+ assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
+ assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
+
+ const float *oscales = pd()->attr()->output_scales_.scales_;
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ auto local_scales = scratchpad(ctx).template get<float>(
+ key_conv_adjusted_scales);
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / pd()->jcp_.wei_adj_scale;
+ if (count == 1) {
+ utils::array_set(local_scales, oscales[0] * factor, 16);
+ } else {
+ for (size_t c = 0; c < count; c++)
+ local_scales[c] = oscales[c] * factor;
+ }
+ oscales = local_scales;
+ }
+
+ size_t offset = weights_d.size() - weights_d.additional_buffer_size();
+ auto w = const_cast<wei_data_t *>(weights);
+ int32_t* compensation = (jcp.signed_input)
+ ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
+ int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
+ int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
+ int group_block = jcp.ch_block;
+ int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.nb_ow;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+
+ int start{ 0 }, end{ 0 };
+ balance211(work_amount, nthr, ithr, start, end);
+
+ auto p = jit_conv_call_s();
+
+ int n{ 0 }, gg{ 0 }, occ{ 0 }, owb{ 0 };
+ switch (jcp.loop_order) {
+ case loop_cwgn:
+ nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, gg,
+ nb_groups, n, jcp.mb);
+ break;
+ case loop_gncw:
+ nd_iterator_init(start, gg, nb_groups, n, jcp.mb, occ, oc_chunks,
+ owb, jcp.nb_ow);
+ break;
+ case loop_ngcw:
+ nd_iterator_init(start, n, jcp.mb, gg, nb_groups, occ, oc_chunks,
+ owb, jcp.nb_ow);
+ break;
+ case loop_nwcg:
+ nd_iterator_init(start, n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks,
+ gg, nb_groups);
+ break;
+ default: assert(!"unsupported loop order");
+ }
+ while (start < end) {
+ int ocb = occ * jcp.nb_oc_blocking;
+ int gb = gg * jcp.nb_ch_blocking;
+ int g = gb * group_block;
+ int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
+ int g_ic = g * jcp.nb_ic * jcp.ic_block;
+ int ow_s = owb * jcp.ow_block;
+ int iw_s = ow_s * jcp.stride_w;
+
+ p.bias = bias ? bias + (bias_d.blk_off(g_oc) * bia_dt_size) : 0;
+ p.compensation = (jcp.signed_input) ? compensation + g_oc : 0;
+ p.dst = dst + dst_d.blk_off(n, g_oc, ow_s);
+ p.src = src + src_d.blk_off(n, g_ic, iw_s);
+ p.filt = weights + wht_blk_off(weights_d, gb, ocb, 0);
+ p.scales = &oscales[jcp.is_oc_scale * g_oc];
+ p.oc_blocks = jcp.is_depthwise ? gb : ocb;
+ p.kh_padding = jcp.kh;
+ p.t_overflow = 0;
+ p.b_overflow = 0;
+ p.owb = owb;
+
+ kernel_->jit_ker(&p);
+
+ ++start;
+ switch (jcp.loop_order) {
+ case loop_cwgn:
+ nd_iterator_step(occ, oc_chunks, owb, jcp.nb_ow, gg, nb_groups,
+ n, jcp.mb);
+ break;
+ case loop_gncw:
+ nd_iterator_step(gg, nb_groups, n, jcp.mb, occ, oc_chunks, owb,
+ jcp.nb_ow);
+ break;
+ case loop_ngcw:
+ nd_iterator_step(n, jcp.mb, gg, nb_groups, occ, oc_chunks, owb,
+ jcp.nb_ow);
+ break;
+ case loop_nwcg:
+ nd_iterator_step(n, jcp.mb, owb, jcp.nb_ow, occ, oc_chunks, gg,
+ nb_groups);
+ break;
+ default: assert(!"unsupported loop order");
+ }
+ }
+ });
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
+ dst_type>::execute_forward_2d(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const size_t bia_dt_size = pd()->with_bias()
+ ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
+
+ const auto &jcp = pd()->jcp_;
+ assert(jcp.ch_block == 1);
+ assert(jcp.nb_ch_blocking == 1);
+ assert(jcp.nb_oc % jcp.nb_oc_blocking == 0);
+ assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
+
+ const float *oscales = pd()->attr()->output_scales_.scales_;
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ auto local_scales = scratchpad(ctx).template get<float>(
+ key_conv_adjusted_scales);
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / pd()->jcp_.wei_adj_scale;
+ if (count == 1) {
+ utils::array_set(local_scales, oscales[0] * factor, 16);
+ } else {
+ for (size_t c = 0; c < count; c++)
+ local_scales[c] = oscales[c] * factor;
+ }
+ oscales = local_scales;
+ }
+
+ size_t offset = weights_d.size() - weights_d.additional_buffer_size();
+ auto w = const_cast<wei_data_t *>(weights);
+ int32_t* compensation = (jcp.signed_input)
+ ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
+ int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking_thr_chunk;
+ int nb_groups = jcp.nb_ch;
+ int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh * jcp.nb_ow;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+
+ int start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ auto p = jit_conv_call_s();
+
+ size_t src_h_stride = src_d.blk_off(0, 0, 1);
+ size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
+ size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
+
+ int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 }, owb{ 0 };
+ switch (jcp.loop_order) {
+ case loop_cwgn:
+ nd_iterator_init(start, occ, oc_chunks, owb, jcp.nb_ow, g,
+ nb_groups, n, jcp.mb, oh_s, jcp.oh);
+ break;
+ case loop_ngcw:
+ nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
+ owb, jcp.nb_ow, oh_s, jcp.oh);
+ break;
+ case loop_nhwcg:
+ nd_iterator_init(start, n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow,
+ occ, oc_chunks, g, nb_groups);
+ break;
+ default: assert(!"unsupported loop order");
+ }
+ while (start < end) {
+ for (int occ1 = 0; occ1 < jcp.nb_oc_blocking_thr_chunk;
+ occ1 += jcp.nb_oc_blocking) {
+ int ocb = occ * jcp.nb_oc_blocking_thr_chunk + occ1;
+ int g_oc = (g * jcp.nb_oc + ocb) * jcp.oc_block;
+
+ int g_ic = g * jcp.nb_ic * jcp.ic_block;
+
+ int work_rem = end - start;
+ int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
+ int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
+ if (jcp.loop_order == loop_nhwcg)
+ oh_e = oh_s + 1; // step instead
+ int ow_s = owb * jcp.ow_block;
+ int iw_s = ow_s * jcp.stride_w;
+
+ auto bias_w = bias
+ ? bias + (bias_d.blk_off(g_oc) * bia_dt_size)
+ : 0;
+ int32_t *compensation_w = (jcp.signed_input)
+ ? compensation + g_oc : 0;
+
+ auto dst_w = dst + dst_d.blk_off(n, g_oc, oh_s, ow_s);
+ auto src_w = src + src_d.blk_off(n, g_ic, ih_s, iw_s);
+ auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
+
+ auto scales = &oscales[jcp.is_oc_scale * g_oc];
+
+ for (int oj = oh_s, ij = ih_s; oj < oh_e;
+ ++oj, ij += jcp.stride_h) {
+ int dilate_h = jcp.dilate_h + 1;
+ int i_t_overflow = nstl::min(jcp.kh,
+ div_up(max(0, -ij), dilate_h));
+ int i_b_overflow = nstl::min(jcp.kh, div_up(
+ max(0, ij - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
+ dilate_h));
+ int kh_padding = nstl::max(0,
+ jcp.kh - i_t_overflow - i_b_overflow);
+
+ size_t wei_stride = (!jcp.signed_input)
+ ? i_t_overflow * wht_h_stride : 0;
+ p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
+ p.dst = dst_w;
+ p.filt = wht_w + wei_stride;
+ p.bias = bias_w;
+ p.compensation = compensation_w;
+ p.oc_blocks = ocb;
+ p.kh_padding = kh_padding;
+ p.scales = scales;
+ p.t_overflow = i_t_overflow;
+ p.b_overflow = i_b_overflow;
+ p.owb = owb;
+
+ kernel_->jit_ker(&p);
+ src_w += src_h_stride * jcp.stride_h;
+ dst_w += dst_h_stride;
+ }
+ }
+ switch (jcp.loop_order) {
+ case loop_cwgn:
+ nd_iterator_jump(start, end, occ, oc_chunks, owb, jcp.nb_ow, g,
+ nb_groups, n, jcp.mb, oh_s, jcp.oh);
+ break;
+ case loop_ngcw:
+ nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
+ oc_chunks, owb, jcp.nb_ow, oh_s, jcp.oh);
+ break;
+ case loop_nhwcg:
+ ++start;
+ nd_iterator_step(n, jcp.mb, oh_s, jcp.oh, owb, jcp.nb_ow, occ,
+ oc_chunks, g, nb_groups);
+ break;
+ default: assert(!"unsupported loop order");
+ }
+ }
+ });
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+void jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type,
+ dst_type>::execute_forward_2d_dw(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const size_t bia_dt_size = pd()->with_bias()
+ ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
+
+ const auto &jcp = pd()->jcp_;
+ assert(jcp.ic_block == 1);
+ assert(jcp.oc_block == 1);
+ assert(jcp.nb_ic == 1);
+ assert(jcp.nb_oc == 1);
+ assert(jcp.nb_oc_blocking == 1);
+ assert(jcp.nb_ch % jcp.nb_ch_blocking == 0);
+
+ const float *oscales = pd()->attr()->output_scales_.scales_;
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ auto local_scales = scratchpad(ctx).template get<float>(
+ key_conv_adjusted_scales);
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / pd()->jcp_.wei_adj_scale;
+ if (count == 1) {
+ utils::array_set(local_scales, oscales[0] * factor, 16);
+ } else {
+ for (size_t c = 0; c < count; c++)
+ local_scales[c] = oscales[c] * factor;
+ }
+ oscales = local_scales;
+ }
+
+ size_t offset = weights_d.size() - weights_d.additional_buffer_size();
+ auto w = const_cast<wei_data_t *>(weights);
+ int32_t* compensation = (jcp.signed_input)
+ ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
+ int nb_groups = jcp.nb_ch / jcp.nb_ch_blocking;
+ int group_block = jcp.ch_block;
+
+ parallel_nd(jcp.mb, jcp.oh, jcp.nb_ow, nb_groups,
+ [&](int n, int oh_s, int owb, int gg) {
+
+ auto p = jit_conv_call_s();
+
+ size_t src_h_stride = src_d.blk_off(0, 0, 1);
+ size_t wht_h_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
+
+ int gb = gg * jcp.nb_ch_blocking;
+ int g = gb * group_block;
+
+ int ih_s = -jcp.t_pad + oh_s * jcp.stride_h;
+ int ow_s = owb * jcp.ow_block;
+ int iw_s = ow_s * jcp.stride_w;
+
+ auto bias_w = bias ? bias + (bias_d.blk_off(g) * bia_dt_size) : 0;
+ int32_t *compensation_w = jcp.signed_input ? compensation + g : 0;
+
+ auto dst_w = dst + dst_d.blk_off(n, g, oh_s, ow_s);
+ auto src_w = src + src_d.blk_off(n, g, ih_s, iw_s);
+ auto wht_w = weights + wht_blk_off(weights_d, gb, 0);
+
+ auto scales = &oscales[jcp.is_oc_scale * g];
+
+ int dilate_h = jcp.dilate_h + 1;
+ int i_t_overflow = nstl::min(jcp.kh, div_up(max(0, -ih_s), dilate_h));
+ int i_b_overflow = nstl::min(jcp.kh,
+ div_up(max(0, ih_s - jcp.ih + (jcp.kh - 1) * dilate_h + 1),
+ dilate_h));
+ int kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow);
+
+ size_t wei_stride = jcp.signed_input ? 0 : i_t_overflow * wht_h_stride;
+ p.src = src_w + i_t_overflow * dilate_h * src_h_stride;
+ p.dst = dst_w;
+ p.filt = wht_w + wei_stride;
+ p.bias = bias_w;
+ p.compensation = compensation_w;
+ p.oc_blocks = gb;
+ p.kh_padding = kh_padding;
+ p.scales = scales;
+ p.t_overflow = i_t_overflow;
+ p.b_overflow = i_b_overflow;
+ p.owb = owb;
+
+ kernel_->jit_ker(&p);
+ });
+}
+
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
+ data_type::s8, data_type::u8>;
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
+ data_type::u8, data_type::u8>;
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
+ data_type::s8, data_type::s8>;
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
+ data_type::u8, data_type::s8>;
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
+ data_type::s8, data_type::s32>;
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
+ data_type::u8, data_type::s32>;
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
+ data_type::s8, data_type::f32>;
+template struct jit_avx512_core_x8s8s32x_convolution_fwd_t<
+ data_type::u8, data_type::f32>;
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp
new file mode 100644
index 0000000000..203ebdf942
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_convolution.hpp
@@ -0,0 +1,115 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP
+#define CPU_JIT_AVX512_CORE_X8S8S32X_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_avx512_core_x8s8s32x_conv_kernel.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t src_type, impl::data_type_t dst_type>
+struct jit_avx512_core_x8s8s32x_convolution_fwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_()
+ {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_int8:", avx512_core, ""),
+ jit_avx512_core_x8s8s32x_convolution_fwd_t<src_type, dst_type>);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(src_type, data_type::s8, data_type::undef,
+ dst_type, data_type::s32)
+ && IMPLICATION(with_bias(), utils::one_of(bias_md_.data_type,
+ data_type::f32, data_type::s32, data_type::s8,
+ data_type::u8))
+ && !has_zero_dim_memory();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx512_core_x8s8s32x_fwd_kernel::init_conf(
+ jcp_, *desc(), src_md_, weights_md_, dst_md_, bias_md_,
+ *attr(), mkldnn_get_max_threads());
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx512_core_x8s8s32x_fwd_kernel::init_scratchpad(scratchpad,
+ jcp_, *attr());
+
+ return status;
+ }
+
+ jit_conv_conf_t jcp_;
+ };
+
+ jit_avx512_core_x8s8s32x_convolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ {
+ kernel_ = new jit_avx512_core_x8s8s32x_fwd_kernel(pd()->jcp_,
+ *pd()->attr());
+ }
+
+ ~jit_avx512_core_x8s8s32x_convolution_fwd_t() { delete kernel_; }
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<data_type::s8>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override
+ {
+ const auto &_pd = pd();
+ if (_pd->ndims() == 3)
+ execute_forward_1d(ctx);
+ else if (_pd->jcp_.is_depthwise)
+ execute_forward_2d_dw(ctx);
+ else
+ execute_forward_2d(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward_1d(const exec_ctx_t &ctx) const;
+ void execute_forward_2d(const exec_ctx_t &ctx) const;
+ void execute_forward_2d_dw(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_avx512_core_x8s8s32x_fwd_kernel *kernel_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp
new file mode 100644
index 0000000000..142af1f541
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.cpp
@@ -0,0 +1,1034 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_avx512_core_x8s8s32x_deconvolution.hpp"
+
+#define GET_OFF(field) offsetof(jit_deconv_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+using namespace Xbyak;
+
+using namespace nstl;
+
+#define wht_blk_off(d, g, ...) \
+ (pd()->with_groups() ? (d).blk_off((g), __VA_ARGS__) : \
+ (d).blk_off(__VA_ARGS__))
+
+status_t jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_conf(
+ jit_conv_conf_t &jcp, const deconvolution_desc_t &cd,
+ memory_desc_t &src_md, memory_desc_t &weights_md,
+ memory_desc_t &dst_md, const bool with_bias,
+ memory_desc_t &bias_md, const primitive_attr_t &attr) {
+ const memory_desc_wrapper src_d(&src_md);
+ const memory_desc_wrapper dst_d(&dst_md);
+ const memory_desc_wrapper weights_d(&weights_md);
+ const memory_desc_wrapper bias_d(&bias_md);
+
+ if (!(mayiuse(avx512_core)
+ && one_of(src_d.data_type(), data_type::u8, data_type::s8)
+ && weights_d.data_type() == data_type::s8
+ && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
+ data_type::s8, data_type::u8)))
+ return status::unimplemented;
+
+ jcp = zero<decltype(jcp)>();
+
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ jcp.signed_input = src_d.data_type() == data_type::s8;
+ const int ndims = jcp.ndims = dst_d.ndims();
+ const bool is_1d = ndims == 3;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+ jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
+ jcp.ic_without_padding = src_d.dims()[1] / jcp.ngroups;
+ jcp.is_depthwise = true && with_groups
+ && utils::everyone_is(1, jcp.ic_without_padding,
+ jcp.oc_without_padding);
+
+ /* TODO: future work, on hold until depthwise specialized kernel is
+ * implemented. */
+ if (jcp.is_depthwise && jcp.signed_input)
+ return status::unimplemented;
+
+ format_tag_t dat_tag = utils::pick(ndims - 3,
+ format_tag::nwc, format_tag::nhwc);
+
+ if (src_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(src_md, dat_tag));
+ jcp.src_tag = dat_tag;
+ } else {
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ }
+ if (jcp.src_tag != dat_tag)
+ return status::unimplemented;
+
+ if (dst_d.format_kind() == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(dst_md, dat_tag));
+ jcp.dst_tag = dat_tag;
+ } else {
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+ }
+ if (jcp.dst_tag != dat_tag)
+ return status::unimplemented;
+
+ auto set_or_check_wei_format = [&]() {
+ using namespace format_tag;
+
+ format_tag_t wei_tag = is_1d
+ ? (jcp.is_depthwise
+ ? Goiw16g : (with_groups ? gOIw4i16o4i : OIw4i16o4i))
+ : (jcp.is_depthwise
+ ? Goihw16g : (with_groups ? gOIhw4i16o4i : OIhw4i16o4i));
+
+ memory_desc_t want_wei_md = weights_md;
+ memory_desc_init_by_tag(want_wei_md, wei_tag);
+ if (jcp.signed_input && !jcp.is_depthwise) {
+ want_wei_md.extra.flags = 0
+ | memory_extra_flags::compensation_conv_s8s8
+ | memory_extra_flags::scale_adjust;
+ want_wei_md.extra.compensation_mask = (1 << 0)
+ + (with_groups && !jcp.is_depthwise ? (1 << 1) : 0);
+ want_wei_md.extra.scale_adjust =
+ mayiuse(avx512_core_vnni) ? 1.f : 0.5f;
+ }
+
+ if (weights_md.format_kind == format_kind::any) {
+ weights_md = want_wei_md;
+ return true;
+ }
+
+ return weights_md == want_wei_md;
+ };
+
+ if (!set_or_check_wei_format())
+ return status::unimplemented;
+
+ jcp.with_bias = with_bias;
+ if (jcp.with_bias) {
+ if (bias_d.format_kind() == format_kind::any)
+ CHECK(memory_desc_init_by_tag(bias_md, format_tag::x));
+ }
+
+ jcp.prop_kind = cd.prop_kind;
+ jcp.mb = src_d.dims()[0];
+ jcp.ih = is_1d ? 1 : src_d.dims()[ndims - 2];
+ jcp.iw = src_d.dims()[ndims - 1];
+ jcp.oh = is_1d ? 1 : dst_d.dims()[ndims - 2];
+ jcp.ow = dst_d.dims()[ndims - 1];
+ jcp.kh = is_1d ? 1 : weights_d.dims()[with_groups + ndims - 2];
+ jcp.kw = weights_d.dims()[with_groups + ndims - 1];
+ jcp.t_pad = is_1d ? 0 : cd.padding[0][ndims - 4];
+ jcp.l_pad = cd.padding[0][ndims - 3];
+ jcp.stride_h = is_1d ? 1 : cd.strides[ndims - 4];
+ jcp.stride_w = cd.strides[ndims - 3];
+
+ if (jcp.is_depthwise) {
+ jcp.ch_block = 16;
+ jcp.oc_block = 1;
+ jcp.ic_block = 1;
+ } else {
+ jcp.ch_block = 1;
+ jcp.oc_block = 16;
+ jcp.ic_block = 16;
+
+ if (jcp.ngroups == 1) {
+ jcp.oc = utils::rnd_up(jcp.oc_without_padding, jcp.oc_block);
+ jcp.ic = utils::rnd_up(jcp.ic_without_padding, jcp.ic_block);
+ }
+ if (jcp.ic % jcp.ic_block != 0)
+ return status::unimplemented;
+ }
+
+ jcp.dilate_h = is_1d ? 0 : cd.dilates[ndims - 4];
+ jcp.dilate_w = cd.dilates[ndims - 3];
+
+ if (!IMPLICATION(jcp.dilate_h, jcp.stride_h == 1)
+ || !IMPLICATION(jcp.dilate_w, jcp.stride_w == 1))
+ return status::unimplemented;
+
+ /* padding: bottom and right */
+ jcp.b_pad = (jcp.ih - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+ - (jcp.oh + jcp.t_pad - 1);
+ jcp.r_pad = (jcp.iw - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1)
+ - (jcp.ow + jcp.l_pad - 1);
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise)
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+
+ jcp.ver = ver_avx512_core;
+ if (mayiuse(avx512_core_vnni))
+ jcp.ver = ver_vnni;
+ const auto &oscales = attr.output_scales_;
+ jcp.is_oc_scale = oscales.mask_ == 1 << 1;
+
+ assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
+
+ jcp.dst_dt = dst_d.data_type();
+ jcp.bia_dt = jcp.with_bias ? bias_d.data_type() : data_type::undef;
+ jcp.typesize_bia
+ = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
+ jcp.typesize_in = types::data_type_size(src_d.data_type());
+ jcp.typesize_out = types::data_type_size(dst_d.data_type());
+
+ jcp.nb_ch = div_up(jcp.ngroups, jcp.ch_block);
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+
+ /* kernel blocking params */
+ const int regs = jcp.ver == ver_vnni ? 30 : 28;
+ jcp.nb_oc_blocking = nstl::min(4, jcp.nb_oc);
+ for (; jcp.nb_oc_blocking > 1; jcp.nb_oc_blocking--)
+ if (jcp.nb_oc % jcp.nb_oc_blocking == 0
+ && jcp.l_pad <= regs / (jcp.nb_oc_blocking + 1))
+ break;
+
+ jcp.ur_w = regs / (jcp.nb_oc_blocking + 1);
+ int l_overflow = max(
+ 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
+
+ if (jcp.ow < jcp.ur_w) {
+ jcp.ur_w = jcp.ow;
+ jcp.ur_w_tail = 0;
+ } else {
+ for (; jcp.ur_w >= 1; jcp.ur_w--) {
+ /* ur_w should be multiple of stride_w in order
+ to simplify logic for get_ow_start and get_ow_end */
+ bool is_multiple_of_stride = jcp.ur_w % jcp.stride_w == 0;
+
+ /* boundary conditions:
+ These conditions ensure all elements close to boundary
+ are computed in a single call of compute loop */
+ bool left_boundary_covered = jcp.ur_w >= l_overflow * jcp.stride_w;
+ jcp.ur_w_tail = jcp.ow % jcp.ur_w;
+ int r_overflow_no_tail
+ = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
+ - max(0, jcp.r_pad) - jcp.ur_w_tail)
+ / jcp.stride_w);
+ bool right_boundary_covered
+ = jcp.ur_w >= r_overflow_no_tail * jcp.stride_w;
+
+ if (is_multiple_of_stride && left_boundary_covered
+ && right_boundary_covered)
+ break;
+ else if (jcp.ur_w == 1)
+ /* The boundary conditions above are also important
+ to maintain simplicity of calls to icb_loop,
+ if those conditions are not satisfied,
+ then special cases will need to be added
+ to use correct l_overflow/r_overflow values
+ when different iterations of compute loop
+ work on the locations close to boundary.
+ So to keep code simple, return unimplemented
+ for extreme case when a good ur_w cannot be found.
+ */
+ return status::unimplemented;
+ }
+ }
+
+ jcp.wei_adj_scale =
+ (weights_d.extra().flags | memory_extra_flags::scale_adjust)
+ ? weights_d.extra().scale_adjust : 1.f;
+
+ jcp.loop_order = jcp.ngroups > 1 ? loop_ngc : loop_cgn;
+ return status::success;
+}
+
+bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::maybe_eltwise(int position) {
+ using namespace primitive_kind;
+ const auto &p = attr_.post_ops_;
+
+ if (position == 0) {
+ /* eltwise before sum */
+ return p.contain(eltwise, 0);
+ } else if (position == 1) {
+ /* eltwise after sum */
+ return p.contain(sum, 0) && p.contain(eltwise, 1);
+ }
+ return false;
+}
+
+void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_eltwise(int ur_w) {
+ int nb_oc_block
+ = jcp.is_depthwise ? jcp.nb_ch_blocking : jcp.nb_oc_blocking;
+ eltwise_injector_->compute_vector_range(0, nb_oc_block * ur_w);
+}
+
+bool jit_avx512_core_x8s8s32x_deconv_fwd_kernel::post_ops_ok(
+ jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ using namespace primitive_kind;
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+
+ switch (p.len_) {
+ case 0: return true;
+ case 1: return is_eltwise(0) || p.contain(sum, 0);
+ case 2:
+ return (p.contain(sum, 0) && is_eltwise(1))
+ || (p.contain(sum, 1) && is_eltwise(0));
+ default: return false;
+ }
+
+ return false;
+}
+
+void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr) {
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ dim_t count = nstl::max<dim_t>(attr.output_scales_.count_, 16);
+ scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
+ }
+}
+
+void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::compute_ker(int ur_w,
+ int l_overflow, int r_overflow, ker_block_t last_ic_block_flag,
+ bool h_padded) {
+
+ const int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
+ const int ur_w_stride = jcp.signed_input ? 1 : jcp.stride_w;
+
+ auto src_offset = [=](int oj, int icb, int ki) {
+ return jcp.typesize_in
+ * (((oj + jcp.l_pad - ki * (jcp.dilate_w + 1)) / jcp.stride_w)
+ * jcp.ngroups * jcp.ic_without_padding
+ + icb * 4);
+ };
+
+ auto kernel_offset = [=](int ocb, int icb, int ki) {
+ return jcp.typesize_in
+ * (ocb * jcp.nb_ic * jcp.kh * jcp.kw * ch_block_all
+ + icb * jcp.oc_block * jcp.ic_block / 4
+ + ki * ch_block_all);
+ };
+
+ auto compute = [=](zmm_t vreg_acc, zmm_t vreg_wei, zmm_t vreg_src) {
+ if (jcp.ver == ver_vnni) {
+ vpdpbusd(vreg_acc, vreg_src, vreg_wei);
+ } else if (jcp.is_depthwise) {
+ vpmulld(zmm_tmp, vreg_src, vreg_wei);
+ vpaddd(vreg_acc, vreg_acc, zmm_tmp);
+ } else {
+ vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
+ vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
+ vpaddd(vreg_acc, vreg_acc, zmm_tmp);
+ }
+ };
+
+ for (int ki = 0; ki < jcp.kw; ki++) {
+
+ int jj_start = get_ow_start(ki, l_overflow);
+ int jj_end = get_ow_end(ur_w, ki, r_overflow);
+
+ int _start = (jcp.signed_input) ? 0 : jj_start;
+ int _end = (jcp.signed_input) ? ur_w : jj_end;
+
+ int tail_size = jcp.ic_without_padding % 4;
+ int n_ic_blocks = jcp.is_depthwise ?
+ 1 :
+ (last_ic_block_flag & ~no_last_block ?
+ div_up(jcp.ic_without_padding % jcp.ic_block,
+ 4) :
+ jcp.ic_block / 4);
+
+ for (int icb1 = 0; icb1 < n_ic_blocks; icb1++) {
+ if (h_padded == true) {
+ /* fill padded area with shifted values */
+ Zmm inp = zmm_inp(0, jcp.nb_oc_blocking);
+ vpxord(inp, inp, inp);
+ vpsubb(inp, inp, zmm_shift);
+ } else {
+
+ for (int jj = _start; jj < _end; jj += ur_w_stride) {
+
+ int aux_src_off = src_offset(jj, icb1, ki);
+
+ if (jj >= jj_start && jj < jj_end
+ && ((jj + jcp.l_pad - ki) % jcp.stride_w == 0)) {
+ if (jcp.is_depthwise) {
+ vpmovzxbd(zmm_inp(jj, jcp.nb_oc_blocking),
+ EVEX_compress_addr(
+ aux_reg_src, aux_src_off));
+ } else if ((last_ic_block_flag & last_sp_block)
+ && tail_size != 0 && icb1 == n_ic_blocks - 1) {
+ xmm_t xmm_tmp = xmm_t(
+ zmm_inp(jj, jcp.nb_oc_blocking).getIdx());
+ for (int r = 0; r < tail_size; ++r)
+ vpinsrb(xmm_tmp, xmm_tmp,
+ ptr[aux_reg_src + aux_src_off + r], r);
+ vpbroadcastd(
+ zmm_inp(jj, jcp.nb_oc_blocking), xmm_tmp);
+ } else {
+ vpbroadcastd(zmm_inp(jj, jcp.nb_oc_blocking),
+ EVEX_compress_addr(
+ aux_reg_src, aux_src_off));
+ }
+ if (jcp.signed_input)
+ vpsubb(zmm_inp(jj, jcp.nb_oc_blocking),
+ zmm_inp(jj, jcp.nb_oc_blocking), zmm_shift);
+ } else {
+ /* fill padded area with shifted values */
+ if (jcp.signed_input) {
+ Zmm inp = zmm_inp(jj, jcp.nb_oc_blocking);
+ vpxord(inp, inp, inp);
+ vpsubb(inp, inp, zmm_shift);
+ }
+ }
+ }
+ }
+ for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
+ int aux_filt_off = kernel_offset(ocb, icb1, ki);
+
+ if (_end - _start > 0) {
+ if (jcp.is_depthwise)
+ vpmovsxbd(zmm_wei,
+ EVEX_compress_addr(aux_reg_filt, aux_filt_off));
+ else
+ vmovups(zmm_wei,
+ EVEX_compress_addr(aux_reg_filt, aux_filt_off));
+ }
+ for (int jj = _start; jj < _end; jj += ur_w_stride) {
+ Zmm inp = (h_padded == true) ?
+ zmm_inp(0, jcp.nb_oc_blocking) :
+ zmm_inp(jj, jcp.nb_oc_blocking);
+ compute(zmm_out(jj, ocb), zmm_wei, inp);
+ }
+ }
+ }
+ }
+}
+
+void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::kh_loop(int ur_w,
+ int l_overflow, int r_overflow, ker_block_t last_ic_block_flag) {
+
+ int ch_block_all = jcp.ch_block * jcp.ic_block * jcp.oc_block;
+ int shift_src_ih = jcp.typesize_in * (jcp.dilate_h + 1) * jcp.iw
+ * jcp.ngroups * jcp.ic_without_padding;
+ const int stride_h = jcp.signed_input ? 1 : jcp.stride_h;
+ int shift_filt_kh = jcp.typesize_in * jcp.kw * ch_block_all * stride_h;
+
+ Label kh_loop_label, skip_kh_loop;
+ Label t_overflow_label, no_t_overflow_label, b_overflow_label,
+ no_b_overflow_label;
+
+ mov(aux_reg_src, reg_src);
+ mov(aux_reg_filt, reg_filt);
+
+ if (jcp.signed_input && jcp.ndims > 3) {
+ /* Weights are transposed, so first compute 'bottom' padding. */
+ mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
+ cmp(reg_overflow, 0);
+ je(no_b_overflow_label, T_NEAR);
+ L(b_overflow_label); {
+ compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
+
+ add(aux_reg_filt, shift_filt_kh);
+ dec(reg_overflow);
+ cmp(reg_overflow, 0);
+ jg(b_overflow_label, T_NEAR);
+ }
+ L(no_b_overflow_label);
+ }
+
+ mov(reg_kh, ptr[param1 + GET_OFF(kh_padding)]);
+
+ if (jcp.signed_input || ((!jcp.signed_input)
+ && ((min(jcp.t_pad, jcp.b_pad) < 0)
+ || ((jcp.kh - 1) * (jcp.dilate_h + 1)
+ < nstl::max(jcp.t_pad, jcp.b_pad))))) {
+ cmp(reg_kh, 0);
+ je(skip_kh_loop, T_NEAR);
+ }
+
+ L(kh_loop_label); {
+ compute_ker(ur_w, l_overflow, r_overflow, last_ic_block_flag, false);
+ sub(aux_reg_src, shift_src_ih);
+ add(aux_reg_filt, shift_filt_kh);
+ dec(reg_kh);
+
+ /* Insert weight compensation in stride 'holes' */
+ if (jcp.signed_input && jcp.stride_h > 1) {
+ Label kh_comp_loop;
+
+ cmp(reg_kh, 0);
+ je(skip_kh_loop, T_NEAR);
+ mov(reg_comp_strides, jcp.stride_h - 1);
+ L(kh_comp_loop);
+ {
+ compute_ker(
+ ur_w, 0, 0, last_ic_block_flag, true);
+ add(aux_reg_filt, shift_filt_kh);
+ dec(reg_comp_strides);
+ cmp(reg_comp_strides, 0);
+ jg(kh_comp_loop, T_NEAR);
+ }
+ }
+ cmp(reg_kh, 0);
+ jg(kh_loop_label, T_NEAR);
+ }
+ L(skip_kh_loop);
+ if (jcp.signed_input && jcp.ndims > 3) {
+ mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
+ cmp(reg_overflow, 0);
+ je(no_t_overflow_label, T_NEAR);
+ L(t_overflow_label); {
+ compute_ker(ur_w, 0, 0, last_ic_block_flag, true);
+
+ add(aux_reg_filt, shift_filt_kh);
+ dec(reg_overflow);
+ cmp(reg_overflow, 0);
+ jg(t_overflow_label, T_NEAR);
+ }
+ L(no_t_overflow_label);
+ }
+}
+
+void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::prepare_output(int ur_w) {
+ for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
+ for (int ur = 0; ur < ur_w; ur++) {
+ zmm_t zmm = zmm_out(ur, ocb);
+ vpxord(zmm, zmm, zmm);
+ }
+ }
+ if (jcp.signed_input) {
+ xor_(reg_scratch, reg_scratch);
+ Reg8 _t8 = reg_scratch.cvt8();
+ mov(_t8, (int8_t)-128);
+ vpbroadcastb(zmm_shift, _t8);
+ }
+}
+
+void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::cvt2ps(
+ data_type_t type_in, zmm_t zmm_in, const Operand &op, bool mask_flag) {
+ zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
+ switch (type_in) {
+ case data_type::f32:
+ case data_type::s32: vmovups(zmm, op); break;
+ case data_type::s8: vpmovsxbd(zmm, op); break;
+ case data_type::u8: vpmovzxbd(zmm, op); break;
+ default: assert(!"unsupported data type");
+ }
+ if (type_in != data_type::f32)
+ vcvtdq2ps(zmm_in, zmm_in);
+}
+
+void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::store_output(
+ int ur_w, bool last_oc_block) {
+ mov(reg_bias, ptr[param1 + GET_OFF(bias)]);
+ mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
+
+ if (jcp.signed_input)
+ mov(reg_compensation, ptr[param1 + GET_OFF(compensation)]);
+
+ const auto &p = attr_.post_ops_;
+ const int sum_idx = p.find(primitive_kind::sum);
+ const float *p_sum_scale
+ = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
+ if (p_sum_scale && *p_sum_scale != 1.f)
+ mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
+
+ if (jcp.with_bias && jcp.signed_input && jcp.ver != ver_vnni) {
+ mov(reg_bias_alpha, float2int(jcp.wei_adj_scale));
+ vmovq(xmm_bias_alpha(), reg_bias_alpha);
+ vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
+ }
+
+ for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
+ const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
+ int scale_offset
+ = jcp.is_oc_scale * (sizeof(float) * ocb * jcp.oc_block);
+
+ auto zmm_bias = zmm_tmp;
+ if (jcp.with_bias) {
+ int bias_offset = jcp.typesize_bia * ocb * jcp.oc_block;
+ auto bias_addr = EVEX_compress_addr(reg_bias, bias_offset);
+ cvt2ps(jcp.bia_dt, zmm_bias, bias_addr, mask_flag);
+ if (jcp.signed_input && jcp.ver != ver_vnni)
+ vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
+ }
+ if (jcp.signed_input) {
+ int comp_offset = sizeof(int32_t) * ocb * jcp.oc_block;
+ auto comp_addr = EVEX_compress_addr(reg_compensation, comp_offset);
+ cvt2ps(data_type::s32, zmm_comp, comp_addr, mask_flag);
+ }
+
+ for (int ur = 0; ur < ur_w; ur++) {
+ zmm_t zmm = zmm_out(ur, ocb);
+ vcvtdq2ps(zmm, zmm);
+ if (jcp.signed_input)
+ vaddps(zmm, zmm, zmm_comp);
+ if (jcp.with_bias)
+ vaddps(zmm, zmm, zmm_bias);
+ zmm_t mask_zmm = mask_flag ? zmm | ktail_mask | T_z : zmm;
+ vmulps(mask_zmm, zmm,
+ EVEX_compress_addr(reg_ptr_scales, scale_offset));
+ }
+ }
+ if (maybe_eltwise(0))
+ compute_eltwise(ur_w);
+ if (p_sum_scale) { // post_op: sum
+ for (int k = 0; k < jcp.nb_oc_blocking; k++) {
+ const bool mask_flag
+ = last_oc_block == 1 && k == jcp.nb_oc_blocking - 1;
+ for (int j = 0; j < ur_w; j++) {
+ int aux_output_offset
+ = jcp.typesize_out
+ * (k * jcp.oc_block
+ + j * jcp.oc_without_padding * jcp.ngroups);
+ auto addr = EVEX_compress_addr(reg_dst, aux_output_offset);
+ Zmm zmm = zmm_out(j, k);
+ cvt2ps(jcp.dst_dt, zmm_prev_dst, addr, mask_flag);
+ if (*p_sum_scale == 1.f)
+ vaddps(zmm, zmm_prev_dst);
+ else
+ vfmadd231ps(zmm, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
+ }
+ }
+ }
+ if (maybe_eltwise(1))
+ compute_eltwise(ur_w);
+
+ for (int ocb = 0; ocb < jcp.nb_oc_blocking; ocb++) {
+ const bool mask_flag = last_oc_block && ocb == jcp.nb_oc_blocking - 1;
+ for (int ur = 0; ur < ur_w; ur++) {
+ zmm_t zmm = zmm_out(ur, ocb);
+ if (jcp.dst_dt == data_type::u8) {
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+ vmaxps(zmm, zmm_zero, zmm);
+ }
+ if (jcp.dst_dt != data_type::f32)
+ vcvtps2dq(zmm, zmm);
+ }
+ for (int ur = 0; ur < ur_w; ur++) {
+ int aux_dst_off = jcp.typesize_out
+ * (ur * jcp.ngroups * jcp.oc_without_padding
+ + ocb * jcp.oc_block);
+ auto addr = EVEX_compress_addr(reg_dst, aux_dst_off);
+
+ zmm_t zmm = zmm_out(ur, ocb);
+ zmm_t r_zmm = mask_flag ? zmm | ktail_mask : zmm;
+ switch (jcp.dst_dt) {
+ case data_type::f32:
+ case data_type::s32: vmovups(addr, r_zmm); break;
+ case data_type::s8: vpmovsdb(addr, r_zmm); break;
+ case data_type::u8: vpmovusdb(addr, r_zmm); break;
+ default: assert(!"unknown dst_dt");
+ }
+ }
+ }
+}
+
+void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::icb_loop(
+ int ur_w, int l_overflow, int r_overflow, bool is_last_sp_block) {
+
+ int shift_src_icb = jcp.typesize_in * jcp.ic_block;
+ int shift_filt_icb
+ = jcp.typesize_in * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block;
+
+ prepare_output(ur_w);
+
+ Label skip_icb_loop, icb_loop_label;
+
+ mov(reg_icb, jcp.nb_ic);
+ L(icb_loop_label); {
+
+ if (jcp.ic_without_padding != jcp.ic) {
+ Label common_ker, end_ker;
+ cmp(reg_icb, 1);
+ jg(common_ker, T_NEAR);
+
+ kh_loop(ur_w, l_overflow, r_overflow,
+ is_last_sp_block ? last_sp_block : last_ic_block);
+ jmp(end_ker, T_NEAR);
+
+ L(common_ker);
+ kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
+
+ L(end_ker);
+ } else {
+ kh_loop(ur_w, l_overflow, r_overflow, no_last_block);
+ }
+
+ add(reg_src, shift_src_icb);
+ add(reg_filt, shift_filt_icb);
+ dec(reg_icb);
+ cmp(reg_icb, 0);
+ jg(icb_loop_label, T_NEAR);
+ }
+
+ /* come-back pointers */
+ sub(reg_src, jcp.nb_ic * shift_src_icb);
+ sub(reg_filt, jcp.nb_ic * shift_filt_icb);
+ L(skip_icb_loop);
+
+ if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
+ Label common_store, end_store;
+ mov(reg_oc_blocks, ptr[param1 + GET_OFF(oc_blocks)]);
+ if (jcp.is_depthwise)
+ cmp(reg_oc_blocks, jcp.nb_ch - 1);
+ else
+ cmp(reg_oc_blocks, jcp.nb_oc - jcp.nb_oc_blocking);
+ jne(common_store, T_NEAR);
+
+ store_output(ur_w, true);
+ jmp(end_store, T_NEAR);
+
+ L(common_store);
+ store_output(ur_w, false);
+
+ L(end_store);
+
+ } else {
+ store_output(ur_w, false);
+ }
+}
+
+void jit_avx512_core_x8s8s32x_deconv_fwd_kernel::generate() {
+ preamble();
+
+ xor_(reg_scratch, reg_scratch);
+ Reg16 _t = reg_scratch.cvt16();
+ mov(_t, 0x1);
+ vpbroadcastw(zmm_one, _t);
+
+ if (jcp.ngroups % jcp.ch_block != 0 || jcp.oc_without_padding != jcp.oc) {
+ int tail_size = jcp.is_depthwise ?
+ jcp.ngroups % jcp.ch_block :
+ jcp.oc_without_padding % jcp.oc_block;
+ int mask = (1 << tail_size) - 1;
+ Reg32 regw_tmp = reg_nur_w.cvt32();
+ mov(regw_tmp, mask);
+ kmovw(ktail_mask, regw_tmp);
+ }
+
+ mov(reg_src, ptr[param1 + GET_OFF(src)]);
+ mov(reg_filt, ptr[param1 + GET_OFF(filt)]);
+ mov(reg_dst, ptr[param1 + GET_OFF(dst)]);
+
+ int dst_shift = jcp.typesize_out * jcp.ur_w * jcp.ngroups
+ * jcp.oc_without_padding;
+ int src_shift = jcp.typesize_in * (jcp.ur_w / jcp.stride_w) * jcp.ngroups
+ * jcp.ic_without_padding;
+
+ int l_overflow = max(
+ 0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - jcp.l_pad) / jcp.stride_w);
+ int r_overflow
+ = max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1) - max(0, jcp.r_pad))
+ / jcp.stride_w);
+
+ int r_overflow1
+ = nstl::max(0, ((jcp.kw - 1) * (jcp.dilate_w + 1)
+ - nstl::max(0, jcp.r_pad) - jcp.ur_w_tail)
+ / jcp.stride_w);
+ int nur_w = jcp.ow / jcp.ur_w;
+ if (r_overflow1 > 0)
+ nur_w--;
+
+ if (jcp.ur_w == jcp.ow) {
+ icb_loop(jcp.ur_w, l_overflow, r_overflow, true);
+ } else if (nur_w == 0) {
+ icb_loop(jcp.ur_w, l_overflow, r_overflow1, jcp.ur_w_tail == 0);
+ add(reg_src, src_shift);
+ add(reg_dst, dst_shift);
+ if (jcp.ur_w_tail != 0)
+ icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
+ } else {
+ xor_(reg_nur_w, reg_nur_w);
+ if (l_overflow > 0) {
+ icb_loop(jcp.ur_w, l_overflow, 0, false);
+ add(reg_src, src_shift);
+ add(reg_dst, dst_shift);
+ inc(reg_nur_w);
+ }
+ if ((l_overflow <= 0 && nur_w > 0) || (l_overflow > 0 && nur_w > 1)) {
+ Label ow_loop_label;
+ L(ow_loop_label);
+ {
+ icb_loop(jcp.ur_w, 0, 0, false);
+ add(reg_src, src_shift);
+ add(reg_dst, dst_shift);
+ inc(reg_nur_w);
+ cmp(reg_nur_w, nur_w);
+ jl(ow_loop_label, T_NEAR);
+ }
+ }
+ if (r_overflow1 > 0) {
+ icb_loop(jcp.ur_w, 0, r_overflow1, jcp.ur_w_tail == 0);
+ add(reg_src, src_shift);
+ add(reg_dst, dst_shift);
+ }
+ if (jcp.ur_w_tail != 0) {
+ icb_loop(jcp.ur_w_tail, 0, r_overflow, true);
+ }
+ }
+ postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
+ dst_type>::execute_forward_1d(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ auto &jcp = kernel_->jcp;
+
+ int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
+ int nb_groups = jcp.nb_ch;
+
+ const float *oscales = pd()->attr()->output_scales_.scales_;
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ auto local_scales
+ = scratchpad(ctx).template get<float>(key_conv_adjusted_scales);
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / pd()->jcp_.wei_adj_scale;
+ if (count == 1) {
+ utils::array_set(local_scales, oscales[0] * factor, 16);
+ } else {
+ for (size_t c = 0; c < count; c++)
+ local_scales[c] = oscales[c] * factor;
+ }
+ oscales = local_scales;
+ }
+ size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
+ auto w = const_cast<wei_data_t *>(weights);
+ int32_t *compensation
+ = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ int start{ 0 }, end{ 0 };
+ int work_amount = jcp.mb * nb_groups * oc_chunks;
+ balance211(work_amount, nthr, ithr, start, end);
+
+ auto p = jit_deconv_call_s();
+
+ int n{ 0 }, g{ 0 }, occ{ 0 };
+ if (jcp.loop_order == loop_ngc)
+ nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks);
+ else if (jcp.loop_order == loop_cgn)
+ nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb);
+ else
+ assert(!"unsupported loop order");
+ while (start < end) {
+
+ int ocb = occ * jcp.nb_oc_blocking;
+ int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
+ int g_ic = g * jcp.ch_block * jcp.ic;
+
+ p.dst = dst + dst_d.blk_off(n, g_oc);
+ p.src = src + src_d.blk_off(n, g_ic);
+ p.filt = weights + wht_blk_off(weights_d, g, ocb, 0);
+ p.bias = jcp.with_bias ?
+ bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
+ 0;
+ p.compensation = (jcp.signed_input) ? compensation + g_oc : 0;
+ p.scales = &oscales[jcp.is_oc_scale * g_oc];
+ p.t_overflow = 0;
+ p.b_overflow = 0;
+ p.kh_padding = jcp.kh;
+ p.oc_blocks = jcp.is_depthwise ? g : ocb;
+
+ kernel_->jit_ker(&p);
+
+ ++start;
+ if (jcp.loop_order == loop_ngc)
+ nd_iterator_step(n, jcp.mb, g, nb_groups, occ, oc_chunks);
+ else if (jcp.loop_order == loop_cgn)
+ nd_iterator_step(occ, oc_chunks, g, nb_groups, n, jcp.mb);
+ else
+ assert(!"unsupported loop order");
+ }
+ });
+}
+
+template <data_type_t src_type, data_type_t dst_type>
+void _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type,
+ dst_type>::execute_forward_2d(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ auto &jcp = kernel_->jcp;
+
+ int oc_chunks = jcp.nb_oc / jcp.nb_oc_blocking;
+ int nb_groups = jcp.nb_ch;
+
+ size_t src_h_stride = src_d.blk_off(0, 0, 1);
+ size_t dst_h_stride = dst_d.blk_off(0, 0, 1);
+ size_t wht_kh_stride = wht_blk_off(weights_d, 0, 0, 0, 1);
+
+ const float *oscales = pd()->attr()->output_scales_.scales_;
+ if (jcp.signed_input && jcp.ver != ver_vnni) {
+ auto local_scales
+ = scratchpad(ctx).template get<float>(key_conv_adjusted_scales);
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / pd()->jcp_.wei_adj_scale;
+ if (count == 1) {
+ utils::array_set(local_scales, oscales[0] * factor, 16);
+ } else {
+ for (size_t c = 0; c < count; c++)
+ local_scales[c] = oscales[c] * factor;
+ }
+ oscales = local_scales;
+ }
+ size_t offset = (size_t)jcp.ngroups * jcp.oc * jcp.ic * jcp.kh * jcp.kw;
+ auto w = const_cast<wei_data_t *>(weights);
+ int32_t *compensation
+ = (jcp.signed_input) ? reinterpret_cast<int32_t *>(&w[offset]) : 0;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ int start{ 0 }, end{ 0 };
+ int work_amount = jcp.mb * nb_groups * oc_chunks * jcp.oh;
+ balance211(work_amount, nthr, ithr, start, end);
+
+ auto p = jit_deconv_call_s();
+
+ /*loop order = cgn*/
+ int n{ 0 }, g{ 0 }, occ{ 0 }, oh_s{ 0 };
+ if (jcp.loop_order == loop_ngc)
+ nd_iterator_init(start, n, jcp.mb, g, nb_groups, occ, oc_chunks,
+ oh_s, jcp.oh);
+ else if (jcp.loop_order == loop_cgn)
+ nd_iterator_init(start, occ, oc_chunks, g, nb_groups, n, jcp.mb,
+ oh_s, jcp.oh);
+ else
+ assert(!"unsupported loop order");
+ while (start < end) {
+
+ int ocb = occ * jcp.nb_oc_blocking;
+ int g_oc = (g * jcp.ch_block * jcp.nb_oc + ocb) * jcp.oc_block;
+ int g_ic = g * jcp.ch_block * jcp.ic;
+ int work_rem = end - start;
+ int oh_e = oh_s + work_rem > jcp.oh ? jcp.oh : oh_s + work_rem;
+
+ auto dst_w = dst + dst_d.blk_off(n, g_oc);
+ auto src_w = src + src_d.blk_off(n, g_ic);
+ auto wht_w = weights + wht_blk_off(weights_d, g, ocb, 0);
+ auto bias_w = jcp.with_bias ?
+ bias + (bias_d.blk_off(g_oc) * jcp.typesize_bia) :
+ 0;
+ int32_t *compensation_w
+ = (jcp.signed_input) ? compensation + g_oc : 0;
+
+ auto scales = &oscales[jcp.is_oc_scale * g_oc];
+ for (int oj = oh_s; oj < oh_e; oj++) {
+ int ih_max = 0, kh_lo = 0, kh_len = 0;
+ if (jcp.dilate_h != 0 && jcp.stride_h == 1) {
+ /* dilation */
+ int dilate_h = jcp.dilate_h + 1;
+ // Note: use div_up to account for "holes" in filter
+ int o_t_overflow = div_up(
+ max(0, (jcp.kh - 1) * dilate_h - oj - jcp.t_pad),
+ dilate_h);
+ int o_b_overflow
+ = div_up(max(0, (jcp.kh - 1) * dilate_h + 1 - jcp.oh
+ + oj - jcp.b_pad),
+ dilate_h);
+ kh_len = jcp.kh - o_t_overflow - o_b_overflow;
+ kh_lo = o_b_overflow;
+ ih_max = oj + jcp.t_pad - o_b_overflow * dilate_h;
+ } else {
+ int o_t_overflow = max(
+ 0, (jcp.kh - (oj + 1 + jcp.t_pad)) / jcp.stride_h);
+ int o_b_overflow
+ = max(0, ((oj + jcp.kh) - (jcp.oh + jcp.b_pad))
+ / jcp.stride_h);
+ int overflow_kh_hi = jcp.kh - 1
+ - abs(jcp.oh + jcp.b_pad - (oj + 1)) % jcp.stride_h;
+ int overflow_kh_lo = (oj + jcp.t_pad) % jcp.stride_h;
+
+ kh_len = (overflow_kh_hi - overflow_kh_lo) / jcp.stride_h
+ + 1 - o_t_overflow - o_b_overflow;
+ kh_lo = overflow_kh_lo + o_b_overflow * jcp.stride_h;
+ ih_max = (oj + jcp.t_pad - kh_lo) / jcp.stride_h;
+ }
+
+ int wei_stride
+ = (!jcp.signed_input) ? kh_lo * wht_kh_stride : 0;
+ p.src = src_w + ih_max * src_h_stride;
+ p.dst = dst_w + oj * dst_h_stride;
+ p.filt = wht_w + wei_stride;
+ p.bias = bias_w;
+ p.compensation = compensation_w;
+ p.t_overflow = max(
+ 0, jcp.kh - (kh_lo + max(0, kh_len - 1) * jcp.stride_h
+ + 1));
+ p.b_overflow = kh_lo;
+ p.kh_padding = kh_len;
+ p.scales = scales;
+ p.oc_blocks = jcp.is_depthwise ? g : ocb;
+ kernel_->jit_ker(&p);
+ }
+ if (jcp.loop_order == loop_ngc)
+ nd_iterator_jump(start, end, n, jcp.mb, g, nb_groups, occ,
+ oc_chunks, oh_s, jcp.oh);
+ else if (jcp.loop_order == loop_cgn)
+ nd_iterator_jump(start, end, occ, oc_chunks, g, nb_groups, n,
+ jcp.mb, oh_s, jcp.oh);
+ else
+ assert(!"unsupported loop order");
+ }
+ });
+}
+
+template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
+ data_type::u8>;
+template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
+ data_type::s8>;
+template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
+ data_type::f32>;
+template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::u8,
+ data_type::s32>;
+template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
+ data_type::u8>;
+template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
+ data_type::s8>;
+template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
+ data_type::f32>;
+template struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<data_type::s8,
+ data_type::s32>;
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp
new file mode 100644
index 0000000000..901038fa48
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_avx512_core_x8s8s32x_deconvolution.hpp
@@ -0,0 +1,237 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP
+#define CPU_JIT_AVX512_CORE_U8S8S32X_DECONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "cpu_primitive.hpp"
+#include "cpu_memory.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "nstl.hpp"
+
+#include "cpu_deconvolution_pd.hpp"
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+typedef enum {
+ no_last_block = 0x1U,
+ last_ic_block = 0x2U,
+ last_sp_block = 0x4U,
+} ker_block_t;
+
+struct jit_avx512_core_x8s8s32x_deconv_fwd_kernel : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_x8s8s32x_deconv_fwd_ker_t);
+
+ jit_avx512_core_x8s8s32x_deconv_fwd_kernel(
+ const jit_conv_conf_t &ajcp, const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr) {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<avx512_common>(
+ this, jcp.eltwise);
+ generate();
+ jit_ker = (void (*)(jit_deconv_call_s *))getCode();
+ }
+
+ ~jit_avx512_core_x8s8s32x_deconv_fwd_kernel() {
+ delete eltwise_injector_;
+ }
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const deconvolution_desc_t &cd,
+ memory_desc_t &src_md,
+ memory_desc_t &weights_md,
+ memory_desc_t &dst_md,
+ const bool with_bias,
+ memory_desc_t &bias_md,
+ const primitive_attr_t &attr);
+
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp, const primitive_attr_t &attr);
+
+ const jit_conv_conf_t &jcp;
+ const primitive_attr_t &attr_;
+ void (*jit_ker)(jit_deconv_call_s *);
+private:
+ jit_uni_eltwise_injector_f32<avx512_common> *eltwise_injector_;
+ using reg64_t = const Xbyak::Reg64;
+ using zmm_t = const Xbyak::Zmm;
+ using xmm_t = const Xbyak::Xmm;
+
+ reg64_t reg_src = r8;
+ reg64_t reg_filt = r9;
+ reg64_t reg_dst = r10;
+ reg64_t param1 = abi_param1;
+ reg64_t reg_kh = abi_not_param1;
+ reg64_t reg_nur_w = rbx;
+ reg64_t reg_bias = rdx;
+ reg64_t reg_icb = reg_bias;
+ reg64_t reg_ptr_scales = rax;
+ reg64_t reg_oc_blocks = rsi;
+
+ reg64_t aux_reg_src = r11;
+ reg64_t aux_reg_filt = r12;
+
+ reg64_t reg_compensation = r14;
+ reg64_t reg_scratch = r14;
+ reg64_t reg_ptr_sum_scale = r11;
+ reg64_t reg_bias_alpha = abi_not_param1;
+ reg64_t reg_overflow = rax;
+ reg64_t reg_comp_strides = reg_overflow;
+
+ Xbyak::Opmask ktail_mask = Xbyak::Opmask(2);
+ zmm_t zmm_tmp = zmm_t(28);
+ zmm_t zmm_one = zmm_t(29);
+ /* used during write-out section of store_output */
+ zmm_t zmm_zero = zmm_t(31);
+ zmm_t zmm_wei = zmm_t(31);
+
+ /* signed input */
+ zmm_t zmm_shift = zmm_t(30);
+ zmm_t zmm_comp = zmm_t(30);
+ zmm_t zmm_bias = zmm_t(31);
+ zmm_t zmm_prev_dst = zmm_t(31);
+
+ zmm_t zmm_out(int i_ur, int i_oc) {
+ int idx = i_ur * jcp.nb_oc_blocking + i_oc;
+ assert(idx < 31);
+ return zmm_t(idx);
+ }
+ zmm_t zmm_inp(int i_ic, int nb_x_blocking) {
+ int idx = i_ic + nb_x_blocking * jcp.ur_w;
+ assert(idx < 31);
+ return zmm_t(idx);
+ }
+ zmm_t zmm_bias_alpha() {
+ return zmm_t(jcp.nb_oc_blocking * jcp.ur_w);
+ }
+ xmm_t xmm_bias_alpha() {
+ return xmm_t(jcp.nb_oc_blocking * jcp.ur_w);
+ }
+
+ int get_ow_start(int ki, int l_overflow) {
+ int res = (jcp.ow - 1 + jcp.r_pad) % jcp.stride_w
+ + l_overflow * jcp.stride_w
+ - (jcp.kw - 1 - ki) * (jcp.dilate_w + 1);
+ while (res < 0)
+ res += jcp.stride_w;
+ return res;
+ }
+
+ int get_ow_end(int ur_w, int ki, int r_overflow) {
+ if (utils::one_of(ur_w, jcp.ow, jcp.ur_w_tail))
+ ur_w += nstl::min(0, jcp.r_pad); // remove negative padding
+ int res = (ur_w - 1 + jcp.l_pad) % jcp.stride_w
+ + r_overflow * jcp.stride_w - ki * (jcp.dilate_w + 1);
+ while (res < 0)
+ res += jcp.stride_w;
+ return ur_w - res;
+ }
+ bool maybe_eltwise(int position);
+ void compute_eltwise(int ur_w);
+ void prepare_output(int ur_w);
+ void store_output(int ur_w, bool last_oc_block);
+ void compute_ker(int ur_w, int l_overflow, int r_overflow,
+ ker_block_t last_ic_block_flag, bool h_padded = false);
+ void kh_loop(int ur_w, int pad_l, int pad_r, ker_block_t last_ker_block);
+ void icb_loop(int ur_w, int pad_l, int pad_r, bool last_block);
+ void generate();
+ void cvt2ps(data_type_t type_in, zmm_t zmm_in, const Xbyak::Operand &op,
+ bool mask_flag);
+};
+
+template <impl::data_type_t src_type, impl::data_type_t dst_type>
+struct _jit_avx512_core_x8s8s32x_deconvolution_fwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_deconvolution_fwd_pd_t {
+ using cpu_deconvolution_fwd_pd_t::cpu_deconvolution_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_deconvolution:", avx512_core, ""),
+ _jit_avx512_core_x8s8s32x_deconvolution_fwd_t<src_type, dst_type>);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && (desc()->alg_kind & alg_kind::deconvolution_direct)
+ && desc()->src_desc.data_type == src_type
+ && desc()->dst_desc.data_type == dst_type
+ && IMPLICATION(with_bias(), utils::one_of(
+ desc()->bias_desc.data_type, data_type::f32,
+ data_type::s32, data_type::s8, data_type::u8))
+ && desc()->accum_data_type == data_type::s32;
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_avx512_core_x8s8s32x_deconv_fwd_kernel::
+ init_conf(jcp_, *desc(), src_md_, weights_md_, dst_md_,
+ with_bias(), bias_md_, *attr());
+
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_avx512_core_x8s8s32x_deconv_fwd_kernel::init_scratchpad(scratchpad,
+ jcp_, *attr());
+
+ return status::success;
+ }
+
+ jit_conv_conf_t jcp_;
+ };
+
+ _jit_avx512_core_x8s8s32x_deconvolution_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ {
+ kernel_ = new jit_avx512_core_x8s8s32x_deconv_fwd_kernel(pd()->jcp_,
+ *pd()->attr());
+ }
+
+ ~_jit_avx512_core_x8s8s32x_deconvolution_fwd_t() { delete kernel_; }
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<data_type::s8>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ if(pd()->ndims() == 3)
+ execute_forward_1d(ctx);
+ else
+ execute_forward_2d(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward_1d(const exec_ctx_t &ctx) const;
+ void execute_forward_2d(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ jit_avx512_core_x8s8s32x_deconv_fwd_kernel *kernel_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp
new file mode 100644
index 0000000000..c09592d5c9
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_generator.hpp
@@ -0,0 +1,773 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_AVX2_GENERATOR_HPP
+#define CPU_JIT_AVX2_GENERATOR_HPP
+
+#include <limits.h>
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "cpu_isa_traits.hpp"
+#include "jit_utils/jit_utils.hpp"
+
+#if defined(_WIN32) && !defined(__GNUC__)
+# define STRUCT_ALIGN(al, ...) __declspec(align(al)) __VA_ARGS__
+#else
+# define STRUCT_ALIGN(al, ...) __VA_ARGS__ __attribute__((__aligned__(al)))
+#endif
+
+#if defined(_WIN32)
+# define OFFSET_SHADOWSPACE 0x28
+#endif
+
+#define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \
+ const char *name() const override { return STRINGIFY(jit_name); } \
+ const char *source_file() const override { return __FILE__; }
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+// TODO: move this to jit_generator class?
+namespace {
+
+typedef enum {
+ PAGE_4K = 4096,
+ PAGE_2M = 2097152,
+} cpu_page_size_t;
+
+// TODO: move this somewhere else? Although this is only used by jit kernels
+// (Roma)
+static inline int float2int(float x) {
+ union {
+ float vfloat;
+ int vint;
+ } cvt;
+ cvt.vfloat = x;
+ return cvt.vint;
+}
+
+// TODO: A GPR class that hides ABI details from the JIT kernels and allows
+// numbering registers from 0 to 14 (x86_64) / 6 (x32) (gpr0, gpr1, ...) and
+// stack register (sr).
+//
+// This will allow using syntax like this:
+//
+// param = gpr0;
+// reg_input = gpr0;
+// reg_output = gpr1;
+// ...
+//
+// #ifndef XBYAK64
+// mov(param, ptr[sr])
+// #endif
+//
+// (Roma)
+
+#ifdef XBYAK64
+constexpr Xbyak::Operand::Code abi_save_gpr_regs[] = {
+ Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
+ Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15,
+#ifdef _WIN32
+ Xbyak::Operand::RDI, Xbyak::Operand::RSI,
+#endif
+};
+
+#ifdef _WIN32
+static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RCX),
+ abi_param2(Xbyak::Operand::RDX),
+ abi_param3(Xbyak::Operand::R8),
+ abi_param4(Xbyak::Operand::R9),
+ abi_not_param1(Xbyak::Operand::RDI);
+#else
+static const Xbyak::Reg64 abi_param1(Xbyak::Operand::RDI),
+ abi_param2(Xbyak::Operand::RSI),
+ abi_param3(Xbyak::Operand::RDX),
+ abi_param4(Xbyak::Operand::RCX),
+ abi_param5(Xbyak::Operand::R8),
+ abi_param6(Xbyak::Operand::R9),
+ abi_not_param1(Xbyak::Operand::RCX);
+#endif
+#endif
+
+inline unsigned int get_cache_size(int level, bool per_core = true){
+ unsigned int l = level - 1;
+ // Currently, if XByak is not able to fetch the cache topology
+ // we default to 32KB of L1, 512KB of L2 and 1MB of L3 per core.
+ if (cpu.getDataCacheLevels() == 0){
+ const int L1_cache_per_core = 32000;
+ const int L2_cache_per_core = 512000;
+ const int L3_cache_per_core = 1024000;
+ int num_cores = per_core ? 1 : mkldnn_get_max_threads();
+ switch(l){
+ case(0): return L1_cache_per_core * num_cores;
+ case(1): return L2_cache_per_core * num_cores;
+ case(2): return L3_cache_per_core * num_cores;
+ default: return 0;
+ }
+ }
+ if (l < cpu.getDataCacheLevels()) {
+ return cpu.getDataCacheSize(l)
+ / (per_core ? cpu.getCoresSharingDataCache(l) : 1);
+ } else
+ return 0;
+}
+
+}
+
+class jit_generator : public Xbyak::CodeGenerator
+{
+private:
+ const size_t xmm_len = 16;
+#ifdef _WIN32
+ const size_t xmm_to_preserve_start = 6;
+ const size_t xmm_to_preserve = 10;
+#else
+ const size_t xmm_to_preserve_start = 0;
+ const size_t xmm_to_preserve = 0;
+#endif
+
+ const size_t num_abi_save_gpr_regs
+ = sizeof(abi_save_gpr_regs) / sizeof(abi_save_gpr_regs[0]);
+
+ const size_t size_of_abi_save_regs
+ = num_abi_save_gpr_regs * rax.getBit() / 8
+ + xmm_to_preserve * xmm_len;
+
+public:
+ enum {
+ _cmp_eq_oq = 0u,
+ _cmp_lt_os = 1u,
+ _cmp_le_os = 2u,
+ _cmp_neq_uq = 4u,
+ _cmp_nlt_us = 5u,
+ _cmp_nle_us = 6u,
+
+ _op_floor = 1u,
+ _op_mxcsr = 4u,
+ };
+
+ Xbyak::Reg64 param1 = abi_param1;
+ const int EVEX_max_8b_offt = 0x200;
+ const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
+
+ inline size_t get_size_of_abi_save_regs() {
+ return size_of_abi_save_regs;
+ }
+
+ void preamble() {
+ if (xmm_to_preserve) {
+ sub(rsp, xmm_to_preserve * xmm_len);
+ for (size_t i = 0; i < xmm_to_preserve; ++i)
+ movdqu(ptr[rsp + i * xmm_len], Xbyak::Xmm(xmm_to_preserve_start + i));
+ }
+ for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
+ push(Xbyak::Reg64(abi_save_gpr_regs[i]));
+ if (mayiuse(avx512_common)) {
+ mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
+ }
+ }
+
+ void mic_prefetcht0(Xbyak::Address a) {
+ if (mayiuse(avx512_mic))
+ prefetcht0(a);
+ }
+
+ void mic_prefetcht1(Xbyak::Address a) {
+ if (mayiuse(avx512_mic))
+ prefetcht1(a);
+ }
+
+ void mic_prefetcht2(Xbyak::Address a) {
+ if (mayiuse(avx512_mic))
+ prefetcht2(a);
+ }
+
+ void uni_vzeroupper() {
+ if (mayiuse(avx) && !mayiuse(avx512_mic))
+ vzeroupper();
+ }
+
+ void postamble() {
+ for (size_t i = 0; i < num_abi_save_gpr_regs; ++i)
+ pop(Xbyak::Reg64(abi_save_gpr_regs[num_abi_save_gpr_regs - 1 - i]));
+ if (xmm_to_preserve) {
+ for (size_t i = 0; i < xmm_to_preserve; ++i)
+ movdqu(Xbyak::Xmm(xmm_to_preserve_start + i), ptr[rsp + i * xmm_len]);
+ add(rsp, xmm_to_preserve * xmm_len);
+ }
+ uni_vzeroupper();
+ ret();
+ }
+
+ template<typename T>
+ Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base,
+ T raw_offt, bool bcast = false)
+ {
+ using Xbyak::Zmm;
+ using Xbyak::Reg64;
+ using Xbyak::Address;
+ using Xbyak::RegExp;
+
+ assert(raw_offt <= INT_MAX);
+ auto offt = static_cast<int>(raw_offt);
+
+ int scale = 0;
+
+ if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
+ offt = offt - 2 * EVEX_max_8b_offt;
+ scale = 1;
+ } else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
+ offt = offt - 4 * EVEX_max_8b_offt;
+ scale = 2;
+ }
+
+ auto re = RegExp() + base + offt;
+ if (scale)
+ re = re + reg_EVEX_max_8b_offt * scale;
+
+ if (bcast)
+ return zword_b [re];
+ else
+ return zword [re];
+ }
+
+ Xbyak::Address make_safe_addr(const Xbyak::Reg64 &reg_out, size_t offt,
+ const Xbyak::Reg64 &tmp_reg, bool bcast = false) {
+ if (offt > INT_MAX) {
+ mov(tmp_reg, offt);
+ return bcast ? ptr_b[reg_out + tmp_reg] : ptr[reg_out + tmp_reg];
+ } else {
+ return bcast ? ptr_b[reg_out + offt] : ptr[reg_out + offt];
+ }
+ }
+
+ Xbyak::Address EVEX_compress_addr_safe(const Xbyak::Reg64 &base,
+ size_t raw_offt, const Xbyak::Reg64 &reg_offt, bool bcast = false) {
+ if (raw_offt > INT_MAX) {
+ return make_safe_addr(base, raw_offt, reg_offt, bcast);
+ } else {
+ return EVEX_compress_addr(base, raw_offt, bcast);
+ }
+ }
+
+ void safe_add(const Xbyak::Reg64 &base, size_t raw_offt,
+ const Xbyak::Reg64 &reg_offt) {
+ if (raw_offt > INT_MAX) {
+ mov(reg_offt, raw_offt);
+ add(base, reg_offt);
+ } else {
+ add(base, raw_offt);
+ }
+ }
+
+ void safe_sub(const Xbyak::Reg64 &base, size_t raw_offt,
+ const Xbyak::Reg64 &reg_offt) {
+ if (raw_offt > INT_MAX) {
+ mov(reg_offt, raw_offt);
+ sub(base, reg_offt);
+ } else {
+ sub(base, raw_offt);
+ }
+ }
+
+ // Disallow char-based labels completely
+ void L(const char *label) = delete;
+ void L(Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
+
+ void uni_vpxor(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op) {
+ assert(x1.getIdx() == x2.getIdx());
+ pxor(x2, op);
+ }
+ void uni_vpxor(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+ const Xbyak::Operand &op) {
+ if (mayiuse(avx2)) {
+ vpxor(x1, x2, op);
+ } else {
+ vxorps(x1, x2, op);
+ }
+ }
+ void uni_vpxor(const Xbyak::Zmm &x1, const Xbyak::Zmm &x2,
+ const Xbyak::Operand &op) {
+ vpxord(x1, x2, op);
+ }
+
+ void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
+ movss(addr, x);
+ }
+ void uni_vmovss(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
+ vmovss(addr, x);
+ }
+ void uni_vmovss(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
+ movss(x, addr);
+ }
+ void uni_vmovss(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
+ vmovss(x, addr);
+ }
+
+ void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Xmm &x) {
+ movsd(addr, x);
+ }
+ void uni_vmovsd(const Xbyak::Address& addr, const Xbyak::Ymm &x) {
+ vmovsd(addr, x);
+ }
+ void uni_vmovsd(const Xbyak::Xmm &x, const Xbyak::Address& addr) {
+ movsd(x, addr);
+ }
+ void uni_vmovsd(const Xbyak::Ymm &x, const Xbyak::Address& addr) {
+ vmovsd(x, addr);
+ }
+
+ void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
+ movdqu(addr, x);
+ }
+ void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
+ vmovdqu(addr, x);
+ }
+ void uni_vmovdqu(const Xbyak::Address &addr, const Xbyak::Zmm &x) {
+ vmovdqu32(addr, x);
+ }
+
+ void uni_vmovdqu(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
+ movdqu(x, addr);
+ }
+ void uni_vmovdqu(const Xbyak::Ymm &x, const Xbyak::Address &addr) {
+ vmovdqu(x, addr);
+ }
+ void uni_vmovdqu(const Xbyak::Zmm &x, const Xbyak::Address &addr) {
+ vmovdqu32(x, addr);
+ }
+
+ void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
+ movups(addr, x);
+ }
+ void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
+ vmovups(addr, x);
+ }
+
+ void uni_vmovups(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+ movups(x, op);
+ }
+ void uni_vmovups(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
+ vmovups(x, op);
+ }
+
+ void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
+ movntps(addr, x);
+ }
+ void uni_vmovntps(const Xbyak::Address &addr, const Xbyak::Ymm &x) {
+ vmovntps(addr, x);
+ }
+
+ void uni_vbroadcastss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+ movss(x, op);
+ shufps(x, x, 0x0);
+ }
+ void uni_vbroadcastss(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
+ if (op.isMEM() || mayiuse(avx2)) {
+ vbroadcastss(x, op);
+ } else {
+ Xbyak::Xmm t(x.getIdx());
+ if (t.getIdx() != op.getIdx()) movss(t, op);
+ vinsertf128(x, x, t, 1);
+ vshufps(x, x, x, 0);
+ }
+ }
+
+ void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+ movsd(x, op);
+ pshufd(x, x, 0x0);
+ }
+ void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
+ if (mayiuse(avx2)) {
+ vpbroadcastd(x, op);
+ } else {
+ Xbyak::Xmm t(x.getIdx());
+ if (t.getIdx() != op.getIdx()) movsd(t, op);
+ vinsertf128(x, x, t, 1);
+ vshufps(x, x, x, 0);
+ }
+ }
+
+ void uni_vrcpss(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+ rcpss(x, op);
+ }
+ void uni_vrcpss(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2) {
+ Xbyak::Xmm x1_(x1.getIdx());
+ Xbyak::Xmm x2_(x2.getIdx());
+ vrcpss(x1_, x1_, x2_);
+ }
+ void uni_vrcpss(const Xbyak::Ymm &x, const Xbyak::Address &op) {
+ Xbyak::Xmm x_(x.getIdx());
+ vrcpss(x_, x_, op);
+ }
+
+ void uni_vrcpps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+ rcpps(x, op);
+ }
+ void uni_vrcpps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
+ vrcpps(x, op);
+ }
+ void uni_vrcpps(const Xbyak::Zmm &x, const Xbyak::Operand &op) {
+ vrcp14ps(x, op);
+ }
+
+ void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ assert(x.getIdx() == op1.getIdx());
+ divps(x, op2);
+ }
+ void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ vdivps(x, op1, op2);
+ }
+
+ void uni_vdivps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
+ movups(buf, op1);
+ divps(buf, op2);
+ if (x.getIdx() != buf.getIdx()) {
+ movups(x, buf);
+ }
+ }
+
+ void uni_vdivps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
+ vdivps(x, op1, op2);
+ }
+
+ void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ assert(x.getIdx() == op1.getIdx());
+ addps(x, op2);
+ }
+ void uni_vaddps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ vaddps(x, op1, op2);
+ }
+
+ void uni_vpsignd(const Xbyak::Xmm& x1, const Xbyak::Xmm& x2,
+ const Xbyak::Operand& op) {
+ assert(x1.getIdx() == x2.getIdx());
+ psignd(x1, op);
+ }
+ void uni_vpsignd(const Xbyak::Ymm& x1, const Xbyak::Ymm& x2,
+ const Xbyak::Operand& op) {
+ vpsignd(x1, x2, op);
+ }
+
+ void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ assert(x.getIdx() == op1.getIdx());
+ subps(x, op2);
+ }
+ void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ vsubps(x, op1, op2);
+ }
+
+ void uni_vsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2, const Xbyak::Xmm &buf) {
+ movups(buf, op1);
+ subps(buf, op2);
+ if (x.getIdx() != buf.getIdx()) {
+ movups(x, buf);
+ }
+ }
+
+ void uni_vsubps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2, const Xbyak::Ymm &buf) {
+ vsubps(x, op1, op2);
+ }
+
+ void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ assert(x.getIdx() == op1.getIdx());
+ mulps(x, op2);
+ }
+ void uni_vmulps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ vmulps(x, op1, op2);
+ }
+
+ void uni_vfmadd213ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op) {
+ mulps(x1, x2);
+ addps(x1, op);
+ }
+ void uni_vfmadd213ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+ const Xbyak::Operand &op) {
+ vfmadd213ps(x1, x2, op);
+ }
+
+ void uni_vfmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op) {
+ mulps(x2, op);
+ addps(x1, x2);
+ }
+ void uni_vfmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+ const Xbyak::Operand &op) {
+ vfmadd231ps(x1, x2, op);
+ }
+
+ void uni_vfnmadd231ps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op) {
+ mulps(x2, op);
+ subps(x1, x2);
+ }
+
+ void uni_vfnmadd231ps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+ const Xbyak::Operand &op) {
+ vfnmadd231ps(x1, x2, op);
+ }
+
+ void uni_vsqrtps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+ sqrtps(x, op);
+ }
+ void uni_vsqrtps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
+ vsqrtps(x, op);
+ }
+
+ void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op) {
+ assert(x1.getIdx() == x2.getIdx());
+ paddd(x2, op);
+ }
+ void uni_vpaddd(const Xbyak::Ymm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op) {
+ vpaddd(x1, x2, op);
+ }
+
+ void uni_vandps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op = Xbyak::Operand()) {
+ assert(x1.getIdx() == x2.getIdx());
+ andps(x1, op);
+ }
+ void uni_vandps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+ const Xbyak::Operand &op = Xbyak::Operand()) {
+ if (!mayiuse(avx512_common) || x1.getBit() < 512)
+ vandps(x1, x2, op);
+ else
+ vpandd(x1, x2, op);
+ }
+
+ void uni_vorps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op = Xbyak::Operand()) {
+ assert(x1.getIdx() == x2.getIdx());
+ orps(x1, op);
+ }
+ void uni_vorps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+ const Xbyak::Operand &op = Xbyak::Operand()) {
+ if (!mayiuse(avx512_common) || x1.getBit() < 512)
+ vorps(x1, x2, op);
+ else
+ vpord(x1, x2, op);
+ }
+
+ void uni_vpslld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
+ const int imm) {
+ assert(x.getIdx() == op.getIdx());
+ pslld(x, imm);
+ }
+ void uni_vpslld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
+ const int imm) {
+ vpslld(x, op, imm);
+ }
+
+ void uni_vpsrld(const Xbyak::Xmm &x, const Xbyak::Operand &op,
+ const int imm) {
+ assert(x.getIdx() == op.getIdx());
+ psrld(x, imm);
+ }
+ void uni_vpsrld(const Xbyak::Ymm &x, const Xbyak::Operand &op,
+ const int imm) {
+ vpsrld(x, op, imm);
+ }
+
+ void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ assert(x.getIdx() == op1.getIdx());
+ maxps(x, op2);
+ }
+ void uni_vmaxps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ vmaxps(x, op1, op2);
+ }
+
+ void uni_vminps(const Xbyak::Xmm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ assert(x.getIdx() == op1.getIdx());
+ minps(x, op2);
+ }
+ void uni_vminps(const Xbyak::Ymm &x, const Xbyak::Operand &op1,
+ const Xbyak::Operand &op2 = Xbyak::Operand()) {
+ vminps(x, op1, op2);
+ }
+
+ void uni_vcmpgtps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op) {
+ assert(x1.getIdx() == x2.getIdx());
+ cmpps(x1, op, _cmp_nle_us);
+ }
+
+ void uni_vcmpgtps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+ const Xbyak::Operand &op) {
+ vcmpgtps(x1, x2, op);
+ }
+
+ void uni_vcmpgeps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op) {
+ assert(x1.getIdx() == x2.getIdx());
+ cmpps(x1, op, _cmp_nlt_us);
+ }
+
+ void uni_vcmpgeps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+ const Xbyak::Operand &op) {
+ vcmpps(x1, x2, op, _cmp_nlt_us);
+ }
+
+ void uni_vtestps(const Xbyak::Xmm &x1, const Xbyak::Operand &op) {
+ ptest(x1, op);
+ }
+
+ void uni_vtestps(const Xbyak::Ymm &x1, const Xbyak::Operand &op) {
+ assert(!(x1.isZMM() || op.isZMM()));
+ vtestps(x1, op);
+ }
+
+ void uni_vblendvps(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2,
+ const Xbyak::Operand &op, const Xbyak::Xmm &msk) {
+ assert(x1.getIdx() == x2.getIdx());
+ assert(msk.getIdx() == 0);
+ blendvps(x1, op);
+ }
+ void uni_vblendvps(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2,
+ const Xbyak::Operand &op, const Xbyak::Ymm &msk) {
+ vblendvps(x1, x2, op, msk);
+ }
+
+ void uni_vroundps(const Xbyak::Xmm &x, const Xbyak::Operand &op,
+ const int imm) {
+ roundps(x, op, imm);
+ }
+ void uni_vroundps(const Xbyak::Ymm &x, const Xbyak::Operand &op,
+ const int imm) {
+ vroundps(x, op, imm);
+ }
+
+ void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+ cvtps2dq(x, op);
+ }
+ void uni_vcvtps2dq(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
+ vcvtps2dq(x, op);
+ }
+
+ void uni_vcvtdq2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
+ cvtdq2ps(x, op);
+ }
+ void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) {
+ vcvtdq2ps(x, op);
+ }
+
+ void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Xmm &x2) {
+ movmskps(x1.cvt64(), x2);
+ }
+ void uni_vmovmskps(const Xbyak::Reg &x1, const Xbyak::Ymm &x2) {
+ vmovmskps(x1, x2);
+ }
+
+ void uni_vpackssdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
+ assert(x1.getIdx() == x1.getIdx());
+ packssdw(x1, op);
+ }
+ void uni_vpackssdw(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
+ vpackssdw(x1, x2, op);
+ }
+
+ void uni_vpackuswb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op){
+ assert(x1.getIdx() == x1.getIdx());
+ packuswb(x1, op);
+ }
+ void uni_vpackuswb(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op){
+ vpackuswb(x1, x2, op);
+ }
+
+
+ void mul_by_const(const Xbyak::Reg &out,
+ const Xbyak::Reg64 &tmp, int value) {
+ // Generates a shift + add sequence for multiplicating contents of the
+ // out register by a known JIT-time value. Clobbers the tmp register.
+ //
+ // Pros compared to mul/imul:
+ // - does not require using known registers
+ // - not microcoded on Intel(R) Xeon Phi(TM) processors
+ // Still, there are probably a lot of cases when mul/imul is faster on
+ // Intel(R) Core(TM) processors. Not intended for critical path.
+
+ // TODO: detect when overflow is emminent (Roma)
+ // TODO: detect when using mul/imul is a better option (Roma)
+
+ int p = 0; // the current power of 2
+ int old_p = 0; // the last seen power of 2 such that value[old_p] != 0
+
+ xor_(tmp, tmp);
+ while (value) {
+ if (value & 1) {
+ int shift = p - old_p;
+ if (shift) {
+ shl(out, shift);
+ old_p = p;
+ }
+ add(tmp, out);
+ }
+ value >>= 1;
+ p++;
+ }
+ mov(out, tmp);
+ }
+
+public:
+ jit_generator(
+ void *code_ptr = nullptr,
+ size_t code_size = 256 * 1024
+ ) : Xbyak::CodeGenerator(code_size, code_ptr)
+ {
+ }
+ virtual ~jit_generator() {}
+
+ virtual const char *name() const = 0;
+ virtual const char *source_file() const = 0;
+
+ const Xbyak::uint8 *getCode() {
+ const Xbyak::uint8 *code = CodeGenerator::getCode();
+ size_t code_size = getSize();
+ jit_utils::register_jit_code(code, code_size, name(), source_file());
+ return code;
+ }
+
+ template<typename F> const F getCode() {
+ return (const F)getCode();
+ }
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp
new file mode 100644
index 0000000000..56d7f592e2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_primitive_conf.hpp
@@ -0,0 +1,481 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_PRIMITIVE_CONF_HPP
+#define JIT_PRIMITIVE_CONF_HPP
+
+#include <stdint.h>
+
+#include "common/primitive_attr.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+/* convolution */
+enum conv_version_t {ver_unused, ver_fma, ver_avx512_core, ver_4fma, ver_vnni};
+enum conv_loop_order_t {loop_cgn, loop_gnc, loop_ngc, loop_gncw, loop_cwgn,
+ loop_ngcw, loop_nhwcg, loop_nwcg};
+enum conv_1x1_loop_order_t {loop_rbl, loop_rlb, loop_lbr, loop_lrb, loop_blr,
+ loop_brl};
+enum conv_kernel_kind_t {embd_bcast, expl_bcast};
+
+enum {
+ FLAG_MB_FIRST = 1 << 0, FLAG_MB_LAST = 1 << 1,
+ FLAG_OC_FIRST = 1 << 2, FLAG_OC_LAST = 1 << 3,
+ FLAG_IC_FIRST = 1 << 4, FLAG_IC_LAST = 1 << 5,
+ FLAG_SP_FIRST = 1 << 6, FLAG_SP_LAST = 1 << 7,
+ FLAG_REDUCE_FIRST = 1<<8, FLAG_REDUCE_LAST = 1<<9,
+ FLAG_ZERO_FILTER = 1 << 0, /* Controls whether the inner kernel skips
+ loading weights-data from memory; this
+ needs to happen on the first Group/16
+ iteration. */
+ FLAG_ZERO_BIAS = 1 << 1, /* Controls whether the inner kernel skip
+ loading bias data from memory */
+ FLAG_COMPUTE_BIAS = 1 << 2, /* Controls bias computation during execution
+ pass */
+};
+
+struct jit_conv_conf_t {
+ prop_kind_t prop_kind;
+ conv_version_t ver;
+ conv_loop_order_t loop_order;
+
+ int simd_w;
+ int ndims;
+ int mb;
+ int ngroups, ic, oc, oc_without_padding, ic_without_padding;
+ int id, ih, iw, od, oh, ow;
+ int f_pad, l_pad, t_pad;
+ int back_pad, r_pad, b_pad;
+ int kd, kh, kw;
+ int stride_d, stride_h, stride_w;
+ int dilate_d, dilate_h, dilate_w;
+ format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
+ bool with_bias;
+ bool with_sum;
+ bool with_eltwise;
+
+ post_ops_t::entry_t::eltwise_t eltwise;
+
+ int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
+
+ int idp, ihp, iwp, ohp, owp;
+ int nb_ic, ic_block;
+ int nb_oc, oc_block;
+ int nb_ow, ow_block;
+ int nb_oc_blocking; /* used in jit kernels for nb_oc work bloking taking
+ into account vector registers distribution */
+ int nb_oc_blocking_thr_chunk; /* used for distibution of nb_oc work
+ within threads */
+ int nb_ic_blocking, nb_ic_blocking_max; // blocking of nb_ic work
+ int nb_ic_L2;
+ int h_blocking;
+ int nb_oc_L2;
+ int ur_h, ur_w;
+ int ur_w_tail;
+ bool is_1stconv;
+ int nonblk_group_off;
+ /* fma avx512_core */
+ conv_kernel_kind_t kernel_kind;
+ /* 4fma */
+ int tr_iw;
+ int tr_src_num_guard_elems;
+ /* 1st conv: 4fma */
+ int tr_ld;
+ int kh_step;
+ /* 4vnni */
+ int typesize_in;
+ int typesize_out;
+ int typesize_bia;
+ int typesize_acc;
+ /* avx512_u8s8u8 */
+ int ic_nb1, ic_nb2;
+ int oc_nb1;
+ int ur_ow_max, ur_ow, ur_ow_tail;
+ int ur_ow_nsteps;
+ data_type_t bia_dt;
+ data_type_t dst_dt;
+ /* avx512: max possible value is nregs(32) - aux_regs(4) */
+ int src_offsets[28];
+ int src_count;
+ bool expl_bcast;
+ bool large_spatial;
+ int is_oc_scale;
+ int max_regs_ur; // maximum accumulation registers
+ // dw conv
+ int nb_ch, ch_block, nb_ch_blocking;
+ bool is_depthwise, is_fast_depthwise, is_resrc_depthwise;
+ int aligned_threads;
+ // large spatial
+ int oh_blk_size;
+ // s8s8 convolution
+ bool signed_input;
+ float wei_adj_scale;
+};
+
+struct jit_conv_conf_2x3_wino_t {
+ conv_version_t ver;
+
+ int m;
+ int r;
+ int alpha;
+ int tile_h, tile_w;
+
+ int mb;
+ int ngroups, ic, oc, oc_without_padding;
+ int ih, iw, oh, ow;
+ int l_pad, t_pad;
+ int r_pad, b_pad;
+ int kh, kw;
+ int stride_h, stride_w;
+ int dilate_h, dilate_w;
+
+ int nb_ic, ic_block;
+ int nb_oc, oc_block;
+
+ int w_block_size, h_block_size;
+
+ data_type_t bia_dt;
+ data_type_t dst_dt;
+
+ int is_oc_scale;
+ int typesize_in;
+ int typesize_out;
+ int typesize_bia;
+ int typesize_acc;
+
+ format_tag_t src_tag, dst_tag; // temporary workaround
+ bool with_bias;
+ bool small_mb;
+
+ int xb, yb;
+ int inp_stride;
+ int out_stride;
+ int wei_stride;
+ int bia_stride;
+
+ int M, N, K;
+ int m_block, n_block, k_block;
+ int n2_block, n_chunks;
+ int k2_block, k_chunks;
+
+ int mb_block, nb_mb;
+
+ size_t size_wino_src, size_wino_wei, size_wino_dst;
+
+ int nthr;
+};
+
+/*
+ Winograd sched policy:
+
+ Computation Unit:
+ W: weights transform
+ S: src transform
+ D: dst transform
+ G: gemm
+
+ Thread grouping by:
+ i: nb_ic
+ o: nb_oc
+ t: tile_block
+ e: element in tile
+
+ Note: 'i' and 'o' are omited if
+ i. not comblined with t or
+ ii. with discrete transforms
+
+ Current policies supported:
+*/
+enum winograd_sched_t {
+ WSCHED_INVALID = 0,
+
+ /* Forward & backward-data */
+ /* W_S_G_D implements discrete transforms */
+ WSCHED_DATA_W_S_G_D,
+ /* W_SGD implements tiled transforms s.t. GEMM could reuse data in L2*/
+ WSCHED_DATA_W_SGD,
+
+ /* Backward-weights */
+ WSCHED_WEI_S_D_G_W,
+ WSCHED_WEI_SDGtWo,
+ WSCHED_WEI_S_D_Giot_W,
+ WSCHED_WEI_SDGt_W,
+};
+
+struct jit_conv_winograd_conf_t : public jit_conv_conf_t {
+ int itiles;
+ int jtiles;
+ int ntiles;
+ int ic_simd_block=16;
+ int tile_4fma_padding;
+ int tile_4fma;
+ int oc_simd_block=16;
+ int oc_reg_block;
+ int ic_reg_block;
+ int tile_block;
+ int tile_block_ur;
+ int nb_tile_block_ur;
+
+ bool double_buffering;
+ bool with_relu_postsum;
+ int zmm_start;
+ int nb_reg;
+
+ int dimK;
+ int dimK_4fma;
+ int dimK_reg_block;
+ int dimK_block;
+ int dimK_nb_block;
+
+ int dimM;
+ int dimM_reg_block;
+ int dimM_simd_block;
+ int dimM_block;
+ int dimM_nb_block;
+
+ int dimN;
+ int dimN_reg_block;
+ int dimN_bcast_ur;
+ int dimN_block;
+ int dimN_nb_block;
+
+ winograd_sched_t sched_policy;
+};
+
+struct jit_conv_call_s {
+ const void *src; /* hack, non-const for backward_data */
+ const void *dst; /* hack, non-const for forward */
+ const void *filt; /* hack, non-const for backward_weights */
+ const void *bias; /* hack, non-const for backward_bias */
+ const void *src_prf;
+ const void *dst_prf;
+ const void *filt_prf;
+ const void *bias_prf;
+ const void *scales;
+ const void *acc_s32;
+ const void *compensation;
+ size_t kd_offset;
+ size_t kd_offset_prf;
+ size_t d_index;
+ size_t d_index_prf;
+ size_t d_worksize;
+ size_t d_worksize_prf;
+ size_t kd_padding;
+ size_t kd_padding_prf;
+ size_t kh_padding;
+ size_t kh_padding_prf;
+ size_t owb;
+ size_t owb_prf;
+ size_t kw_padding;
+ size_t channel;
+ size_t channel_prf;
+ size_t oc_blocks;
+ size_t ur_w;
+ size_t ur_str_w;
+ size_t ch_blocks;
+ size_t t_overflow;
+ size_t b_overflow;
+ int flags;
+};
+
+struct jit_deconv_call_s {
+ const void *src; /* hack, non-const for backward_data */
+ const void *dst; /* hack, non-const for forward */
+ const void *filt; /* hack, non-const for backward_weights */
+ const void *bias; /* hack, non-const for backward_bias */
+ const void *scales;
+ const void *compensation;
+ size_t t_overflow;
+ size_t b_overflow;
+ size_t kh_padding;
+ size_t oc_blocks;
+};
+
+struct jit_dw_conv_call_s {
+ const void *input;
+ const void *output;
+ const void *filter;
+ const void *bias;
+ size_t kh_count;
+ size_t oh_count;
+ size_t oh_index;
+ size_t filter_pad_off;
+ unsigned char
+ exec_flags; /* Flags passed by driver execution to inner kernel */
+};
+
+struct jit_wino_transform_call_s {
+ size_t tile_block;
+ size_t tile_block_ur;
+ size_t nb_tile_block_ur;
+ size_t tile_count;
+ size_t tj;
+ size_t ti;
+ void *src;
+ void *dst;
+ void *Mw;
+ void *M;
+ void *T;
+ void *G;
+ void *bias;
+};
+
+struct jit_1x1_conv_conf_t {
+ prop_kind_t prop_kind;
+ conv_version_t ver;
+
+ int mb;
+ int ngroups, ic, oc, oc_without_padding, ic_without_padding;
+ int iw, ih, ow, oh;
+ int l_pad, t_pad;
+ int kh, kw;
+ int stride_h, stride_w;
+ format_tag_t src_tag, wei_tag, dst_tag; // temporary workaround
+ bool with_bias;
+ bool with_sum;
+ bool with_eltwise;
+
+ post_ops_t::entry_t::eltwise_t eltwise;
+
+ int is, os;
+ int ic_block, oc_block;
+
+ int ur, ur_tail;
+
+ int reduce_dim, reduce_block, nb_reduce,
+ nb_reduce_blocking, nb_reduce_blocking_max;
+ int load_dim, load_block, nb_load,
+ nb_load_blocking, nb_load_blocking_max, nb_load_chunk;
+ int bcast_dim, bcast_block, nb_bcast,
+ nb_bcast_blocking, nb_bcast_blocking_max;
+
+ int reduce_loop_unroll, reduce_loop_bcast_step, reduce_loop_load_step;
+ int load_loop_load_step, load_loop_iter_step;
+ int bcast_loop_output_step, bcast_loop_output_substep;
+ int bcast_loop_bcast_step, bcast_loop_bcast_substep;
+ int fma_step;
+ int load_grp_count;
+ conv_1x1_loop_order_t loop_order;
+ bool use_vmovntps;
+ /* avx512 core */
+ bool expl_bcast;
+ /* 4vnni */
+ int typesize_in;
+ int typesize_out;
+ int typesize_bia;
+ int typesize_acc;
+ /* 4fma */
+ bool transpose_src;
+ int tr_is;
+ int nthr, nthr_mb, nthr_g, nthr_oc_b, nthr_ic_b;
+ int is_oc_scale;
+ data_type_t bia_dt;
+ data_type_t dst_dt;
+ bool signed_input;
+ float wei_adj_scale;
+};
+
+struct jit_gemm_conv_conf_t {
+ prop_kind_t prop_kind;
+
+ int mb;
+ int ngroups, ic, oc;
+ int iw, ih, id, ow, oh, od;
+ int l_pad, t_pad, f_pad;
+ int kh, kw, kd;
+ int stride_h, stride_w, stride_d;
+ int dilate_h, dilate_w, dilate_d;
+ bool with_bias;
+
+ int is, os, ks;
+ int ic_block, oc_block;
+
+ int nthr;
+ ptrdiff_t im2col_sz;
+ bool need_wei_reduction;
+ bool signed_input;
+ int oh_block;
+ int ow_block;
+ bool outer_threading;
+};
+
+struct jit_1x1_conv_call_s {
+ const void *bcast_data;
+ const void *load_data;
+ const void *output_data;
+ const void *bias_data; // used in forward and backward_weights only
+ const void *acc_s32;
+ const void *scales;
+ const void *compensation;
+
+ size_t load_dim;
+ size_t bcast_dim;
+ size_t reduce_dim;
+
+ size_t output_stride; // used in backward_weights only
+
+ size_t first_last_flag;
+};
+
+/* pooling */
+struct jit_pool_conf_t {
+ int ndims;
+ int mb, c;
+ int id, ih, iw, od, oh, ow;
+ int stride_d, stride_h, stride_w;
+ int kd, kh, kw;
+ int f_pad, t_pad, l_pad;
+ alg_kind_t alg;
+ bool is_training;
+ bool pad_w_is_null;
+ bool is_backward;
+ bool simple_alg;
+ data_type_t ind_dt;
+
+ int c_block, c_tail, nb_c;
+ int ur_c, ur_c_tail;
+ int ur_w;
+ int ur_w_tail;
+ size_t tail[4];
+ data_type_t src_dt;
+ data_type_t dst_dt;
+};
+
+struct jit_pool_call_s {
+ const float *src;
+ const float *dst;
+ const void *indices;
+ const float *src_prf;
+ const float *dst_prf;
+ const void *indices_prf;
+ size_t oh;
+ size_t kd_padding;
+ size_t kh_padding;
+ size_t kh_padding_shift;
+ size_t kd_padding_shift;
+ size_t kw_padding;
+ const float* init_value;
+ float ker_area_h;
+};
+
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp
new file mode 100644
index 0000000000..94d2101d6e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.cpp
@@ -0,0 +1,677 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "cpu_memory.hpp"
+
+#include "jit_sse42_1x1_conv_kernel_f32.hpp"
+
+#define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::utils;
+
+using namespace Xbyak;
+
+void jit_sse42_1x1_conv_kernel_f32::generate_bcast_loop(int load_loop_blk)
+{
+ mov(aux1_reg_bcast_data, reg_bcast_data);
+ mov(aux_reg_output_data, reg_output_data);
+ mov(bcast_loop_iter, reg_bcast_loop_work);
+
+ Label bcast_loop;
+ Label bcast_loop_tail;
+
+ cmp(bcast_loop_iter, jcp.ur);
+ jl(bcast_loop_tail, T_NEAR);
+
+ L(bcast_loop); {
+ assert(jcp.bcast_block % jcp.ur == 0);
+ int num_substeps = jcp.bcast_block / jcp.ur;
+ assert(num_substeps > 0 && num_substeps < 10);
+ for (int i = 0; i < num_substeps; i++) {
+ generate_reduce_loop(load_loop_blk, jcp.ur);
+ if (i < num_substeps - 1) {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
+ add(aux_reg_output_data, jcp.bcast_loop_output_substep);
+ } else {
+ add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
+ - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
+ add(aux_reg_output_data, jcp.bcast_loop_output_step
+ - (num_substeps - 1) * jcp.bcast_loop_output_substep);
+ }
+ }
+ sub(bcast_loop_iter, jcp.bcast_block);
+ cmp(bcast_loop_iter, jcp.bcast_block);
+ jge(bcast_loop, T_NEAR);
+ }
+
+ L(bcast_loop_tail);
+ if (jcp.ur_tail) {
+ Label bcast_loop_tail_out;
+ cmp(bcast_loop_iter, 0);
+ jz(bcast_loop_tail_out, T_NEAR);
+ generate_reduce_loop(load_loop_blk, jcp.ur_tail);
+ L(bcast_loop_tail_out);
+ }
+}
+
+void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop(
+ int load_loop_blk, int ur)
+{
+ auto reg_load = [=](int i, int n) {
+ return Xmm(2*ur * load_loop_blk + 2*i + n + 1);
+ };
+
+ auto reg_accum = [=](int i, int j, int n) {
+ return Xmm(2*j * load_loop_blk + 2*i + n + 1);
+ };
+
+ auto bias_ptr = [=](int i, int n) {
+ return ptr[reg_bias_data + sizeof(float) * jcp.oc_block * i + n*4*sizeof(float)];
+ };
+
+ auto bcast_ptr = [=](int u, int j) {
+ assert(j < jcp.ur);
+ assert(u <= jcp.reduce_loop_unroll);
+ size_t offt;
+ if (one_of(jcp.prop_kind,
+ forward_training, forward_inference, backward_data)) {
+ assert(jcp.reduce_loop_unroll == (jcp.prop_kind == backward_data)
+ ? jcp.oc_block : jcp.ic_block);
+ auto height = (jcp.prop_kind == backward_data) ? jcp.os : jcp.is;
+ offt = (u == jcp.reduce_loop_unroll)
+ ? (height + j) * jcp.reduce_loop_unroll
+ : j * jcp.reduce_loop_unroll + u;
+ } else
+ offt = u * jcp.ic_block + j;
+ return ptr[aux_reg_bcast_data + sizeof(float) * offt];
+ };
+
+ auto load_ptr = [=](int u, int i, int n) {
+ size_t offt;
+ size_t u0 = u % jcp.reduce_loop_unroll;
+ size_t u1 = u / jcp.reduce_loop_unroll;
+ switch (jcp.prop_kind) {
+ case backward_data:
+ offt = (i * jcp.oc_block + u0) * jcp.ic_block;
+ break;
+ case backward_weights:
+ offt = (i * jcp.os + u0) * jcp.oc_block;
+ break;
+ default:
+ offt = (i * jcp.ic + u0) * jcp.oc_block;
+ }
+ return ptr[aux_reg_load_data
+ + u1 * jcp.reduce_loop_load_step + sizeof(float) * offt + n * 4 * sizeof(float)];
+ };
+
+ auto output_ptr = [=](int i, int j, int n) {
+ switch (jcp.prop_kind) {
+ case backward_data:
+ return ptr[aux_reg_output_data +
+ (i * jcp.is + j) * jcp.ic_block * sizeof(float) + n * 4 * sizeof(float)];
+ case backward_weights:
+ return ptr[aux_reg_output_data
+ + (i ? reg_output_stride * i : 0) // TODO: Xbyak should allow 0 scale
+ + sizeof(float) * jcp.oc_block * j + n * 4 * sizeof(float)];
+ default:
+ return ptr[aux_reg_output_data +
+ (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
+ }
+ };
+
+ auto init = [=]() {
+ Label init_done;
+ Label init_zero;
+
+ if (jcp.with_bias && one_of(jcp.prop_kind, forward_training,
+ forward_inference)) {
+ test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
+ jz(init_zero);
+
+ for (int i = 0; i < load_loop_blk; i++)
+ for (int j = 0; j < ur; ++j) {
+ movups(reg_accum(i, j, 0), bias_ptr(i, 0));
+ movups(reg_accum(i, j, 1), bias_ptr(i, 1));
+ }
+ jmp(init_done);
+ }
+
+ L(init_zero);
+ for (int i = 0; i < load_loop_blk; ++i)
+ for (int j = 0; j < ur; ++j) {
+ auto r0 = reg_accum(i, j, 0);
+ auto r1 = reg_accum(i, j, 1);
+ xorps(r0, r0);
+ xorps(r1, r1);
+ }
+
+ L(init_done);
+
+ // load weights
+ for (int i = 0; i < load_loop_blk; ++i) {
+ movups(reg_load(i, 0), load_ptr(0, i, 0));
+ movups(reg_load(i, 1), load_ptr(0, i, 1));
+ }
+
+ movss(reg_bcast, bcast_ptr(0, 0));
+ shufps(reg_bcast, reg_bcast, 0);
+ }; // init()
+
+ auto store = [=]() {
+ Label store_noadd;
+
+ if (!jcp.with_sum) {
+ test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
+ jnz(store_noadd, T_NEAR);
+ }
+
+ for (int j = 0; j < ur; ++j)
+ for (int i = 0; i < load_loop_blk; ++i) {
+ auto r0 = reg_accum(i, j, 0);
+ auto r1 = reg_accum(i, j, 1);
+ addps(r0, output_ptr(i, j, 0));
+ addps(r1, output_ptr(i, j, 1));
+ }
+
+ L(store_noadd);
+
+ if (jcp.with_eltwise) {
+ assert(ur * load_loop_blk < 14);
+
+ Label store_norelu;
+ test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
+ jz(store_norelu, T_NEAR);
+
+ eltwise_injector_->compute_vector_range(1,
+ 2 * ur * load_loop_blk + 1);
+
+ L(store_norelu);
+ }
+
+ for (int j = 0; j < ur; ++j)
+ for (int i = 0; i < load_loop_blk; ++i) {
+ movups(output_ptr(i, j, 0), reg_accum(i, j, 0));
+ movups(output_ptr(i, j, 1), reg_accum(i, j, 1));
+ }
+ };
+
+ auto fma_block = [=](bool last_block) {
+ for (int u = 0; u < jcp.reduce_loop_unroll; ++u) {
+ for (int j = 0; j < ur; ++j) {
+ for (int i = 0; i < load_loop_blk; ++i) {
+ mulps(reg_load(i, 0), reg_bcast);
+ mulps(reg_load(i, 1), reg_bcast);
+ addps(reg_accum(i, j, 0), reg_load(i, 0));
+ addps(reg_accum(i, j, 1), reg_load(i, 1));
+
+ if (j == ur - 1 && !(last_block && u == jcp.reduce_loop_unroll - 1)) {
+ movups(reg_load(i, 0), load_ptr(u + 1, i, 0));
+ movups(reg_load(i, 1), load_ptr(u + 1, i, 1));
+ }
+ }
+ if (j < ur - 1) {
+ movss(reg_bcast, bcast_ptr(u, j + 1));
+ shufps(reg_bcast, reg_bcast, 0);
+ }
+ } // for ur
+ if (!last_block || u < jcp.reduce_loop_unroll - 1) {
+ movss(reg_bcast, bcast_ptr(u + 1, 0));
+ shufps(reg_bcast, reg_bcast, 0);
+ }
+ } // for reduce_loop_unroll
+ };
+
+ Label reduce_loop;
+ Label reduce_loop_tail;
+
+ mov(aux_reg_load_data, reg_load_data);
+ mov(aux_reg_bcast_data, aux1_reg_bcast_data);
+
+ init();
+
+ mov(reduce_loop_iter, reg_reduce_loop_work);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jle(reduce_loop_tail, T_NEAR);
+
+ L(reduce_loop); {
+ fma_block(false);
+ add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
+ add(aux_reg_load_data, jcp.reduce_loop_load_step);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jg(reduce_loop, T_NEAR);
+ }
+
+ L(reduce_loop_tail);
+ fma_block(true);
+
+ store();
+} // reduce_loop()
+
+void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
+{
+ if (!jcp.with_bias || jcp.prop_kind != backward_weights)
+ return;
+
+ Label diff_bias_loop, diff_bias_loop_out, diff_bias_init_out;
+ Label diff_bias_load;
+
+ auto diff_bias_ptr = [=](int i, int n) {
+ return ptr[reg_diff_bias_data + i * jcp.oc_block * sizeof(float)+ 4*n*sizeof(float)];
+ };
+
+ auto load_ptr = [=](int u, int i, int n) {
+ return ptr[aux_reg_load_data
+ + (i * jcp.os + u) * jcp.oc_block * sizeof(float) + 4*n*sizeof(float)];
+ };
+
+ auto diff_bias_reg = [=](int i, int n) { return Xmm(2*i + n + 1); };
+
+ mov(reg_diff_bias_data, ptr[rsp + reg_diff_bias_data_stack_offt]);
+ cmp(reg_diff_bias_data, 0);
+ je(diff_bias_loop_out, T_NEAR);
+
+ test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
+ jz(diff_bias_load, T_NEAR);
+
+ for (int i = 0; i < load_loop_blk; ++i) {
+ auto r0 = diff_bias_reg(i, 0);
+ auto r1 = diff_bias_reg(i, 1);
+ xorps(r0, r0);
+ xorps(r1, r1);
+ }
+ jmp(diff_bias_init_out, T_NEAR);
+
+ L(diff_bias_load);
+ for (int i = 0; i < load_loop_blk; ++i) {
+ movups(diff_bias_reg(i, 0), diff_bias_ptr(i, 0));
+ movups(diff_bias_reg(i, 1), diff_bias_ptr(i, 1));
+ }
+
+ L(diff_bias_init_out);
+ mov(aux_reg_load_data, reg_load_data);
+ mov(reduce_loop_iter, reg_reduce_loop_work);
+ L(diff_bias_loop); {
+ for(int u = 0; u < jcp.reduce_loop_unroll; ++u)
+ for (int i = 0; i < load_loop_blk; ++i) {
+ addps(diff_bias_reg(i, 0), load_ptr(u, i, 0));
+ addps(diff_bias_reg(i, 1), load_ptr(u, i, 1));
+ }
+ assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
+ add(aux_reg_load_data, jcp.reduce_loop_load_step);
+ sub(reduce_loop_iter, jcp.reduce_loop_unroll);
+ jnz(diff_bias_loop, T_NEAR);
+ }
+
+ for (int i = 0; i < load_loop_blk; i++) {
+ movups(diff_bias_ptr(i, 0), diff_bias_reg(i, 0));
+ movups(diff_bias_ptr(i, 1), diff_bias_reg(i, 1));
+ }
+
+ add(reg_diff_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
+ mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
+
+ L(diff_bias_loop_out);
+}
+
+void jit_sse42_1x1_conv_kernel_f32::generate()
+{
+ preamble();
+
+ mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
+ mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
+ mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
+ if (jcp.with_bias) {
+ if (jcp.prop_kind == backward_weights) {
+ sub(rsp, stack_space_needed);
+ mov(reg_diff_bias_data, ptr[param1 + GET_OFF(bias_data)]);
+ mov(ptr[rsp + reg_diff_bias_data_stack_offt], reg_diff_bias_data);
+ } else
+ mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
+ }
+
+ mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
+ mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
+ mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
+ mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
+ if (jcp.prop_kind == backward_weights)
+ mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
+
+ auto generate_load_loop_body = [=] (int load_loop_blk) {
+ generate_bcast_loop(load_loop_blk);
+ add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
+ switch (jcp.prop_kind) {
+ case forward_training:
+ case forward_inference:
+ add(reg_bias_data, load_loop_blk * jcp.oc_block * sizeof(float));
+ add(reg_output_data,
+ load_loop_blk * jcp.os * jcp.oc_block * sizeof(float));
+ break;
+ case backward_data:
+ add(reg_output_data,
+ load_loop_blk * jcp.is * jcp.ic_block * sizeof(float));
+ break;
+ case backward_weights:
+ for (int i = 0; i < load_loop_blk; i++)
+ add(reg_output_data, reg_output_stride);
+ break;
+ default:
+ assert(!"invalid prop_kind");
+ }
+ sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
+ };
+
+ Label load_loop_blk_8;
+ Label load_loop_blk_16;
+ Label load_loop_blk_24;
+ Label load_loop_blk_end;
+
+ cmp(reg_load_loop_work, 8);
+ jle(load_loop_blk_8, T_NEAR);
+
+ cmp(reg_load_loop_work, 32);
+ je(load_loop_blk_16, T_NEAR);
+
+ cmp(reg_load_loop_work, 16);
+ jle(load_loop_blk_16, T_NEAR);
+
+ L(load_loop_blk_24); {
+ generate_diff_bias_loop(3);
+ generate_load_loop_body(3);
+ cmp(reg_load_loop_work, 32);
+ je(load_loop_blk_16);
+ cmp(reg_load_loop_work, 24);
+ jge(load_loop_blk_24);
+ }
+
+ cmp(reg_load_loop_work, 8);
+ jle(load_loop_blk_8, T_NEAR);
+
+ L(load_loop_blk_16); {
+ generate_diff_bias_loop(2);
+ generate_load_loop_body(2);
+ cmp(reg_load_loop_work, 16);
+ jge(load_loop_blk_16);
+ }
+
+ L(load_loop_blk_8); {
+ cmp(reg_load_loop_work, 0);
+ je(load_loop_blk_end, T_NEAR);
+ generate_diff_bias_loop(1);
+ generate_load_loop_body(1);
+ }
+
+ L(load_loop_blk_end);
+
+ if (jcp.with_bias && jcp.prop_kind == backward_weights)
+ add(rsp, stack_space_needed);
+
+ postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+}
+
+bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok(
+ jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+
+ switch (p.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
+ case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr)
+{
+ if (!mayiuse(sse42))
+ return status::unimplemented;
+
+ // TODO (Roma): this code is duplicated from the generic kernel; maybe the
+ // configuration struct could do some stuff below
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ const int ndims = src_d.ndims();
+
+ jcp.prop_kind = cd.prop_kind;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+
+ jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
+ jcp.iw = src_d.dims()[ndims - 1];
+ jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[ndims - 1];
+
+ jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
+ jcp.kw = weights_d.dims()[with_groups + ndims - 1];
+
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][ndims - 3];
+
+ jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
+ jcp.stride_w = cd.strides[ndims - 3];
+
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ jcp.os = jcp.oh * jcp.ow;
+ jcp.is = jcp.ih * jcp.iw;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ jcp.with_sum = p.find(primitive_kind::sum) != -1;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise)
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+
+ const int is_bwd_d = jcp.prop_kind == backward_data;
+
+ format_tag_t dat_tag = ndims == 3 ? nCw8c : nChw8c;
+ format_tag_t wei_tag = with_groups
+ ? utils::pick(2 * ndims - 6 + is_bwd_d, gOIw8i8o, gOIw8o8i, gOIhw8i8o,
+ gOIhw8o8i)
+ : utils::pick(2 * ndims - 6 + is_bwd_d, OIw8i8o, OIw8o8i, OIhw8i8o,
+ OIhw8o8i);
+
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+
+ bool args_ok = true
+ && jcp.ngroups == 1
+ && jcp.src_tag == dat_tag
+ && jcp.wei_tag == wei_tag
+ && jcp.dst_tag == dat_tag;
+ if (!args_ok) return status::unimplemented;
+
+ const int simd_w = 4;
+ jcp.ic_block = jcp.oc_block = simd_w*2;
+
+ args_ok = true
+ && jcp.oc % jcp.oc_block == 0
+ && jcp.ic % jcp.ic_block == 0
+ && jcp.t_pad == 0 && jcp.l_pad == 0
+ && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
+ && jcp.kh == 1 && jcp.kw == 1;
+ if (!args_ok) return status::unimplemented;
+
+ jcp.ur = 1;
+
+ int load_blocking{ 0 };
+ int load_blocking_max{ 0 };
+ int bcast_blocking{ 0 };
+ int bcast_blocking_max{ 0 };
+ int reduce_blocking{ 0 };
+
+ if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
+ jcp.reduce_dim = jcp.ic;
+ jcp.reduce_block = jcp.ic_block;
+
+ jcp.load_dim = jcp.oc;
+ jcp.load_block = jcp.oc_block;
+
+ jcp.bcast_dim = jcp.is;
+ jcp.bcast_block = jcp.ur;
+
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.is * sizeof(float);
+ jcp.reduce_loop_load_step
+ = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
+
+ jcp.bcast_loop_output_step = jcp.ur * jcp.oc_block * sizeof(float);
+ jcp.bcast_loop_output_substep = -1; // unused
+ jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_block * sizeof(float);
+ jcp.bcast_loop_bcast_substep = -1; // unused
+
+ jcp.load_loop_load_step = jcp.ic * jcp.oc_block * sizeof(float);
+ jcp.load_loop_iter_step = jcp.oc_block;
+
+ load_blocking = 120; // assumes the kernel is jcp.ur x 3
+ load_blocking_max = 144;
+ bcast_blocking = 128; // affects load balancing across threads
+ bcast_blocking_max = 192;
+ reduce_blocking = 128; // affects L1$ utilization
+ } else if (jcp.prop_kind == backward_data) {
+ jcp.reduce_dim = jcp.oc;
+ jcp.reduce_block = jcp.oc_block;
+
+ jcp.load_dim = jcp.ic;
+ jcp.load_block = jcp.oc_block;
+
+ jcp.bcast_dim = jcp.os;
+ jcp.bcast_block = jcp.ur;
+
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.os * sizeof(float);
+ jcp.reduce_loop_load_step
+ = jcp.reduce_loop_unroll * jcp.ic * sizeof(float);
+
+ jcp.bcast_loop_output_step = jcp.ur * jcp.ic_block * sizeof(float);
+ jcp.bcast_loop_output_substep = -1; // unused
+ jcp.bcast_loop_bcast_step = jcp.ur * jcp.oc_block * sizeof(float);
+ jcp.bcast_loop_bcast_substep = -1; // unused
+
+ jcp.load_loop_load_step = jcp.oc_block * jcp.ic_block * sizeof(float);
+ jcp.load_loop_iter_step = jcp.ic_block;
+
+ load_blocking = 96; // assumes the kernel is jcp.ur x 3
+ load_blocking_max = 144;
+ bcast_blocking = 128; // affects load balancing across threads
+ bcast_blocking_max = 196;
+ reduce_blocking = 64; // affects L1$ utilization
+ } else if (jcp.prop_kind == backward_weights) {
+ jcp.reduce_dim = jcp.os;
+ jcp.reduce_block = 1;
+
+ jcp.load_dim = jcp.oc;
+ jcp.load_block = jcp.oc_block;
+
+ jcp.bcast_dim = jcp.ic;
+ jcp.bcast_block = jcp.ic_block;
+
+ jcp.reduce_loop_unroll = jcp.reduce_block;
+ jcp.reduce_loop_bcast_step
+ = jcp.reduce_loop_unroll * jcp.ic_block * sizeof(float);
+ jcp.reduce_loop_load_step
+ = jcp.reduce_loop_unroll * jcp.oc_block * sizeof(float);
+
+ jcp.bcast_loop_output_step = jcp.oc_block * jcp.ic_block * sizeof(float);
+ jcp.bcast_loop_output_substep = jcp.oc_block * jcp.ur * sizeof(float);
+ jcp.bcast_loop_bcast_step = jcp.ic_block * jcp.is * sizeof(float);
+ jcp.bcast_loop_bcast_substep = jcp.ur * sizeof(float);
+
+ jcp.load_loop_load_step = jcp.oc_block * jcp.os * sizeof(float);
+ jcp.load_loop_iter_step = jcp.oc_block;
+
+ /* --- */
+
+ load_blocking = div_up(jcp.load_dim, jcp.load_block);
+ while (true) {
+ if (load_blocking <= 32) break;
+ else if (load_blocking % 2 == 0) load_blocking /= 2;
+ else if (load_blocking % 3 == 0) load_blocking /= 3;
+ else break;
+ }
+ load_blocking *= jcp.load_block;
+ load_blocking_max = load_blocking;
+ assert(jcp.load_dim % load_blocking == 0);
+
+ bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
+ while (true) {
+ if (bcast_blocking <= 9) break;
+ else if (bcast_blocking % 2 == 0) bcast_blocking /= 2;
+ else if (bcast_blocking % 3 == 0) bcast_blocking /= 3;
+ else break;
+ }
+ bcast_blocking *= jcp.bcast_block;
+ bcast_blocking_max = bcast_blocking;
+ assert(jcp.bcast_dim % bcast_blocking == 0);
+
+ reduce_blocking = 128; // affects L1$ utilization
+ } else
+ return status::unimplemented;
+
+ assert(load_blocking);
+ assert(load_blocking_max);
+ assert(bcast_blocking);
+ assert(bcast_blocking_max);
+ assert(reduce_blocking);
+
+ assert(jcp.bcast_block % jcp.ur == 0);
+ jcp.ur_tail = jcp.bcast_dim % jcp.ur;
+
+ jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
+ jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
+ jcp.nb_load_blocking = load_blocking / jcp.load_block;
+ jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
+ jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
+
+ jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
+ jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
+ jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
+
+ return status::success;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp
new file mode 100644
index 0000000000..b314a5098c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_conv_kernel_f32.hpp
@@ -0,0 +1,104 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_SSE42_1x1_CONV_KERNEL_F32_HPP
+#define JIT_SSE42_1x1_CONV_KERNEL_F32_HPP
+
+#include "c_types_map.hpp"
+#include "cpu_memory.hpp"
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_sse42_1x1_conv_kernel_f32: public jit_generator {
+ jit_sse42_1x1_conv_kernel_f32(jit_1x1_conv_conf_t ajcp,
+ const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
+ {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<sse42>(this,
+ jcp.eltwise);
+
+ this->generate();
+ jit_ker = (void (*)(jit_1x1_conv_call_s *))this->getCode();
+ }
+
+ ~jit_sse42_1x1_conv_kernel_f32() {
+ delete eltwise_injector_;
+ }
+
+ static bool post_ops_ok(jit_1x1_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_1x1_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr);
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_1x1_conv_kernel_f32)
+
+ jit_1x1_conv_conf_t jcp;
+ const primitive_attr_t &attr_;
+ void (*jit_ker)(jit_1x1_conv_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ using xmm_t = const Xbyak::Xmm;
+
+ reg64_t reg_bcast_data = rax;
+ reg64_t reg_load_data = rsi;
+ reg64_t reg_output_data = rbx;
+ reg64_t aux_reg_bcast_data = rdx;
+ reg64_t aux1_reg_bcast_data = abi_not_param1;
+ reg64_t aux_reg_load_data = abi_param1;
+ reg64_t aux_reg_output_data = rbp;
+ reg64_t reg_load_loop_work = r9;
+ reg64_t reg_bcast_loop_work = r10;
+ reg64_t reg_reduce_loop_work = r11;
+ reg64_t load_loop_iter = r13;
+ reg64_t imm_addr64 = load_loop_iter;
+ reg64_t bcast_loop_iter = r14;
+ reg64_t reduce_loop_iter = r15;
+ reg64_t reg_reduce_pos_flag = r8;
+ reg64_t reg_output_stride = r12;
+ reg64_t reg_bias_data = r12;
+ reg64_t reg_diff_bias_data = bcast_loop_iter;
+
+ int reg_diff_bias_data_stack_offt = 0;
+ int stack_space_needed = 8;
+
+ xmm_t reg_bcast = xmm_t(15);
+
+ jit_uni_eltwise_injector_f32<sse42> *eltwise_injector_;
+
+ void generate_bcast_loop(int load_loop_blk);
+ void generate_reduce_loop(int load_loop_blk, int ur);
+ void generate_diff_bias_loop(int load_loop_blk);
+
+ void generate();
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp
new file mode 100644
index 0000000000..30c137641e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.cpp
@@ -0,0 +1,134 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "jit_sse42_1x1_convolution.hpp"
+#include "utils.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+#define data_blk_off(f, n, c, h, w) \
+ ((ndims == 3) \
+ ? (f).blk_off(n, c, w) \
+ : (f).blk_off(n, c, h, w))
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+
+void jit_sse42_1x1_convolution_fwd_t::execute_forward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+ const int ndims = src_d.ndims();
+
+ const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ // TODO (Roma): remove this restriction
+ assert(jcp.stride_w == 1 && jcp.stride_h == 1);
+
+ auto par_conv = jit_1x1_conv_call_s();
+
+ const int nb_oc = jcp.nb_load;
+ const int nb_ic = jcp.nb_reduce;
+ const int nb_ic_blocking = jcp.nb_reduce_blocking;
+ const int os_block = jcp.bcast_block;
+
+ int start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+
+ int iwork = start;
+ while (iwork < end) {
+ int n{0}, g{0}, osb{0};
+ nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb,
+ jcp.nb_bcast);
+
+ const int bcast_step_rem = jcp.nb_bcast - osb;
+ int bcast_step = bcast_step_rem <= jcp.nb_bcast_blocking_max
+ ? bcast_step_rem : jcp.nb_bcast_blocking;
+ bcast_step = nstl::min<int>(bcast_step, end - iwork);
+
+ const int os = osb * os_block;
+ const int ow = os % jcp.ow;
+ const int oh = os / jcp.ow;
+ const int iw = nstl::max<int>(ow * jcp.stride_w - jcp.l_pad, 0);
+ const int ih = nstl::max<int>(oh * jcp.stride_h - jcp.t_pad, 0);
+
+ par_conv.bcast_dim = this_block_size(os, jcp.os,
+ bcast_step * os_block);
+
+ int ocb = 0;
+ while (ocb < jcp.nb_load) {
+ const int load_step_rem = jcp.nb_load - ocb;
+ const int load_step = load_step_rem < jcp.nb_load_blocking_max
+ ? load_step_rem : jcp.nb_load_blocking;
+
+ const size_t _ocb = g * nb_oc + ocb;
+ par_conv.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc,
+ load_step * jcp.oc_block);
+
+ const size_t dst_off = data_blk_off(dst_d, n, _ocb, oh, ow);
+ par_conv.output_data = &dst[dst_off];
+
+ par_conv.bias_data = &bias[_ocb * jcp.oc_block];
+
+ for (int icb = 0; icb < nb_ic; icb += nb_ic_blocking) {
+ par_conv.first_last_flag = 0
+ | (icb == 0) * FLAG_REDUCE_FIRST
+ | (icb + nb_ic_blocking >= nb_ic) * FLAG_REDUCE_LAST;
+
+ par_conv.reduce_dim = this_block_size(icb * jcp.ic_block,
+ jcp.ic, nb_ic_blocking * jcp.ic_block);
+
+ const size_t _icb = g * nb_ic + icb;
+ const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw);
+ par_conv.bcast_data = &src[src_off];
+
+ par_conv.load_data = &weights[pd()->with_groups()
+ ? weights_d.blk_off(g, ocb, icb)
+ : weights_d.blk_off(ocb, icb)];
+
+ kernel_->jit_ker(&par_conv);
+ }
+
+ ocb += load_step;
+ }
+
+ iwork += bcast_step;
+ }
+ });
+
+ if (pd()->wants_zero_pad_dst())
+ ctx.memory(MKLDNN_ARG_DST)->zero_pad();
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp
new file mode 100644
index 0000000000..b32b1e4784
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_1x1_convolution.hpp
@@ -0,0 +1,96 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_SSE42_1x1_CONVOLUTION_HPP
+#define CPU_JIT_SSE42_1x1_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+#include "jit_sse42_1x1_conv_kernel_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_sse42_1x1_convolution_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_1x1:", sse42, ""),
+ jit_sse42_1x1_convolution_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ return jit_sse42_1x1_conv_kernel_f32::init_conf(jcp_, *desc(),
+ *src_md(), *weights_md(), *dst_md(), *attr());
+ }
+
+ jit_1x1_conv_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto wei_tag = with_groups()
+ ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o)
+ : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o);
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ jit_sse42_1x1_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {
+ kernel_ = new jit_sse42_1x1_conv_kernel_f32(pd()->jcp_, *pd()->attr());
+ }
+ ~jit_sse42_1x1_convolution_fwd_t() { delete kernel_; };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ jit_sse42_1x1_conv_kernel_f32 *kernel_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp
new file mode 100644
index 0000000000..17cabc1186
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp
@@ -0,0 +1,497 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "cpu_memory.hpp"
+
+#include "jit_sse42_conv_kernel_f32.hpp"
+
+#define GET_OFF(field) offsetof(jit_conv_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::utils;
+
+using namespace Xbyak;
+
+void jit_sse42_conv_fwd_kernel_f32::oh_step_unroll_kw(int ur_w,
+ int pad_l, int pad_r, int oc_blocks)
+{
+ int iw = jcp.iw;
+ int ih = jcp.ih;
+ int kw = jcp.kw;
+ int kh = jcp.kh;
+ int nb_ic = jcp.nb_ic;
+ int stride_w = jcp.stride_w;
+ int dilate_w = jcp.dilate_w + 1;
+ int ic_blk = jcp.ic_block;
+ int oc_blk = jcp.oc_block;
+
+ for (int ki = 0; ki < kw; ki++) {
+ int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
+ int jj_end = ur_w
+ - nstl::max(0, div_up(ki*dilate_w + pad_r - (kw-1)*dilate_w, stride_w));
+ for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ int inp_off;
+ if (one_of(jcp.src_tag, ncw, nchw))
+ inp_off = ifm2*ih*iw + (ki*dilate_w + jj*stride_w - pad_l);
+ else
+ inp_off = (ki*dilate_w + jj*stride_w - pad_l)*ic_blk + ifm2;
+
+ movss(Xmm(oc_blocks * ur_w + jj + 1),
+ ptr[aux_reg_input + sizeof(float) * inp_off]);
+ shufps(Xmm(oc_blocks * ur_w + jj + 1),
+ Xmm(oc_blocks * ur_w + jj + 1), 0x0);
+ }
+
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ int ker_off = ii * nb_ic * kh * kw * ic_blk * oc_blk
+ + ki * ic_blk * oc_blk + ifm2 * oc_blk;
+
+ for (int jj = jj_start; jj < jj_end; jj++)
+ {
+ movups(xmm0,
+ ptr[aux_reg_kernel + sizeof(float) * ker_off]);
+ mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
+ addps(Xmm(ur_w * ii + jj + 1), xmm0);
+ }
+ }
+ }
+ }
+}
+
+void jit_sse42_conv_fwd_kernel_f32::oh_step_nopad(int ur_w,
+ int pad_l, int pad_r, int oc_blocks)
+{
+ Label kw_loop;
+
+ int iw = jcp.iw;
+ int ih = jcp.ih;
+ int kw = jcp.kw;
+ int kh = jcp.kh;
+ int nb_ic = jcp.nb_ic;
+ int stride_w = jcp.stride_w;
+ int dilate_w = jcp.dilate_w + 1;
+ int ic_blk = jcp.ic_block;
+ int oc_blk = jcp.oc_block;
+
+ xor_(ki_iter, ki_iter);
+ L(kw_loop);
+ {
+ int jj_start = 0;
+ int jj_end = ur_w;
+ for (int ifm2 = 0; ifm2 < ic_blk; ifm2++) {
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ int inp_off;
+ if (one_of(jcp.src_tag, ncw, nchw))
+ inp_off = ifm2 * ih * iw + (jj * stride_w - pad_l);
+ else
+ inp_off = (jj * stride_w - pad_l) * ic_blk + ifm2;
+
+ movss(Xmm(oc_blocks * ur_w + jj + 1),
+ ptr[aux_reg_input + sizeof(float) * inp_off]);
+ shufps(Xmm(oc_blocks * ur_w + jj + 1),
+ Xmm(oc_blocks * ur_w + jj + 1), 0x0);
+ }
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ int aux_kernel_offset = ii * nb_ic * kh * kw * ic_blk * oc_blk
+ + ifm2 * oc_blk;
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ movups(xmm0,
+ ptr[aux_reg_kernel + sizeof(float) * aux_kernel_offset]);
+ mulps(xmm0, Xmm(oc_blocks * ur_w + jj + 1));
+ addps(Xmm(ur_w * ii + jj + 1), xmm0);
+ }
+ }
+ }
+ add(aux_reg_kernel, sizeof(float) * oc_blk * ic_blk);
+ add(aux_reg_input, sizeof(float) * (one_of(jcp.src_tag, ncw, nchw) ?
+ dilate_w : ic_blk * dilate_w));
+
+ inc(ki_iter);
+ cmp(ki_iter, kw);
+ jl(kw_loop, T_NEAR);
+ }
+}
+
+void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
+ int pad_l, int pad_r, int oc_blocks)
+{
+ int iw = jcp.iw;
+ int kw = jcp.kw;
+ int ow = jcp.ow;
+ int oh = jcp.oh;
+ int dilate_h = jcp.dilate_h + 1;
+ int dilate_w = jcp.dilate_w + 1;
+ int ic_blk = jcp.ic_block;
+ int oc_blk = jcp.oc_block;
+ const int inp_mult = one_of(jcp.src_tag, ncw, nchw)
+ ? dilate_h : ic_blk * dilate_h;
+ const int inp_off = one_of(jcp.src_tag, ncw, nchw)
+ ? dilate_w : ic_blk * dilate_w;
+
+ xor_(simd_iter, simd_iter);
+
+ mov(aux_reg_input, reg_input);
+ mov(aux_reg_kernel, reg_kernel);
+
+ Label init_simd_iter_loop;
+ Label init_done;
+ Label init_first;
+
+ L(init_simd_iter_loop);
+
+ if (!jcp.with_sum) {
+ test(reg_ci_flag, FLAG_IC_FIRST);
+ jne(init_first, T_NEAR);
+ }
+
+ for (int ii = 0; ii < oc_blocks; ii++)
+ for (int jj = 0; jj < ur_w; jj++)
+ movups(Xmm(ur_w * ii + jj + 1), xword[reg_output
+ + sizeof(float) * (ii * oh * ow + jj) * oc_blk]);
+
+ if (jcp.with_sum && jcp.with_bias) {
+ test(reg_ci_flag, FLAG_IC_FIRST);
+ je(init_done, T_NEAR);
+
+ for (int ii = 0; ii < oc_blocks; ii++)
+ for (int jj = 0; jj < ur_w; jj++)
+ addps(Xmm(ur_w * ii + jj + 1),
+ xword[reg_bias + sizeof(float) * ii * oc_blk]);
+ }
+
+ jmp(init_done);
+
+ L(init_first);
+ if (this->jcp.with_bias) {
+ for (int ii = 0; ii < oc_blocks; ii++)
+ for (int jj = 0; jj < ur_w; jj++)
+ movups(Xmm(ur_w * ii + jj + 1),
+ xword[reg_bias + sizeof(float) * ii * oc_blk]);
+ } else {
+ for (int ii = 0; ii < oc_blocks; ii++)
+ for (int jj = 0; jj < ur_w; jj++)
+ pxor(Xmm(ur_w * ii + jj + 1), Xmm(ur_w * ii + jj + 1));
+ }
+
+ L(init_done);
+
+ Label skip_kh_loop;
+ mov(kj, reg_kh);
+ if ((jcp.dilate_h >= jcp.ih)
+ || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
+ cmp(kj, 0);
+ je(skip_kh_loop, T_NEAR);
+ }
+ Label kh_loop;
+ L(kh_loop);
+ {
+ if (jcp.kw >= 5 && pad_l == 0 && pad_r == 0) {
+ oh_step_nopad(ur_w, pad_l, pad_r, oc_blocks);
+ sub(aux_reg_input, sizeof(float) * kw * inp_off);
+ add(aux_reg_input, sizeof(float) * iw * inp_mult);
+ } else {
+ oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
+ add(aux_reg_kernel, sizeof(float) * kw * oc_blk * ic_blk);
+ add(aux_reg_input, sizeof(float) * iw * inp_mult);
+ }
+
+ dec(kj);
+ cmp(kj, 0);
+ jg(kh_loop, T_NEAR);
+ }
+
+ L(skip_kh_loop);
+
+ if (jcp.with_eltwise) {
+ Label regular_store;
+ test(reg_ci_flag, FLAG_IC_LAST);
+ je(regular_store, T_NEAR);
+
+ eltwise_injector_->compute_vector_range(1, oc_blocks * ur_w + 1);
+
+ L(regular_store);
+ }
+
+ for (int ii = 0; ii < oc_blocks; ii++) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ const size_t o_off = (ii * oh * ow + jj) * oc_blk;
+
+ Xmm reg_out = Xmm(ur_w * ii + jj + 1);
+ movups(xword[reg_output + sizeof(float) * o_off], reg_out);
+ }
+ }
+
+ mov(aux_reg_kernel, reg_kernel);
+ mov(aux_reg_input, reg_input);
+ add(aux_reg_kernel, sizeof(float) * 4);
+ add(reg_output, sizeof(float) * 4);
+ add(reg_bias, sizeof(float) * 4);
+
+ inc(simd_iter);
+ cmp(simd_iter, 2);
+ jl(init_simd_iter_loop, T_NEAR);
+
+ sub(reg_output, sizeof(float) * 8);
+ sub(reg_bias, sizeof(float) * 8);
+}
+
+inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks)
+{
+ int ur_w = jcp.ur_w;
+ int ur_w_tail = jcp.ur_w_tail;
+ int n_oi = jcp.ow / ur_w;
+ int iw = jcp.iw;
+ int kw = jcp.kw;
+ int ic_blk = jcp.ic_block;
+ int oc_blk = jcp.oc_block;
+ int dilate_w = jcp.dilate_w + 1;
+ int str_w = jcp.stride_w;
+ const int inp_mult = one_of(jcp.src_tag, ncw, nchw) ? 1 : ic_blk;
+
+ int l_pad = jcp.l_pad;
+ int r_pad = nstl::max(0, (int(jcp.ow) - 1) * str_w + (kw - 1) * dilate_w
+ - (iw + l_pad - 1));
+ int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
+ - (iw + l_pad - 1);
+ if (r_pad1 > 0) n_oi--;
+
+ if (l_pad > 0) {
+ n_oi--;
+ if (n_oi < 0 && r_pad1 > 0)
+ width_blk_step(ur_w, l_pad, r_pad1, oc_blocks); // "lrpad"
+ else
+ width_blk_step(ur_w, l_pad, 0, oc_blocks); // "lpad"
+ add(reg_input, sizeof(float) * (ur_w * str_w - l_pad) * inp_mult);
+ add(reg_output, sizeof(float) * ur_w * oc_blk);
+ }
+
+ Label ow_loop;
+ xor_(oi_iter, oi_iter);
+
+ if (n_oi > 0) {
+ L(ow_loop);
+
+ width_blk_step(ur_w, 0, 0, oc_blocks); // "middle"
+ add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
+ add(reg_output, sizeof(float) * ur_w * oc_blk);
+
+ inc(oi_iter);
+ cmp(oi_iter, n_oi);
+ jl(ow_loop, T_NEAR);
+ }
+
+ if (r_pad1 > 0 && n_oi >=0) {
+ width_blk_step(ur_w, 0, r_pad1, oc_blocks); // "rpad"
+ add(reg_input, sizeof(float) * ur_w * str_w * inp_mult);
+ add(reg_output, sizeof(float) * ur_w * oc_blk);
+ }
+
+ if (ur_w_tail != 0)
+ width_blk_step(ur_w_tail, 0, r_pad, oc_blocks); // "tail"
+}
+
+void jit_sse42_conv_fwd_kernel_f32::generate()
+{
+ this->preamble();
+
+ mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
+ mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
+ mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
+ if (jcp.with_bias)
+ mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
+ mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
+ mov(reg_ci_flag, ptr[this->param1 + GET_OFF(flags)]);
+ mov(reg_oc_blocks, ptr[this->param1 + GET_OFF(oc_blocks)]);
+
+ int nb_oc_tail = jcp.nb_oc % jcp.nb_oc_blocking;
+ Label tail, exit;
+
+ cmp(reg_oc_blocks, jcp.nb_oc_blocking);
+ jne(nb_oc_tail ? tail : exit, T_NEAR);
+
+ solve_common(jcp.nb_oc_blocking);
+ jmp(exit, T_NEAR);
+
+ if (nb_oc_tail) {
+ L(tail);
+ cmp(reg_oc_blocks, nb_oc_tail);
+ jne(exit, T_NEAR);
+ solve_common(nb_oc_tail);
+ }
+
+ L(exit);
+
+ this->postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+}
+
+bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok(
+ jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+
+ switch (p.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
+ case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
+ default: return false;
+ }
+
+ return false;
+}
+
+status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr)
+{
+ if (!mayiuse(sse42)) return status::unimplemented;
+
+ jcp.prop_kind = cd.prop_kind;
+
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ const int ndims = src_d.ndims();
+ jcp.ndims = ndims;
+
+ jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
+ jcp.mb = src_d.dims()[0];
+
+ jcp.oc = dst_d.dims()[1] / jcp.ngroups;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+
+ jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
+ jcp.iw = src_d.dims()[ndims - 1];
+ jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[ndims - 1];
+
+ jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
+ jcp.kw = weights_d.dims()[with_groups + ndims - 1];
+
+ jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][ndims - 3];
+
+ jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
+ jcp.stride_w = cd.strides[ndims - 3];
+
+ jcp.dilate_h = (ndims == 3) ? 0 : cd.dilates[0];
+ jcp.dilate_w = cd.dilates[ndims - 3];
+ jcp.b_pad = (jcp.oh - 1) * jcp.stride_h + (jcp.kh - 1) * (jcp.dilate_h + 1)
+ - (jcp.ih + jcp.t_pad - 1);
+
+ if (ndims == 3) {
+ jcp.src_tag = src_d.matches_one_of_tag(ncw, nwc, nCw8c);
+ jcp.wei_tag = weights_d.matches_one_of_tag(
+ Owi8o, gOwi8o, OIw8i8o, gOIw8i8o);
+ jcp.dst_tag = dst_d.matches_one_of_tag(nCw8c);
+ } else if (ndims == 4) {
+ jcp.src_tag = src_d.matches_one_of_tag(nchw, nhwc, nChw8c);
+ jcp.wei_tag = weights_d.matches_one_of_tag(
+ Ohwi8o, gOhwi8o, OIhw8i8o, gOIhw8i8o);
+ jcp.dst_tag = dst_d.matches_one_of_tag(nChw8c);
+ }
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ jcp.with_sum = p.find(primitive_kind::sum) != -1;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise)
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+
+ const bool flat = jcp.ic == 3;
+ const bool mimo = !flat;
+
+ bool args_ok = true
+ && IMPLICATION(flat, one_of(jcp.src_tag, ncw, nwc, nchw, nhwc)
+ && one_of(jcp.wei_tag, Owi8o, gOwi8o, Ohwi8o, gOhwi8o))
+ && IMPLICATION(mimo, one_of(jcp.src_tag, nCw8c, nChw8c)
+ && one_of(jcp.wei_tag, OIw8i8o, gOIw8i8o, OIhw8i8o, gOIhw8i8o))
+ && one_of(jcp.dst_tag, nCw8c, nChw8c);
+ if (!args_ok) return status::unimplemented;
+
+ const int simd_w = 8; // 2 SSE vectors processing at once
+
+ jcp.ur_h = 1; /* no code-unrolling by h so far */
+ jcp.ur_w = 3;
+ if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
+ jcp.ur_w_tail = jcp.ow % jcp.ur_w;
+
+ jcp.nb_oc_blocking = 4; /* the optimal value for the kernel */
+
+ args_ok = true
+ && jcp.oc % simd_w == 0
+ && jcp.l_pad <= jcp.ur_w
+ && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
+ || (jcp.stride_w == 1 && jcp.stride_h == 1))
+ && IMPLICATION(mimo, jcp.ic % simd_w == 0);
+ if (!args_ok) return status::unimplemented;
+
+ int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+
+ // kernel needs 1 temporary YMM register
+ const int num_avail_regs = 15;
+ if (r_pad_no_tail > jcp.ur_w * jcp.stride_w && jcp.ow / jcp.ur_w > 1) {
+ /* recalculate ur_w, nb_oc_blocking and ur_w_tail */
+ jcp.ur_w = nstl::min(r_pad_no_tail / jcp.stride_w + jcp.ur_w_tail,
+ nstl::min(jcp.ow, num_avail_regs / 2));
+ jcp.nb_oc_blocking = (num_avail_regs - jcp.ur_w) / jcp.ur_w;
+ jcp.ur_w_tail = jcp.ow % jcp.ur_w;
+ /* check again ... */
+ r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
+ if (jcp.ur_w < nstl::max(jcp.l_pad, r_pad_no_tail))
+ return status::unimplemented;
+ }
+ assert(jcp.nb_oc_blocking > 0);
+ assert(jcp.ur_w * (jcp.nb_oc_blocking + 1) <= num_avail_regs);
+
+ jcp.ic_block = (jcp.ic % simd_w != 0) ? jcp.ic : simd_w;
+ jcp.nb_ic = jcp.ic / jcp.ic_block;
+
+ jcp.oc_block = simd_w;
+ jcp.nb_oc = jcp.oc / jcp.oc_block;
+
+ if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
+ jcp.nb_ic_blocking = 12;
+ jcp.nb_ic_blocking_max = 16;
+ } else {
+ jcp.nb_ic_blocking = 1;
+ jcp.nb_ic_blocking_max = jcp.nb_ic_blocking;
+ }
+
+ return status::success;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp
new file mode 100644
index 0000000000..33c26ef081
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.hpp
@@ -0,0 +1,93 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_SSE42_CONV_KERNEL_F32_HPP
+#define JIT_SSE42_CONV_KERNEL_F32_HPP
+
+#include "c_types_map.hpp"
+#include "cpu_memory.hpp"
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_sse42_conv_fwd_kernel_f32: public jit_generator {
+ jit_sse42_conv_fwd_kernel_f32(jit_conv_conf_t ajcp,
+ const primitive_attr_t &attr)
+ : jcp(ajcp), attr_(attr), eltwise_injector_(nullptr)
+ {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<sse42>(this,
+ jcp.eltwise);
+
+ this->generate();
+ jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
+ }
+
+ ~jit_sse42_conv_fwd_kernel_f32() {
+ delete eltwise_injector_;
+ }
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_sse42_conv_fwd_kernel_f32)
+ jit_conv_conf_t jcp;
+ const primitive_attr_t &attr_;
+ void (*jit_ker)(jit_conv_call_s *);
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ reg64_t reg_input = rax;
+ reg64_t aux_reg_input = r8;
+ reg64_t reg_kernel = rdx;
+ reg64_t aux_reg_kernel = r9;
+ reg64_t reg_output = rsi;
+ reg64_t reg_bias = rbx;
+
+ reg64_t kj = r10;
+ reg64_t oi_iter = r11;
+ reg64_t ki_iter = r12;
+ reg64_t reg_kh = abi_not_param1;
+ reg64_t simd_iter = r15;
+ reg64_t reg_oc_blocks = r14;
+ reg64_t imm_addr64 = reg_oc_blocks;
+ Xbyak::Reg32 reg_ci_flag = r13d;
+
+ jit_uni_eltwise_injector_f32<sse42> *eltwise_injector_;
+
+ inline void oh_step_unroll_kw(int ur_w, int pad_l, int pad_r,
+ int oc_blocks);
+ inline void oh_step_nopad(int ur_w, int pad_l, int pad_r, int oc_blocks);
+ inline void width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks);
+ inline void solve_common(int oc_blocks);
+
+ void generate();
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp
new file mode 100644
index 0000000000..5f77d692f5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.cpp
@@ -0,0 +1,136 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "jit_sse42_convolution.hpp"
+#include "mkldnn_thread.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+
+#define src_blk_off(f, n, c, h, w) \
+ (pd()->ndims() == 3) \
+ ? (f).blk_off(n, c, w) \
+ : (f).blk_off(n, c, h, w)
+
+#define wht_blk_off_(f, g, ...) \
+ pd()->with_groups() \
+ ? (f).blk_off(g, __VA_ARGS__) \
+ : (f).blk_off(__VA_ARGS__)
+#define wht_blk_off(f, g, oc, ic, kh, kw) \
+ pd()->ndims() == 3 \
+ ? wht_blk_off_(f, g, oc, ic, kw) \
+ : wht_blk_off_(f, g, oc, ic, kh, kw)
+
+void jit_sse42_convolution_fwd_t::execute_forward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const auto &jcp = kernel_->jcp;
+
+ int ocb_work = div_up(jcp.nb_oc, jcp.nb_oc_blocking);
+ const size_t work_amount = jcp.mb * jcp.ngroups * ocb_work * jcp.oh;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ size_t start{ 0 }, end{ 0 };
+ balance211(work_amount, nthr, ithr, start, end);
+
+ int icbb = 0;
+ while (icbb < jcp.nb_ic) {
+ int icb_step = jcp.nb_ic_blocking;
+ int icb_step_rem = jcp.nb_ic - icbb;
+ if (icb_step_rem < jcp.nb_ic_blocking_max)
+ icb_step = icb_step_rem;
+
+ size_t n{0}, g{0}, ocbb{0}, oh{0};
+ nd_iterator_init(start, n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
+ oh, jcp.oh);
+ for (size_t iwork = start; iwork < end; ++iwork) {
+ int ocb = ocbb * jcp.nb_oc_blocking;
+ int ocb_num = jcp.nb_oc_blocking;
+
+ for (int icb = icbb; icb < icbb + icb_step; ++icb) {
+ auto par_conv = jit_conv_call_s();
+
+ const int ij = oh * jcp.stride_h;
+ const int i_t_overflow = nstl::max(0, jcp.t_pad - ij);
+ const int i_b_overflow = nstl::max(jcp.ih, ij
+ + (jcp.kh-1) * (jcp.dilate_h+1) - jcp.t_pad+1) - jcp.ih;
+
+ const size_t _oc = g * jcp.nb_oc + ocb;
+ const size_t _ic = g * jcp.nb_ic + icb;
+
+ const int ih = nstl::max(ij - jcp.t_pad
+ + div_up(i_t_overflow,
+ (jcp.dilate_h+1)) * (jcp.dilate_h + 1), 0);
+ par_conv.src = &src[src_blk_off(src_d, n,
+ jcp.ic == 3 ? 0 : _ic, ih, 0)];
+
+ par_conv.dst = &dst[src_blk_off(dst_d, n, _oc, oh, 0)];
+
+ const int wh = div_up(i_t_overflow, (jcp.dilate_h + 1));
+ par_conv.filt = &weights[wht_blk_off(weights_d, g, ocb,
+ jcp.ic == 3 ? 0 : icb, wh, 0)];
+
+ if (icb == 0) {
+ if (bias)
+ par_conv.bias =
+ &bias[bias_d.blk_off(_oc * jcp.oc_block)];
+ par_conv.flags |= FLAG_IC_FIRST;
+ }
+
+ if (jcp.with_eltwise && icb + 1 == jcp.nb_ic) {
+ par_conv.flags |= FLAG_IC_LAST;
+ }
+
+ par_conv.oc_blocks =
+ nstl::min(ocb + ocb_num, jcp.nb_oc) - ocb;
+
+ par_conv.kw_padding = 0;
+ const int kh_padding = jcp.kh
+ - div_up(i_t_overflow, (jcp.dilate_h + 1))
+ - div_up(i_b_overflow, (jcp.dilate_h + 1));
+ par_conv.kh_padding = nstl::max(0, kh_padding);
+ kernel_->jit_ker(&par_conv);
+ }
+ nd_iterator_step(n, jcp.mb, g, jcp.ngroups, ocbb, ocb_work,
+ oh, jcp.oh);
+ }
+ icbb += icb_step;
+ }
+ });
+
+ if (pd()->wants_zero_pad_dst())
+ ctx.memory(MKLDNN_ARG_DST)->zero_pad();
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp
new file mode 100644
index 0000000000..d2f0a38c5c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_sse42_convolution.hpp
@@ -0,0 +1,103 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_SSE42_CONVOLUTION_HPP
+#define CPU_JIT_SSE42_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_primitive_conf.hpp"
+#include "jit_sse42_conv_kernel_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_sse42_convolution_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", sse42, ""),
+ jit_sse42_convolution_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ return jit_sse42_conv_fwd_kernel_f32::init_conf(jcp_, *desc(),
+ *src_md(), *weights_md(), *dst_md(), *attr());
+ }
+
+ jit_conv_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ const bool flat = IC() == 3;
+ auto src_tag = flat
+ ? utils::pick(ndims() - 3, ncw, nchw, ncdhw)
+ : utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto dst_tag =
+ utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
+ auto wei_tag = with_groups()
+ ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
+ gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
+ : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
+ OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
+
+ return set_default_formats_common(src_tag, wei_tag, dst_tag);
+ }
+ };
+
+ jit_sse42_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
+ { kernel_ = new jit_sse42_conv_fwd_kernel_f32(pd()->jcp_, *pd()->attr()); }
+ ~jit_sse42_convolution_fwd_t() { delete kernel_; };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ jit_sse42_conv_fwd_kernel_f32 *kernel_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp
new file mode 100644
index 0000000000..0e734f7265
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.cpp
@@ -0,0 +1,1192 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+#include "jit_generator.hpp"
+#include "cpu_barrier.hpp"
+
+#include "jit_transpose_src_utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace Xbyak;
+
+#define GET_OFF(x) offsetof(ctx_t, x)
+
+struct jit_trans_iw_ic_t: public jit_trans_src_t, public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_t)
+
+ jit_trans_iw_ic_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) {
+ generate();
+ ker_ = (decltype(ker_))this->getCode();
+ }
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ using reg32_t = const Xbyak::Reg32;
+ using opmask_t = const Xbyak::Opmask;
+
+ enum { typesize = sizeof(float), transpose_size = 16, small_spatial = 14 };
+ int src_stride, tr_src_stride;
+ int tail;
+ bool enable_prefetch;
+
+ opmask_t k3333 = k1;
+ opmask_t k5555 = k2;
+ opmask_t kAAAA = k3;
+ opmask_t kCCCC = k4;
+ opmask_t k0F0F = k5;
+ opmask_t kF0F0 = k6;
+ opmask_t kTail = k7;
+
+ reg64_t reg_src = r8;
+ reg64_t reg_tr_src = r9;
+ reg64_t reg_src_prf = r10;
+ reg64_t reg_tr_src_prf = r11;
+ reg64_t reg_loop = r12;
+ reg64_t reg_tr_src_tmp = r13;
+ reg32_t regw_tmp = r14d;
+
+ void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores);
+ void generate();
+};
+
+void jit_trans_iw_ic_t::transpose(int nrows, int l_pad, int r_pad,
+ bool nontemporal_stores) {
+ assert(nrows >= 0 && nrows <= transpose_size);
+ static_assert(transpose_size == 16, "Unsupported transpose size");
+ if (!nrows)
+ return;
+
+ auto pf_src_t0 = [=](int i) {
+ if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_src,
+ (transpose_size + i) * src_stride));
+ };
+
+ auto pf_tr_src_t0 = [=](int i) {
+ int offset = (transpose_size) * typesize + i * tr_src_stride;
+ if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src, offset));
+ if(enable_prefetch) prefetcht0(EVEX_compress_addr(reg_tr_src,
+ offset + 64));
+ };
+
+ auto pf_src_t1 = [=](int i) {
+ if(enable_prefetch) prefetcht1(EVEX_compress_addr(reg_src_prf,
+ i * src_stride));
+ };
+
+ auto pf_tr_src_t1 = [=](int i) {
+ if(enable_prefetch) prefetchwt1(EVEX_compress_addr(reg_tr_src_prf,
+ i * tr_src_stride));
+ };
+
+ auto src_zmm = [=](int i) {
+ assert(i >= 0 && i < 16);
+ return Zmm(i);
+ };
+
+ auto tmp_zmm = [=](int i) {
+ assert(i >= 0 && i < 16);
+ return Zmm(16 + i);
+ };
+
+ auto load = [=](int i) {
+ vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride));
+ };
+
+ auto store = [=](Zmm r, int i) {
+ auto kmovw = [=](Opmask k, unsigned w) {
+ mov(regw_tmp, w);
+ jit_generator::kmovw(k, regw_tmp);
+ };
+
+ auto padding = [=] (Reg64 reg, int pad) {
+ kmovw(kTail, (1 << pad) - 1);
+ auto k = kTail;
+ auto base = reg;
+ base.setOpmaskIdx(k.getIdx(), true);
+
+ auto zmm_zero = r;
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+ auto addr = EVEX_compress_addr(base, i * tr_src_stride);
+ vmovups(addr, zmm_zero);
+ };
+
+ mov(reg_tr_src_tmp, reg_tr_src);
+ if (l_pad > 0)
+ add(reg_tr_src_tmp, l_pad * typesize);
+
+ if (tail != transpose_size)
+ kmovw(kTail, (1 << tail) - 1);
+
+ // Xbyak does not allow k0 to be specified explicitly via the '|'
+ // operator, so we have to do this via a method call (implicitly
+ // EVEX encoding uses k0 to mean 'no mask')
+ bool partial_store = nrows < 16;
+ auto k = partial_store ? kTail : k0;
+ auto base = reg_tr_src_tmp;
+ base.setOpmaskIdx(k.getIdx(), true);
+
+ auto addr = EVEX_compress_addr(base, i * tr_src_stride);
+ if (nontemporal_stores && !partial_store)
+ vmovntps(addr, r);
+ else
+ vmovups(addr, r);
+
+ if (r_pad > 0) {
+ add(reg_tr_src_tmp, tail * typesize);
+ padding(reg_tr_src_tmp, r_pad);
+ }
+
+ if (l_pad > 0) {
+ padding(reg_tr_src, l_pad);
+ }
+ };
+
+ auto transpose16x8 = [=](int base_idx) {
+ assert(base_idx == 0 || base_idx == 8);
+
+ // swap 1
+ for (int i = 0; i < 4; i++) {
+ int src_idx0 = base_idx + i * 2;
+ int src_idx1 = src_idx0 + 1;
+
+ int next_src_idx0 = src_idx0 + 2;
+ int next_src_idx1 = src_idx1 + 2;
+ bool load_next = base_idx == 0 || i < 3;
+
+ if (base_idx == 0 && i == 0) {
+ load(src_idx0);
+ load(src_idx1);
+ }
+
+ auto tmp0 = tmp_zmm(src_idx0);
+ auto tmp1 = tmp_zmm(src_idx1);
+ auto src0 = src_zmm(src_idx0);
+ auto src1 = src_zmm(src_idx1);
+
+ if (next_src_idx0 < nrows && load_next)
+ load(next_src_idx0);
+ valignd(tmp0, src0, src0, 0x1);
+ pf_src_t1(base_idx + i);
+
+ if (next_src_idx1 < nrows && load_next)
+ load(next_src_idx1);
+ valignd(tmp1, src1, src1, 0xf);
+ pf_src_t0(base_idx + i);
+
+ vmovaps(src0 | kAAAA, tmp1);
+ vmovaps(src1 | k5555, tmp0);
+ }
+ // swap 2
+ for (int i = 0; i < 4; i++) {
+ int select_half = (i < 2) ? 0 : 2;
+ int src_idx0 = base_idx + i + select_half + 0;
+ int src_idx2 = src_idx0 + 2;
+
+ auto tmp0 = tmp_zmm(src_idx0);
+ auto tmp1 = tmp_zmm(src_idx2);
+ auto src0 = src_zmm(src_idx0);
+ auto src2 = src_zmm(src_idx2);
+
+ valignd(tmp0, src0, src0, 0x2);
+ pf_src_t1(base_idx + 4 + i);
+ valignd(tmp1, src2, src2, 0xe);
+ pf_src_t0(base_idx + 4 + i);
+ vmovaps(src2 | k3333, tmp0);
+ vmovaps(src0 | kCCCC, tmp1);
+ }
+
+ // swap 4
+ for (int i = 0; i < 4; i++) {
+ int src_idx0 = base_idx + i;
+ int src_idx4 = src_idx0 + 4;
+
+ auto tmp0 = tmp_zmm(src_idx0);
+ auto src0 = src_zmm(src_idx0);
+ auto src4 = src_zmm(src_idx4);
+
+ vmovaps(tmp0, src0);
+ vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
+ pf_tr_src_t1(base_idx / 2 + i);
+ vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
+ pf_tr_src_t0(base_idx / 2 + i);
+ }
+ };
+
+ auto fixup16x16 = [=]() {
+ // swap 8
+ for (int i = 0; i < 8; i++) {
+ auto tmp = tmp_zmm(i);
+ auto src0 = src_zmm(i);
+ auto src8 = src_zmm(8 + i);
+ vshuff64x2(tmp, src0, src8, 0x44);
+ store(tmp, i);
+ if (i % 2 == 0) {
+ pf_tr_src_t1(8 + i / 2);
+ pf_tr_src_t0(8 + i / 2);
+ }
+ }
+
+ for (int i = 0; i < 8; i++) {
+ auto tmp = tmp_zmm(8 + i);
+ auto src0 = src_zmm(i);
+ auto src8 = src_zmm(8 + i);
+ vshuff64x2(tmp, src0, src8, 0xee);
+ store(tmp, 8 + i);
+ if (i % 2 == 0) {
+ pf_tr_src_t1(12 + i / 2);
+ pf_tr_src_t0(12 + i / 2);
+ }
+ }
+ };
+
+ transpose16x8(0);
+ transpose16x8(8);
+ fixup16x16();
+}
+
+void jit_trans_iw_ic_t::generate() {
+ preamble();
+
+ const int ic_block = conf_->ic_block;
+ const int iw = conf_->iw;
+ const int tr_iw = conf_->tr_iw;
+ const int transposes = utils::div_up(iw, transpose_size);
+ int loop_iters = nstl::max(0, transposes - 1);
+ tail = iw - loop_iters * transpose_size;
+
+ src_stride = ic_block * typesize;
+ assert(src_stride == 64);
+ tr_src_stride = tr_iw * typesize;
+
+ bool nontemporal_stores = false;
+ enable_prefetch = iw > small_spatial ? 1 : 0;
+
+ assert(transpose_size == ic_block);
+ const int src_step = ic_block * transpose_size * typesize;
+ const int tr_src_step = ic_block * typesize;
+
+ const int left_pad = conf_->l_pad;
+ const int right_pad = tr_iw - iw - left_pad;
+
+ mov(reg_src, ptr [param1 + GET_OFF(src)]);
+ mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]);
+ mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]);
+ mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]);
+
+ auto kmovw = [=](Opmask k, unsigned w) {
+ mov(regw_tmp, w);
+ jit_generator::kmovw(k, regw_tmp);
+ };
+
+ kmovw(k3333, 0x3333); // 0011001100110011
+ kmovw(k5555, 0x5555); // 0101010101010101
+ kmovw(kAAAA, 0xaaaa); // 1010101010101010
+ kmovw(kCCCC, 0xcccc); // 1100110011001100
+ kmovw(k0F0F, 0x0f0f); // 0000111100001111
+ kmovw(kF0F0, 0xf0f0); // 1111000011110000
+
+ if (left_pad > 0 && loop_iters > 0) {
+ loop_iters--;
+ transpose(transpose_size, left_pad, 0, nontemporal_stores);
+ add(reg_src, src_step);
+ add(reg_tr_src, tr_src_step + left_pad * typesize);
+ add(reg_src_prf, src_step);
+ add(reg_tr_src_prf, tr_src_step + left_pad * typesize);
+ }
+
+ if (loop_iters) {
+ mov(reg_loop, loop_iters);
+ Label loop;
+ L(loop); {
+ transpose(transpose_size, 0, 0, nontemporal_stores);
+ add(reg_src, src_step);
+ add(reg_tr_src, tr_src_step);
+ add(reg_src_prf, src_step);
+ add(reg_tr_src_prf, tr_src_step);
+ sub(reg_loop, 1);
+ jnz(loop);
+ }
+ }
+ if (transposes > 1)
+ transpose(tail, 0, right_pad, nontemporal_stores);
+ else
+ transpose(tail, left_pad, right_pad, nontemporal_stores);
+
+ postamble();
+}
+
+struct jit_trans_iw_ic_int16_t: public jit_trans_src_t, public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_ic_int16_t)
+ jit_trans_iw_ic_int16_t(const jit_conv_conf_t *conf):
+ jit_trans_src_t(conf) {
+ generate();
+ ker_ = (decltype(ker_))this->getCode();
+ }
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ using reg32_t = const Xbyak::Reg32;
+ using opmask_t = const Xbyak::Opmask;
+
+ enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 };
+ int src_stride, tr_src_stride;
+ int tail;
+ bool enable_prefetch;
+
+ opmask_t kFFFF = k1;
+ opmask_t k5555 = k2;
+ opmask_t kAAAA = k3;
+ opmask_t kAA = k4;
+ opmask_t k55 = k5;
+ opmask_t kCC = k6;
+ opmask_t k33 = k7;
+ opmask_t kTail = k1;
+
+ reg64_t reg_src = r8;
+ reg64_t reg_tr_src = r9;
+ reg64_t reg_src_prf = r10;
+ reg64_t reg_tr_src_prf = r11;
+ reg64_t reg_loop = r12;
+ reg64_t reg_tr_src_tmp = r13;
+ reg32_t regw_tmp = r14d;
+ reg64_t imm_addr64 = rbx;
+
+ Xbyak::Zmm vidx1 = zmm31;
+ Xbyak::Zmm vidx2 = zmm30;
+ Xbyak::Zmm vidx3 = zmm29;
+ Xbyak::Zmm vidx4 = zmm28;
+ Xbyak::Zmm vidx5 = zmm27;
+ Xbyak::Zmm zmm_tmp = zmm26;
+
+
+ void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores);
+ void generate();
+};
+
+void jit_trans_iw_ic_int16_t::transpose(int nrows, int l_pad, int r_pad,
+ bool nontemporal_stores) {
+ assert(nrows >= 0 && nrows <= transpose_size);
+ static_assert(transpose_size == 16, "Unsupported transpose size");
+ if (!nrows)
+ return;
+
+ auto src_zmm = [=](int i) {
+ return Zmm(i);
+ };
+
+ auto src_ymm = [=](int i) {
+ assert(i >= 0 && i < 16);
+ return Ymm(i);
+ };
+
+ auto load_ymm = [=](int i) {
+ vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride));
+ };
+
+ auto kmovw = [=](Opmask k, unsigned w) {
+ mov(regw_tmp, w);
+ jit_generator::kmovw(k, regw_tmp);
+ };
+
+ auto store = [=](Zmm r, int i) {
+
+ auto padding = [=] (Reg64 reg, int pad) {
+ kmovw(kTail, (1 << pad) - 1);
+ auto k = kTail;
+ auto base = reg;
+ base.setOpmaskIdx(k.getIdx(), true);
+
+ auto zmm_zero = zmm_tmp;
+ vpxord(zmm_zero, zmm_zero, zmm_zero);
+ auto addr = EVEX_compress_addr(base, i * tr_src_stride);
+ vmovups(addr, zmm_zero);
+ };
+
+ int store_tail = (nrows%2) ? nrows+1 : nrows;
+
+ int store_pad = (l_pad%2) ? l_pad/2 + 1 : l_pad/2;
+ mov(reg_tr_src_tmp, reg_tr_src);
+ if (l_pad > 0) {
+ padding(reg_tr_src, store_pad);
+ add(reg_tr_src_tmp, l_pad * typesize);
+ }
+ if (r_pad > 0) {
+ store_pad = (r_pad%2) ? r_pad/2 + 1 : r_pad/2;
+ int addr_shift = (r_pad%2) ? 1 : 0;
+ add(reg_tr_src_tmp, (nrows - addr_shift) * typesize);
+ padding(reg_tr_src_tmp, store_pad);
+ }
+
+ mov(reg_tr_src_tmp, reg_tr_src);
+ add(reg_tr_src_tmp, l_pad * typesize);
+
+ kmovw(kTail, (1 << store_tail/2) - 1);
+ auto k = kTail;
+ auto base = reg_tr_src_tmp;
+ base.setOpmaskIdx(k.getIdx(), true);
+
+ auto addr = EVEX_compress_addr(base, i * tr_src_stride);
+ vmovups(addr, r);
+
+ };
+
+ kmovw(kFFFF, 0xffff);
+ //all loads
+ for (int i=0; i<16; i++){
+ vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
+ }
+
+ for (int i = 0; i < nrows/2; i++) {
+ auto src0 = src_ymm(2*i);
+ auto src1 = src_ymm(2*i+1);
+ auto zmm_src0 = src_zmm(2*i);
+ load_ymm(2*i);
+
+ vpunpcklwd(src1, src0,
+ EVEX_compress_addr(reg_src, (2*i+1) * src_stride));
+ vpunpckhwd(src0, src0,
+ EVEX_compress_addr(reg_src, (2*i+1) * src_stride));
+ vinserti64x4(zmm_src0, zmm_src0, src1, 1);
+ vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0);
+ }
+
+ // for odd numbers we need to mix row with zeroes
+ if (nrows%2) {
+ int i = nrows-1;
+ auto src0 = src_ymm(i);
+ auto src1 = src_ymm(i+1); //zero
+
+ auto zmm_src0 = src_zmm(i);
+ vpxor(src1, src1, src1);
+
+ load_ymm(i);
+ vpunpckhwd(src0, src0, src1);
+ vinserti64x4(zmm_tmp, zmm_tmp, src0, 0);
+ vpxor(src0, src0, src0);
+ load_ymm(i);
+ vpunpcklwd(src1, src0, src1);
+ vinserti64x4(zmm_tmp, zmm_tmp, src1, 1);
+ vpxord(zmm_src0, zmm_src0, zmm_src0);
+ vmovups(zmm_src0, zmm_tmp);
+ vpermps(zmm_src0 | kFFFF, vidx4, zmm_src0);
+ }
+
+ // swap 1
+ for (int i=0; i<4; i++) {
+ auto zmm0 = src_zmm(4*i);
+ auto zmm1 = src_zmm(4*i+2);
+ auto tmp0 = src_zmm(4*i+1);
+ auto tmp1 = src_zmm(4*i+3);
+
+ vmovups(tmp0, zmm0);
+ vmovups(tmp1, zmm1);
+
+ vpermps(tmp0 | kAAAA, vidx3, zmm1);
+ vpermps(tmp1 | k5555, vidx3, zmm0);
+ }
+ // swap 2
+ int base_idx;
+ base_idx=0;
+ for (int i=0; i<2; i++) {
+ auto zmm0 = src_zmm(base_idx+2*i+1);
+ auto zmm1 = src_zmm(base_idx+2*i+5);
+
+ auto tmp0 = src_zmm(base_idx+2*i);
+ auto tmp1 = src_zmm(base_idx+2*i+4);
+
+ vmovupd(tmp0, zmm0);
+ vmovupd(tmp1, zmm1);
+
+ vpermpd(tmp0 | kAA, vidx2, zmm1);
+ vpermpd(tmp1 | k55, vidx2, zmm0);
+ }
+ base_idx=8;
+ for (int i=0; i<2; i++) {
+ auto zmm0 = src_zmm(base_idx+2*i+1);
+ auto zmm1 = src_zmm(base_idx+2*i+5);
+
+ auto tmp0 = src_zmm(base_idx+2*i);
+ auto tmp1 = src_zmm(base_idx+2*i+4);
+
+ vmovupd(tmp0, zmm0);
+ vmovupd(tmp1, zmm1);
+
+ vpermpd(tmp0 | kAA, vidx2, zmm1);
+ vpermpd(tmp1 | k55, vidx2, zmm0);
+ }
+
+ // swap 3
+ for (int i=0; i<4; i++) {
+ auto zmm0 = src_zmm(2*i);
+ auto zmm1 = src_zmm(2*i+8);
+
+ auto tmp0 = src_zmm(2*i+1);
+ auto tmp1 = src_zmm(2*i+9);
+
+ vmovupd(tmp0, zmm0);
+ vmovupd(tmp1, zmm1);
+
+ vpermpd(tmp0 | kCC, vidx1, zmm1);
+ vpermpd(tmp1 | k33, vidx1, zmm0);
+ }
+
+ // all stores
+ for (int i=0; i<8; i++)
+ vextracti64x4(src_ymm(2*i), src_zmm(2*i+1), 1);
+
+ store(src_zmm(1), 0);
+ store(src_zmm(0), 1);
+ store(src_zmm(3), 2);
+ store(src_zmm(2), 3);
+ store(src_zmm(9), 4);
+ store(src_zmm(8), 5);
+ store(src_zmm(11), 6);
+ store(src_zmm(10), 7);
+ store(src_zmm(5), 8);
+ store(src_zmm(4), 9);
+ store(src_zmm(7), 10);
+ store(src_zmm(6), 11);
+ store(src_zmm(13), 12);
+ store(src_zmm(12), 13);
+ store(src_zmm(15), 14);
+ store(src_zmm(14), 15);
+
+}
+
+void jit_trans_iw_ic_int16_t::generate() {
+ preamble();
+
+ alignas(64) static constexpr const int64_t idx1[8]
+ = { 2, 3, 0, 1, 6, 7, 4, 5 };
+ alignas(64) static constexpr const int64_t idx2[8]
+ = { 1, 0, 3, 2, 5, 4, 7, 6 };
+ alignas(64) static constexpr const int32_t idx3[16]
+ = { 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14 };
+ alignas(64) static constexpr const int32_t idx4[16]
+ = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 };
+ alignas(64) static constexpr const int32_t idx5[16]
+ = { 8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7 };
+
+ const int ic_block = conf_->ic_block;
+ const int iw = conf_->iw;
+ const int tr_iw = conf_->tr_iw;
+ const int transposes = utils::div_up(iw, transpose_size);
+ int loop_iters = nstl::max(0, transposes - 1);
+ tail = iw - loop_iters * transpose_size;
+
+ src_stride = ic_block * typesize;
+ tr_src_stride = tr_iw * typesize;
+
+ bool nontemporal_stores = false;
+ enable_prefetch = iw > small_spatial ? 1 : 0;
+
+ assert(transpose_size == ic_block);
+ const int src_step = ic_block * transpose_size * typesize;
+ const int tr_src_step = ic_block * typesize;
+
+ const int left_pad = conf_->l_pad;
+ const int right_pad = tr_iw - iw - left_pad;
+
+ mov(reg_src, ptr [param1 + GET_OFF(src)]);
+ mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]);
+ mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]);
+ mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]);
+
+ auto kmovw = [=](Opmask k, unsigned w) {
+ mov(regw_tmp, w);
+ jit_generator::kmovw(k, regw_tmp);
+ };
+
+ kmovw(kFFFF, 0xffff);
+ kmovw(k5555, 0x5555);
+ kmovw(kAAAA, 0xaaaa);
+ kmovw(kAA, 0xaa);
+ kmovw(k55, 0x55);
+ kmovw(kCC, 0xcc);
+ kmovw(k33, 0x33);
+
+ auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
+ mov(imm_addr64, reinterpret_cast<size_t>(addr));
+ jit_generator::vmovdqa64(z, ptr[imm_addr64]);
+ };
+
+ auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
+ mov(imm_addr64, reinterpret_cast<size_t>(addr));
+ jit_generator::vmovdqa32(z, ptr[imm_addr64]);
+ };
+
+ vmovdqa64(vidx1, idx1);
+ vmovdqa64(vidx2, idx2);
+ vmovdqa32(vidx3, idx3);
+ vmovdqa32(vidx4, idx4);
+ vmovdqa32(vidx5, idx5);
+
+ if (left_pad > 0 && loop_iters > 0) {
+ loop_iters--;
+ transpose(transpose_size, left_pad, 0, nontemporal_stores);
+ add(reg_src, src_step);
+ add(reg_tr_src, tr_src_step + left_pad * typesize);
+ add(reg_src_prf, src_step);
+ add(reg_tr_src_prf, tr_src_step + left_pad * typesize);
+ }
+
+ if (loop_iters) {
+ mov(reg_loop, loop_iters);
+ Label loop;
+ L(loop); {
+ transpose(transpose_size, 0, 0, nontemporal_stores);
+ add(reg_src, src_step);
+ add(reg_tr_src, tr_src_step);
+ add(reg_src_prf, src_step);
+ add(reg_tr_src_prf, tr_src_step);
+ sub(reg_loop, 1);
+ jnz(loop);
+ }
+ }
+ if (transposes > 1)
+ transpose(tail, 0, right_pad, nontemporal_stores);
+ else
+ transpose(tail, left_pad, right_pad, nontemporal_stores);
+
+ postamble();
+
+}
+
+struct jit_trans_ow_oc_t: public jit_trans_dst_t, public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_ow_oc_t)
+ jit_trans_ow_oc_t(const jit_conv_conf_t *conf): jit_trans_dst_t(conf) {
+ generate();
+ ker_ = (decltype(ker_))this->getCode();
+ }
+
+private:
+ using reg64_t = const Xbyak::Reg64;
+ using reg32_t = const Xbyak::Reg32;
+ using opmask_t = const Xbyak::Opmask;
+ using zmm = const Xbyak::Zmm;
+
+ enum { typesize = sizeof(int16_t), transpose_size = 16, small_spatial = 14 };
+ int src_stride, tr_src_stride;
+ int tail;
+ bool enable_prefetch;
+
+ opmask_t kFF = k1;
+
+ zmm vidx1 = zmm31;
+
+ reg64_t reg_src = r8;
+ reg64_t reg_tr_src = r9;
+ reg64_t reg_src_prf = r10;
+ reg64_t reg_tr_src_prf = r11;
+ reg64_t reg_loop = r12;
+ reg64_t reg_tr_src_tmp = r13;
+ reg32_t regw_tmp = r14d;
+ reg64_t imm_addr64 = rbx;
+
+ void transpose(int nrows, int l_pad, int r_pad, bool nontemporal_stores);
+ void generate();
+};
+
+void jit_trans_ow_oc_t::transpose(int nrows, int l_pad, int r_pad,
+ bool nontemporal_stores) {
+ assert(nrows >= 0 && nrows <= transpose_size);
+ static_assert(transpose_size == 16, "Unsupported transpose size");
+ if (!nrows)
+ return;
+
+ auto src_zmm = [=](int i) {
+ return Zmm(i);
+ };
+
+ auto src_ymm = [=](int i) {
+ assert(i >= 0 && i < 16);
+ return Ymm(i);
+ };
+
+ auto load_ymm = [=](int i) {
+ vmovups(src_ymm(i), EVEX_compress_addr(reg_src, i * src_stride));
+ };
+
+
+ auto store = [=](Zmm r, int i) {
+ auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride);
+ if (nontemporal_stores)
+ vmovntps(addr, r);
+ else
+ vmovups(addr, r);
+ };
+
+ for (int i = 0; i < nrows/2; i++) {
+ auto src0 = src_ymm(2*i);
+ auto src1 = src_ymm(2*i+1);
+ auto zmm_src0 = src_zmm(2*i);
+ load_ymm(2*i);
+ vpunpcklwd(src1, src0,
+ EVEX_compress_addr(reg_src, (2*i+1) * src_stride));
+ vpunpckhwd(src0, src0,
+ EVEX_compress_addr(reg_src, (2*i+1) * src_stride));
+ vinserti64x4(zmm_src0, zmm_src0, src1, 1);
+ vpermpd(zmm_src0 | kFF, vidx1, zmm_src0);
+ store(zmm_src0, 2*i);
+ }
+ if (r_pad > 0) {
+ auto src0 = src_ymm(nrows-1);
+ auto src1 = src_ymm(nrows);
+ auto zmm_src0 = src_zmm(30);
+ load_ymm(nrows-1);
+
+ vpxor(src1, src1, src1);
+ vpunpckhwd(src1, src0, src1);
+ vinserti64x4(zmm_src0, zmm_src0, src1, 0);
+ vpxor(src1, src1, src1);
+ vpunpcklwd(src0, src0, src1);
+ vinserti64x4(zmm_src0, zmm_src0, src0, 1);
+ vpermpd(zmm_src0 | kFF, vidx1, zmm_src0);
+ store(zmm_src0, nrows-1);
+ }
+}
+
+void jit_trans_ow_oc_t::generate() {
+ preamble();
+
+ alignas(64) static constexpr const int64_t idx1[8]
+ = { 4, 5, 0, 1, 6, 7, 2, 3 };
+
+ const int oc_block = conf_->oc_block;
+ const int ow = conf_->ow;
+ const int transposes = utils::div_up(ow, transpose_size);
+ int loop_iters = nstl::max(0, transposes - 1);
+ tail = ow - loop_iters * transpose_size;
+
+ src_stride = oc_block * typesize;
+ tr_src_stride = oc_block * typesize;
+
+ bool nontemporal_stores = false;
+ enable_prefetch = ow > small_spatial ? 1 : 0;
+
+ const int src_step = oc_block * transpose_size * typesize;
+ const int tr_src_step = oc_block * transpose_size * typesize;
+ const int right_pad = ow % 2;
+
+ mov(reg_src, ptr [param1 + GET_OFF(src)]);
+ mov(reg_tr_src, ptr [param1 + GET_OFF(tr_src)]);
+ mov(reg_src_prf, ptr [param1 + GET_OFF(src_prf)]);
+ mov(reg_tr_src_prf, ptr [param1 + GET_OFF(tr_src_prf)]);
+
+ auto kmovw = [=](Opmask k, unsigned w) {
+ mov(regw_tmp, w);
+ jit_generator::kmovw(k, regw_tmp);
+ };
+
+ kmovw(kFF, 0xFF);
+
+ auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
+ mov(imm_addr64, reinterpret_cast<size_t>(addr));
+ jit_generator::vmovdqa64(z, ptr[imm_addr64]);
+ };
+
+ vmovdqa64(vidx1, idx1);
+ if (loop_iters) {
+ mov(reg_loop, loop_iters);
+ Label loop;
+ L(loop); {
+ transpose(transpose_size, 0, 0, nontemporal_stores);
+ add(reg_src, src_step);
+ add(reg_tr_src, tr_src_step);
+ add(reg_src_prf, src_step);
+ add(reg_tr_src_prf, tr_src_step);
+ sub(reg_loop, 1);
+ jnz(loop);
+ }
+ }
+ transpose(tail, 0, right_pad, nontemporal_stores);
+
+ postamble();
+}
+
+struct jit_trans_iw_x4_4x_t: public jit_trans_src_t, public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_trans_iw_x4_4x_t)
+
+ jit_trans_iw_x4_4x_t(const jit_conv_conf_t *conf): jit_trans_src_t(conf) {
+ generate();
+ ker_ = (decltype(ker_))this->getCode();
+ }
+
+ void generate();
+ enum { typesize = (int)sizeof(float) };
+};
+
+/** @brief transposition of the form [:][iw/4][4] -> [:][4][iw/4]
+ * required for 1st 4fma backward by weights convolution */
+void jit_trans_iw_x4_4x_t::generate() {
+ using namespace utils;
+
+ /* TODO: put into code */
+ static int mask[16] = {
+ 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, };
+
+ const auto &c = *conf_;
+ const int simd_w = cpu_isa_traits<avx512_common>::vlen / typesize;
+ const int niters = c.tr_ld / simd_w;
+
+ assert(niters <= 4); /* [bwd_w:tr_src:r1] */
+
+ Reg64 reg_ptr_src = r8;
+ Reg64 reg_ptr_tr_src = r9;
+
+ Reg64 reg_ih = rax;
+ Reg64 reg_ih_end = rbx;
+
+ Reg64 reg_nthr_oc_b = rsi;
+ Reg64 reg_ptr_tr_src_bctx = abi_not_param1;
+
+ Reg64 reg_tmp = rdx;
+
+ Zmm vmsk = Zmm(31);
+ Opmask kmsk = k7;
+
+ auto emit_tr_sync = [&]() {
+ simple_barrier::generate(*this, reg_ptr_tr_src_bctx, reg_nthr_oc_b);
+ };
+
+ auto emit_tr_iw = [&]() {
+ auto vreg = [](int iter, int i) {
+ assert(4 * iter + i < 24);
+ return Zmm(4 * iter + i);
+ };
+ auto vtmp = [](int i) { return Zmm(24 + i); };
+
+ auto emit_load = [&](int iter) {
+ for (int i = 0; i < 4; ++i) {
+ auto v = vreg(iter, i);
+ const int off = (iter * 4 + i) * simd_w;
+
+ if (off + simd_w <= c.iw)
+ vmovups(v, ptr[reg_ptr_src + off * typesize]);
+ else if (off < c.iw)
+ vmovups(v | kmsk | T_z, ptr[reg_ptr_src + off * typesize]);
+ else
+ vpxord(v, v, v);
+ }
+ };
+
+ auto emit_tr = [&](int iter) {
+ for (int i = 0; i < 4; ++i)
+ vpermps(vreg(iter, i), vmsk, vreg(iter, i));
+
+ vshuff32x4(vtmp(0), vreg(iter, 0), vreg(iter, 1), 0x88);
+ vshuff32x4(vtmp(1), vreg(iter, 0), vreg(iter, 1), 0xdd);
+ vshuff32x4(vtmp(2), vreg(iter, 2), vreg(iter, 3), 0x88);
+ vshuff32x4(vtmp(3), vreg(iter, 2), vreg(iter, 3), 0xdd);
+
+ vshuff32x4(vreg(iter, 0), vtmp(0), vtmp(2), 0x88);
+ vshuff32x4(vreg(iter, 2), vtmp(0), vtmp(2), 0xdd);
+ vshuff32x4(vreg(iter, 1), vtmp(1), vtmp(3), 0x88);
+ vshuff32x4(vreg(iter, 3), vtmp(1), vtmp(3), 0xdd);
+ };
+
+ auto emit_store = [&]() {
+ for (int i = 0; i < 4; ++i) {
+ for (int iter = 0; iter < niters; ++iter) {
+ const size_t off = i * c.tr_ld + iter * simd_w;
+ vmovups(ptr[reg_ptr_tr_src + off * typesize], vreg(iter, i));
+ }
+ }
+ };
+
+ for (int iter = 0; iter < niters; ++iter)
+ emit_load(iter);
+
+ for (int iter = 0; iter < niters; ++iter)
+ emit_tr(iter);
+
+ emit_store();
+ };
+
+ preamble();
+
+ mov(reg_ptr_src, ptr[abi_param1 + GET_OFF(src)]);
+ mov(reg_ptr_tr_src, ptr[abi_param1 + GET_OFF(tr_src)]);
+
+ mov(reg_nthr_oc_b.cvt32(), ptr[abi_param1 + GET_OFF(nthr_oc_b)]);
+ mov(reg_ih.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_start)]);
+ mov(reg_ih_end.cvt32(), ptr[abi_param1 + GET_OFF(tr_src_ih_end)]);
+ mov(reg_ptr_tr_src_bctx, ptr[abi_param1 + GET_OFF(tr_src_bctx)]);
+
+ emit_tr_sync();
+
+ Label l_ih_loop, l_tr_done;
+ cmp(reg_ih, reg_ih_end);
+ je(l_tr_done, T_NEAR);
+
+ mov(reg_tmp, (size_t)&mask[0]);
+ vmovups(vmsk, ptr[reg_tmp]);
+
+ if (c.iw % simd_w) {
+ const char load_mask = (1 << (c.iw % simd_w)) - 1;
+ mov(reg_tmp, load_mask);
+ kmovw(kmsk, reg_tmp.cvt32());
+ }
+
+ /* src += ih_start * c.iw; */
+ imul(reg_tmp, reg_ih, c.iw * typesize);
+ add(reg_ptr_src, reg_tmp);
+ /* tr_src += ih_start * c.stride_w * c.tr_ld; */
+ imul(reg_tmp, reg_ih, c.stride_w * c.tr_ld * typesize);
+ add(reg_ptr_tr_src, reg_tmp);
+
+ L(l_ih_loop); {
+ emit_tr_iw();
+
+ add(reg_ptr_src, c.iw * typesize);
+ add(reg_ptr_tr_src, c.stride_w * c.tr_ld * typesize);
+
+ inc(reg_ih);
+ cmp(reg_ih, reg_ih_end);
+ jl(l_ih_loop, T_NEAR);
+ }
+
+ L(l_tr_done);
+
+ emit_tr_sync();
+
+ postamble();
+}
+
+/*
+// -------------------------------------------------
+// jit_transpose4x16_src
+// -------------------------------------------------
+*/
+
+void jit_transpose4x16_src::transpose(int nrows)
+{
+ assert(nrows >= 0 && nrows <= transpose_size);
+ static_assert(transpose_size == 4, "Unsupported transpose size");
+ if (!nrows)
+ return;
+
+ auto pf_src_t0 = [=](int i) {
+ if (tparams->src_pf0_distance)
+ prefetcht0(EVEX_compress_addr(
+ reg_src, (tparams->src_pf0_distance + i) * src_stride));
+ };
+
+ auto pf_tr_src_t0 = [=](int i) {
+ if (tparams->tr_src_pf0_distance)
+ prefetcht0(EVEX_compress_addr(reg_tr_src,
+ (tparams->tr_src_pf0_distance + i) * src_stride));
+ };
+
+ auto pf_src_t1 = [=](int i) {
+ if (tparams->src_pf1)
+ prefetcht1(EVEX_compress_addr(reg_src_prf, i * src_stride));
+ };
+
+ auto pf_tr_src_t1 = [=](int i) {
+ if (tparams->tr_src_pf1)
+ prefetchwt1(EVEX_compress_addr(reg_tr_src_prf, i * tr_src_stride));
+ };
+
+ auto src_zmm = [=](int i) {
+ assert(i >= 0 && i < 4);
+ return Zmm(i);
+ };
+
+ auto tmp_zmm = [=](int i) {
+ assert(i >= 0 && i < 4);
+ return Zmm(4 + i);
+ };
+
+ auto load = [=](int i) {
+ vmovups(src_zmm(i), EVEX_compress_addr(reg_src, i * src_stride));
+ };
+
+ auto store = [=](Zmm r, int i) {
+ vmovups(EVEX_compress_addr(reg_tr_src, i * tr_src_stride), r);
+ };
+
+ auto tmp0 = tmp_zmm(0);
+ auto tmp1 = tmp_zmm(1);
+ auto tmp2 = tmp_zmm(2);
+ auto tmp3 = tmp_zmm(3);
+
+ auto src0 = src_zmm(0);
+ auto src1 = src_zmm(1);
+ auto src2 = src_zmm(2);
+ auto src3 = src_zmm(3);
+ for (int i = 0; i < nrows; i++) {
+ load(i);
+ }
+
+ for (size_t i = nrows; i < 4; i++) {
+ vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
+ }
+
+ vmovupd(tmp0, src0);
+ vmovupd(tmp1, src1);
+ pf_src_t0(0);
+ vpermpd(tmp0 | kF0, vidx01, src2);
+ vpermpd(tmp1 | kF0, vidx01, src3);
+
+ valignd(src0, src0, src0, 8);
+ valignd(src1, src1, src1, 8);
+ pf_src_t0(1);
+ vmovupd(tmp2, src0);
+ vmovupd(tmp3, src1);
+ pf_src_t0(2);
+ vpermpd(tmp2 | kF0, vidx10, src2);
+ vpermpd(tmp3 | kF0, vidx10, src3);
+ pf_src_t0(3);
+
+ vmovupd(src0, tmp0);
+ pf_src_t1(0);
+ vmovupd(src1, tmp2);
+ pf_src_t1(1);
+ vmovupd(src2, tmp1);
+ pf_src_t1(2);
+ vmovupd(src3, tmp3);
+ pf_src_t1(3);
+ vpermpd(src0 | kCC, vidx1, tmp1);
+ vpermpd(src1 | kCC, vidx1, tmp3);
+ pf_tr_src_t0(0);
+ vpermpd(src2 | k33, vidx1, tmp0);
+ vpermpd(src3 | k33, vidx1, tmp2);
+ pf_tr_src_t0(1);
+
+ vmovupd(tmp0, src0);
+ vmovupd(tmp1, src2);
+ pf_tr_src_t0(2);
+ vmovupd(tmp2, src1);
+ vmovupd(tmp3, src3);
+ pf_tr_src_t0(3);
+ vpermps(tmp0 | kFFFF, vidxP, src0);
+ pf_tr_src_t1(0);
+ vpermps(tmp1 | kFFFF, vidxP, src2);
+ pf_tr_src_t1(1);
+ vpermps(tmp2 | kFFFF, vidxP, src1);
+ pf_tr_src_t1(3);
+ vpermps(tmp3 | kFFFF, vidxP, src3);
+ pf_tr_src_t1(4);
+
+ store(tmp0, 0);
+ store(tmp1, 1);
+ store(tmp2, 2);
+ store(tmp3, 3);
+}
+
+alignas(64) static constexpr const int64_t idx01[8]
+ = { 0, 0, 0, 0, 0, 1, 2, 3 };
+alignas(64) static constexpr const int64_t idx10[8]
+ = { 0, 0, 0, 0, 4, 5, 6, 7 };
+alignas(64) static constexpr const int64_t idx1[8] = { 2, 3, 0, 1, 6, 7, 4, 5 };
+alignas(64) static constexpr const int32_t idxP[16]
+ = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 };
+
+void jit_transpose4x16_src::generate()
+{
+ preamble();
+
+ const int ic_block = params->ic_block;
+ const int is = params->is;
+ int tail = is % transpose_size;
+
+ src_stride = ic_block * typesize;
+ assert(src_stride == 64);
+ tr_src_stride = ic_block * typesize;
+
+ const int src_step = ic_block * transpose_size * typesize;
+ const int tr_src_step = ic_block * transpose_size * typesize;
+
+#define GET_TR_OFF(x) offsetof(jit_src_transpose_s, x)
+ mov(reg_loop, ptr[param1 + GET_TR_OFF(size)]);
+ mov(reg_src, ptr[param1 + GET_TR_OFF(src)]);
+ mov(reg_tr_src, ptr[param1 + GET_TR_OFF(tr_src)]);
+ mov(reg_src_prf, ptr[param1 + GET_TR_OFF(src_prf)]);
+ mov(reg_tr_src_prf, ptr[param1 + GET_TR_OFF(tr_src_prf)]);
+#undef GET_TR_OFF
+
+ auto kmovw = [=](Opmask k, unsigned w) {
+ mov(regw_tmp, w);
+ jit_generator::kmovw(k, regw_tmp);
+ };
+
+ auto vmovdqa64 = [=](Zmm z, const int64_t *addr) {
+ mov(imm_addr64, reinterpret_cast<size_t>(addr));
+ jit_generator::vmovdqa64(z, ptr[imm_addr64]);
+ };
+
+ auto vmovdqa32 = [=](Zmm z, const int32_t *addr) {
+ mov(imm_addr64, reinterpret_cast<size_t>(addr));
+ jit_generator::vmovdqa32(z, ptr[imm_addr64]);
+ };
+
+ kmovw(kF0, 0xf0); // 11110000
+ kmovw(kCC, 0xcc); // 11001100
+ kmovw(k33, 0x33); // 00110011
+ kmovw(kFFFF, 0xffff); // 1111111111111111
+
+ vmovdqa64(vidx01, idx01);
+ vmovdqa64(vidx10, idx10);
+ vmovdqa64(vidx1, idx1);
+ vmovdqa32(vidxP, idxP);
+
+ Label loop_label;
+ Label tail_label;
+
+ cmp(reg_loop, transpose_size);
+ jl(tail_label, T_NEAR);
+
+ L(loop_label);
+ {
+ transpose(transpose_size);
+ add(reg_src, src_step);
+ add(reg_tr_src, tr_src_step);
+ add(reg_src_prf, src_step);
+ add(reg_tr_src_prf, tr_src_step);
+ sub(reg_loop, transpose_size);
+ cmp(reg_loop, transpose_size);
+ jge(loop_label, T_NEAR);
+ }
+ L(tail_label);
+ transpose(tail);
+
+ postamble();
+}
+
+jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf) {
+ if (conf->ver == ver_4fma && !conf->is_1stconv)
+ return new jit_trans_iw_ic_t(conf);
+ if (conf->ver == ver_4fma && conf->is_1stconv)
+ return new jit_trans_iw_x4_4x_t(conf);
+ assert(!"unsupported configuration");
+ return nullptr;
+}
+
+jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf) {
+ assert(!"unsupported configuration");
+ return nullptr;
+}
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp
new file mode 100644
index 0000000000..565e97e4fc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_transpose_src_utils.hpp
@@ -0,0 +1,145 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_TRANSPOSE_SRC_HPP
+#define CPU_JIT_TRANSPOSE_SRC_HPP
+
+#include "cpu_barrier.hpp"
+#include "jit_primitive_conf.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_trans_src_t {
+ struct ctx_t {
+ const void *src;
+ const void *tr_src;
+ const void *src_prf;
+ const void *tr_src_prf;
+
+ /* 1st conv 4fma: backward by weights */
+ int nthr_oc_b; /* number of threads process given src image */
+ int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */
+ simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */
+ };
+
+ jit_trans_src_t(const jit_conv_conf_t *conf)
+ : conf_(conf), ker_(nullptr) {}
+ virtual ~jit_trans_src_t() {}
+
+ void operator()(const ctx_t *ctx)
+ { assert(ker_); ker_(ctx); }
+
+ const jit_conv_conf_t *conf_;
+ void (*ker_)(const ctx_t *);
+};
+
+struct jit_src_transpose_s {
+ size_t size;
+ const void *src;
+ const void *tr_src;
+ const void *src_prf;
+ const void *tr_src_prf;
+};
+
+struct jit_trans_dst_t {
+ struct ctx_t {
+ const void *src;
+ const void *tr_src;
+ const void *src_prf;
+ const void *tr_src_prf;
+
+ /* 1st conv 4fma: backward by weights */
+ int nthr_oc_b; /* number of threads process given src image */
+ int tr_src_ih_start, tr_src_ih_end; /* thread's transposition bounds */
+ simple_barrier::ctx_t *tr_src_bctx; /* transposition synchronization */
+ };
+
+ jit_trans_dst_t(const jit_conv_conf_t *conf)
+ : conf_(conf), ker_(nullptr) {}
+ virtual ~jit_trans_dst_t() {}
+
+ void operator()(const ctx_t *ctx)
+ { assert(ker_); ker_(ctx); }
+
+ const jit_conv_conf_t *conf_;
+ void (*ker_)(const ctx_t *);
+};
+
+struct jit_transpose4x16_src_t {
+ int src_pf0_distance;
+ int tr_src_pf0_distance;
+ bool src_pf1;
+ bool tr_src_pf1;
+};
+
+struct jit_transpose4x16_src : public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_transpose4x16_src)
+
+ jit_transpose4x16_src(const jit_1x1_conv_conf_t *aparams,
+ jit_transpose4x16_src_t *tparams_)
+ : params(aparams), tparams(tparams_)
+ {
+ this->generate();
+ jit_ker = (decltype(jit_ker))this->getCode();
+ }
+
+ const jit_1x1_conv_conf_t *params;
+ const jit_transpose4x16_src_t *tparams;
+ void (*jit_ker)(jit_src_transpose_s *);
+
+ void operator()(jit_src_transpose_s *arg) { jit_ker(arg); }
+
+ static const int transpose_size = 4;
+private:
+ static const int typesize = sizeof(float);
+
+ int src_stride, tr_src_stride;
+
+ Xbyak::Reg64 imm_addr64 = rbx;
+
+ Xbyak::Opmask kF0 = k1;
+ Xbyak::Opmask kCC = k2;
+ Xbyak::Opmask k33 = k3;
+ Xbyak::Opmask kFFFF = k4;
+
+ Xbyak::Zmm vidx01 = zmm31;
+ Xbyak::Zmm vidx10 = zmm30;
+ Xbyak::Zmm vidx1 = zmm29;
+ Xbyak::Zmm vidxP = zmm28;
+
+ Xbyak::Reg64 reg_src = r8;
+ Xbyak::Reg64 reg_tr_src = r9;
+ Xbyak::Reg64 reg_src_prf = r10;
+ Xbyak::Reg64 reg_tr_src_prf = r11;
+ Xbyak::Reg64 reg_loop = r12;
+ Xbyak::Reg64 reg_tr_src_tmp = r13;
+ Xbyak::Reg32 regw_tmp = r14d;
+
+ void transpose_block(int ur, int nrows);
+ void transpose(int nrows);
+ void generate();
+};
+
+jit_trans_src_t *create_trans_src(const jit_conv_conf_t *conf);
+jit_trans_dst_t *create_trans_dst(const jit_conv_conf_t *conf);
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp
new file mode 100644
index 0000000000..53313f9f01
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_1x1_conv_utils.hpp
@@ -0,0 +1,327 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_UNI_1x1_CONV_UTILS_HPP
+#define JIT_UNI_1x1_CONV_UTILS_HPP
+
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+
+struct reduce_to_unit_stride_t {
+ convolution_desc_t conv_d_;
+ bool reduce_src_;
+ size_t space_per_thread_;
+};
+
+/* 1x1-kernel does not support non-unit strides so far, so the idea is:
+ * - for fwd or bwd_weights: to copy src to a scratch memory (with strides
+ * equal to 1) and then call the kernel
+ * - for bwd_data: reduce the problem to the one with unit stride by
+ * performing computations in a scratch memory (with strides equal to 1)
+ * and then copy the result to diff_src */
+template <typename conv_pd_t>
+inline void rtus_prepare(conv_pd_t *self, const convolution_desc_t *&conv_d,
+ const memory_desc_t *&src_d, const memory_desc_t *dst_d) {
+ const bool is_bwd_data = self->desc()->prop_kind
+ == prop_kind::backward_data;
+
+ const int ndims = src_d->ndims;
+ const auto dat_tag = ndims == 3
+ ? memory_desc_wrapper(dst_d).matches_one_of_tag(
+ format_tag::nCw8c, format_tag::nCw16c)
+ : memory_desc_wrapper(dst_d).matches_one_of_tag(
+ format_tag::nChw8c, format_tag::nChw16c);
+
+ bool rtus_applicable = true
+ && utils::pick(ndims - 3,
+ (conv_d->strides[0] != 1 && !one_of(conv_d->src_desc.data_type,
+ data_type::s32)),
+ (conv_d->strides[0] != 1 || conv_d->strides[1] != 1))
+ && dat_tag != format_tag::undef;
+ for (int d = 2; d < ndims; ++d) {
+ /* TODO: relax these conditions (by improving reducer) */
+ rtus_applicable = rtus_applicable
+ && conv_d->padding[0][d - 2] == 0
+ && dst_d->dims[d] * conv_d->strides[d - 2] == src_d->dims[d];
+ }
+
+ if (rtus_applicable) {
+ self->rtus_.reduce_src_ = true;
+ conv_d = &(self->rtus_.conv_d_ = *conv_d);
+ self->rtus_.conv_d_.strides[0] = 1;
+ if (ndims == 4)
+ self->rtus_.conv_d_.strides[1] = 1;
+ utils::array_set(self->rtus_.conv_d_.padding[0], 0, 2);
+ if (ndims == 4)
+ utils::array_set(self->rtus_.conv_d_.padding[1], 0, 2);
+ const int ic = src_d->dims[1];
+ if (is_bwd_data) {
+ src_d = &(self->rtus_.conv_d_.diff_src_desc = *dst_d);
+ self->rtus_.conv_d_.diff_src_desc.dims[1] = ic;
+ memory_desc_wrapper::compute_blocking(
+ self->rtus_.conv_d_.diff_src_desc, dat_tag);
+ } else {
+ data_type_t data_type = self->rtus_.conv_d_.src_desc.data_type;
+ src_d = &(self->rtus_.conv_d_.src_desc = *dst_d);
+ self->rtus_.conv_d_.src_desc.dims[1] = ic;
+ self->rtus_.conv_d_.src_desc.data_type = data_type;
+ memory_desc_wrapper::compute_blocking(
+ self->rtus_.conv_d_.src_desc, dat_tag);
+ }
+ }
+}
+
+template <typename conv_pd_t>
+inline void rtus_prepare_space_info(conv_pd_t *self,
+ memory_tracking::registrar_t &scratchpad) {
+ const auto &jcp = self->jcp_;
+
+ const int max_threads = mkldnn_get_max_threads();
+ const size_t factor = utils::pick_by_prop_kind(self->desc()->prop_kind,
+ jcp.nb_reduce, jcp.nb_load_blocking_max, jcp.nb_bcast_blocking);
+ size_t typesize = types::data_type_size(
+ conv_prop_invariant_src_d(self->desc())->data_type);
+
+ self->rtus_.space_per_thread_ = factor * jcp.is * jcp.ic_block;
+ scratchpad.book(memory_tracking::names::key_conv_rtus_space,
+ typesize * max_threads * self->rtus_.space_per_thread_);
+}
+
+template <cpu_isa_t isa>
+struct rtus_driver_t: public jit_generator {
+
+ struct call_params_t {
+ const void *ws; /* reduced image (w/ strides = 1) */
+ const void *src; /* source image (w/ non-unit strides) */
+ size_t icb;
+ size_t os;
+ size_t iw_start;
+ };
+
+ void (*ker_)(const call_params_t *p);
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(rtus_driver_t)
+
+ /* cpu specific part */
+ using Vmm = typename utils::conditional<isa == avx2, Xbyak::Ymm,
+ Xbyak::Zmm>::type;
+
+ Xbyak::Reg64 reg_ws = abi_param1;
+ Xbyak::Reg64 reg_src = abi_not_param1;
+ Xbyak::Reg64 reg_icb = rdx;
+ Xbyak::Reg64 reg_os = r11;
+ Xbyak::Reg64 reg_iw_start = r8;
+
+ Xbyak::Reg64 reg_cur_os = rax;
+ Xbyak::Reg64 reg_cur_iw = r9;
+ Xbyak::Reg64 reg_cur_src = r10;
+
+ int iw_, stride_w_;
+ int src_step_h_, src_step_icb_, ws_step_icb_, vlen_, vlen_shift_;
+ bool src_to_ws_;
+ size_t typesize_;
+ Vmm reg_zero;
+ Vmm reg_v;
+
+ rtus_driver_t(int iw, int stride_w, int src_step_h,
+ int src_step_icb, int ws_step_icb, bool src_to_ws, size_t typesize)
+ : iw_(iw), stride_w_(stride_w), src_step_h_(src_step_h)
+ , src_step_icb_(src_step_icb), ws_step_icb_(ws_step_icb)
+ , src_to_ws_(src_to_ws), typesize_(typesize)
+ {
+ using namespace Xbyak;
+ vlen_ = cpu_isa_traits<isa>::vlen;
+ vlen_shift_ = cpu_isa_traits<isa>::vlen_shift;
+ if (typesize_ == 2) {
+ vlen_ /= 2;
+ vlen_shift_--;
+ }
+
+ reg_zero = Vmm(0);
+ reg_v = Vmm(1);
+
+ generate();
+ }
+
+ void loop_is() {
+ using namespace Xbyak;
+
+ mov(reg_cur_src, reg_src);
+ mov(reg_cur_iw, reg_iw_start);
+ mov(reg_cur_os, reg_os);
+
+ Label is_loop, skip_h_step;
+ L(is_loop);
+
+ if (src_to_ws_) {
+ vmovups(reg_v, ptr[reg_cur_src]);
+ vmovups(ptr[reg_ws], reg_v);
+ } else {
+ vmovups(reg_v, ptr[reg_ws]);
+ vmovups(ptr[reg_cur_src], reg_v);
+ for (int w = 1; w < stride_w_; ++w)
+ vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
+ }
+
+ add(reg_ws, vlen_);
+
+ add(reg_cur_iw, stride_w_);
+ add(reg_cur_src, stride_w_ * vlen_);
+
+ cmp(reg_cur_iw, iw_);
+ jl(skip_h_step);
+ /* for 1d convolution the loop over h should be skipped */
+ if (src_step_icb_ == iw_) jmp(skip_h_step);
+
+ if (src_to_ws_) {
+ add(reg_cur_src, (src_step_h_ - iw_) * vlen_);
+ } else {
+ Xbyak::Reg64 reg_cur_src_fin = reg_cur_iw; /* just reuse */
+ mov(reg_cur_src_fin, reg_cur_src);
+ add(reg_cur_src_fin, (src_step_h_ - iw_) * vlen_);
+ Label ih_loop;
+ L(ih_loop);
+
+ for (int w = 0; w < stride_w_; ++w)
+ vmovups(ptr[reg_cur_src + w * vlen_], reg_zero);
+
+ add(reg_cur_src, stride_w_ * vlen_);
+ cmp(reg_cur_src, reg_cur_src_fin);
+ jl(ih_loop);
+ }
+ xor_(reg_cur_iw, reg_cur_iw);
+
+ L(skip_h_step);
+
+ sub(reg_cur_os, vlen_);
+ jnz(is_loop);
+
+ /* restore dst */
+ sub(reg_ws, reg_os);
+ }
+
+ void generate() {
+ using namespace Xbyak;
+ assert(isa == avx2 || isa == avx512_common
+ || isa == avx512_core || isa == avx512_mic);
+
+#if defined(_WIN32)
+ assert(reg_src == abi_not_param1 && abi_not_param1 == rdi);
+ push(rdi);
+#endif
+
+#define READ_PARAM(what) \
+ mov(reg_ ## what, ptr[abi_param1 + offsetof(call_params_t, what)])
+ READ_PARAM(src);
+ READ_PARAM(icb);
+ READ_PARAM(os);
+ READ_PARAM(iw_start);
+
+ assert(reg_ws == abi_param1);
+ READ_PARAM(ws); /* reg_ws should always be read the last */
+#undef READ_PARAM
+
+ shl(reg_os, vlen_shift_);
+
+ if (!src_to_ws_)
+ uni_vpxor(reg_zero, reg_zero, reg_zero);
+
+ Label icb_loop;
+ L(icb_loop);
+
+ loop_is();
+
+ add(reg_ws, ws_step_icb_ * vlen_);
+ add(reg_src, src_step_icb_ * vlen_);
+
+ dec(reg_icb);
+ jnz(icb_loop, T_NEAR);
+
+#if defined(_WIN32)
+ pop(rdi);
+#endif
+
+ uni_vzeroupper();
+ ret();
+ this->ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
+ this->getCode()));
+ }
+};
+
+template <cpu_isa_t isa, typename conv_t>
+inline void init_rtus_driver(conv_t *self) {
+ const auto &conf = *self->pd();
+ if (!conf.rtus_.reduce_src_) return;
+
+ const auto &cd = *conf.desc();
+ const int ndims = conf.ndims();
+ const int stride_h = (conf.ndims() == 3) ? 1 : cd.strides[0];
+ const int stride_w = cd.strides[ndims - 3];
+
+ const bool is_bwd_data = cd.prop_kind == prop_kind::backward_data;
+ const auto &src_d = is_bwd_data ? *conf.diff_src_md() : *conf.src_md();
+
+ const int ih = ndims == 3 ? 1 : src_d.dims[2];
+ const int iw = src_d.dims[ndims - 1];
+
+ const int src_step_h = stride_h * iw;
+ const int src_step_icb = ih * iw;
+ const int ws_step_icb = conf.jcp_.is;
+ const bool src_to_ws = !is_bwd_data;
+ const size_t typesize = types::data_type_size(
+ conv_prop_invariant_src_d(self->pd()->desc())->data_type);
+
+ self->rtus_driver_ = new rtus_driver_t<isa>(iw, stride_w, src_step_h,
+ src_step_icb, ws_step_icb, src_to_ws, typesize);
+}
+
+inline int best_divider(int value, int min_divider, int max_divider,
+ bool find_max, int step = 1)
+{
+ max_divider = nstl::max(1, nstl::min(max_divider, value));
+ min_divider = nstl::max(1, nstl::min(min_divider, max_divider));
+
+ auto loss_ratio = [](int total, int chunk)
+ { return float(rnd_up(total, chunk) - total) / rnd_up(total, chunk); };
+
+ float min_loss = FLT_MAX;
+ int x_divider = max_divider;
+ for (int divider = max_divider; divider >= min_divider; divider -= step) {
+ const float loss = loss_ratio(value, divider);
+ if ((find_max && loss < min_loss) || (!find_max && loss <= min_loss)) {
+ min_loss = loss;
+ x_divider = divider;
+ }
+ }
+ return x_divider;
+}
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp
new file mode 100644
index 0000000000..72fe3a8109
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.cpp
@@ -0,0 +1,1407 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "math_utils.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_barrier.hpp"
+#include "cpu_batch_normalization_utils.hpp"
+#include "jit_generator.hpp"
+
+#include "jit_uni_batch_normalization.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace {
+
+using namespace memory_tracking::names;
+
+using namespace Xbyak;
+namespace barrier = simple_barrier;
+
+typedef float data_t;
+
+template <cpu_isa_t isa>
+struct jit_bnorm_t: public jit_generator {
+ struct call_params_t {
+ // keep all sizes at 8 bytes -- jit code expects this
+ size_t N_ithr, N_nthr;
+ size_t coff_max, soff_max;
+ size_t mb_stride_Bc, spat_size, spat_size_loc;
+ size_t S_s, S_tail;
+ size_t is_cblk_tail;
+ data_t chan_size, eps, one;
+ const data_t *scale_shift;
+ const data_t *mean, *var;
+ const data_t *diff_scale_shift;
+ const data_t *src, *dst;
+ const data_t *diff_src, *diff_dst;
+ const data_t *rbuf1, *rbuf2;
+ const uint8_t *ws;
+ barrier::ctx_t *barrier;
+ };
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_bnorm_t)
+
+ /* cpu specific part */
+ using Vmm = typename utils::conditional3<isa == sse42, Xmm,
+ isa == avx2, Ymm, Zmm>::type;
+ const AddressFrame &vmmword = (isa == sse42) ? xword :
+ (isa == avx2) ? yword : zword;
+
+ const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
+
+ const batch_normalization_pd_t *bdesc_;
+ bool is_spatial_thr_;
+
+ void (*ker)(const call_params_t *);
+ void operator()(const call_params_t *p) { (*ker)(p); }
+
+ Reg64 reg_param = abi_param1;
+
+ Reg64 reg_scale_shift = rbx;
+ Reg64 reg_rbuf1 = abi_not_param1;
+ Reg64 reg_rbuf2 = rdx;
+
+ Reg64 reg_mean = rbp;
+ Reg64 reg_var = reg_param;
+ Reg64 reg_diff_scale_shift = rax;
+
+ Reg64 reg_coff = r8;
+ Reg64 reg_coff_max = r9;
+ Reg64 reg_soff = r10;
+ Reg64 reg_soff_max = r11;
+ Reg64 reg_ctr = r12;
+ Reg64 reg_roff = r13;
+
+ Reg64 reg_mb_stride_Bc = r14;
+
+ Reg64 reg_src = r15;
+ Reg64 reg_diff_src = reg_rbuf1;
+ Reg64 reg_dst = rsi;
+ Reg64 reg_diff_dst = reg_dst;
+
+ Reg64 reg_tmp_off = reg_roff;
+
+ // Reuse loop counters
+ Reg64 reg_bar = reg_coff;
+ Reg64 reg_nnthr = reg_soff; // must be usable w/ loops over coff
+ Reg64 reg_tmp = reg_ctr;
+
+ // Relu section
+ bool with_relu, with_relu_inf_only;
+ Vmm vzero; // is_fwd() ? vdiff_beta : vbeta
+ Reg64 reg_ws = reg_roff;
+ Label l_relu_mask_avx2;
+ Opmask kstore_mask = Opmask(1);
+
+ // channel tail processing
+ Opmask ktail_mask = Opmask(2);
+
+ size_t unroll_blocks;
+ size_t unroll_regs;
+ Vmm vbuf = Vmm(isa == avx512_common ? 20 : 5);
+ Vmm vdiff_beta = Vmm(isa == avx512_common ? 21 : 6);
+ Vmm vdiff_gamma = Vmm(isa == avx512_common ? 22 : 7);
+ Vmm vsqrtvar = Vmm(isa == avx512_common ? 23 : 8);
+ Vmm vone = Vmm(isa == avx512_common ? 24 : 9);
+ Vmm vmean = Vmm(isa == avx512_common ? 25 : 10);
+ Vmm vgamma = Vmm(isa == avx512_common ? 26 : 11);
+ Vmm vbeta = Vmm(isa == avx512_common ? 27 : 12);
+ Vmm veps = Vmm(isa == avx512_common ? 28 : 13);
+ Vmm vchan_size = Vmm(isa == avx512_common ? 29 : 14);
+ Vmm vtail_mask = Vmm(isa == avx512_common ? 30 : 15);
+
+ size_t t0_pf_offt;
+ size_t t1_pf_offt;
+ size_t spat_size;
+ size_t chan_data_offt;
+
+ enum {
+ stack_off_N_nthr = 0,
+ stack_off_N_ithr = 8,
+ stack_off_src = 16,
+ stack_off_dst = 24,
+ stack_off_diff_src = 32,
+ stack_off_diff_dst = 40,
+ stack_off_diff_scale_shift = 48,
+ stack_off_ws = 56,
+ stack_off_barrier = 64,
+ stack_off_spat_size_loc = 72,
+ stack_off_s_s = 80,
+ stack_off_s_tail = 88,
+ stack_off_is_cblk_tail = 96,
+ stack_size_required = 104,
+ };
+
+ bool is_c_padded() const {
+ const memory_desc_wrapper data_d(bdesc_->src_md());
+ return bdesc_->C() != data_d.padded_dims()[1];
+ }
+
+ void compute_static_strides() {
+ spat_size = bdesc_->D() * bdesc_->W() * bdesc_->H();
+ chan_data_offt = bdesc_->C() * sizeof(data_t);
+
+ if (isa == avx512_mic) {
+ t0_pf_offt = 4096;
+ t1_pf_offt = 0;
+ } else {
+ t0_pf_offt = 0;
+ t1_pf_offt = 0;
+ }
+ }
+
+ void load_common_params() {
+# define PARAM_OFF(x) offsetof(call_params_t, x)
+ mov(reg_rbuf1, ptr[reg_param + PARAM_OFF(rbuf1)]);
+ if (bdesc_->is_bwd())
+ mov(reg_rbuf2, ptr[reg_param + PARAM_OFF(rbuf2)]);
+ mov(reg_coff_max, ptr[reg_param + PARAM_OFF(coff_max)]);
+ mov(reg_soff_max, ptr[reg_param + PARAM_OFF(soff_max)]);
+ mov(reg_mb_stride_Bc, ptr[reg_param + PARAM_OFF(mb_stride_Bc)]);
+ shl(reg_coff_max, 2);
+ shl(reg_soff_max, 2);
+ shl(reg_mb_stride_Bc, 2);
+
+ mov(reg_mean, ptr[reg_param + PARAM_OFF(mean)]);
+ mov(reg_scale_shift, ptr[reg_param + PARAM_OFF(scale_shift)]);
+
+ uni_vbroadcastss(vchan_size, vmmword[reg_param + PARAM_OFF(chan_size)]);
+ uni_vbroadcastss(vone, vmmword[reg_param + PARAM_OFF(one)]);
+ uni_vbroadcastss(veps, vmmword[reg_param + PARAM_OFF(eps)]);
+
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_nthr)]);
+ mov(ptr[rsp + stack_off_N_nthr], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(N_ithr)]);
+ mov(ptr[rsp + stack_off_N_ithr], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(src)]);
+ mov(ptr[rsp + stack_off_src], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(dst)]);
+ mov(ptr[rsp + stack_off_dst], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_src)]);
+ mov(ptr[rsp + stack_off_diff_src], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_dst)]);
+ mov(ptr[rsp + stack_off_diff_dst], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(ws)]);
+ mov(ptr[rsp + stack_off_ws], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(barrier)]);
+ mov(ptr[rsp + stack_off_barrier], reg_tmp);
+ if (is_spatial_thr_) {
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(spat_size_loc)]);
+ mov(ptr[rsp + stack_off_spat_size_loc], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_s)]);
+ mov(ptr[rsp + stack_off_s_s], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(S_tail)]);
+ mov(ptr[rsp + stack_off_s_tail], reg_tmp);
+ }
+ if (is_c_padded()) {
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(is_cblk_tail)]);
+ mov(ptr[rsp + stack_off_is_cblk_tail], reg_tmp);
+ }
+
+ if (bdesc_->is_fwd()) {
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
+ mov(reg_var, reg_tmp);
+ } else {
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(diff_scale_shift)]);
+ mov(ptr[rsp + stack_off_diff_scale_shift], reg_tmp);
+ mov(reg_tmp, ptr[reg_param + PARAM_OFF(var)]);
+ mov(reg_var, reg_tmp);
+ }
+# undef PARAM_OFF
+ }
+
+ void prepare_tail_mask_avx512_common() {
+ if (!is_c_padded()) return;
+
+ const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
+ const int mask = (1 << tail) - 1;
+
+ Reg32 regw_tmp = reg_tmp.cvt32();
+ mov(regw_tmp, mask);
+ kmovw(ktail_mask, regw_tmp);
+ }
+
+ void prepare_tail_mask_avx2_common() {
+ if (!is_c_padded()) return;
+
+ const int tail = bdesc_->C() % (int)(vlen / sizeof(float));
+ static const uint32_t mask[16] = {0xffffffff, 0xffffffff, 0xffffffff,
+ 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
+ 0, 0, 0, 0, 0, 0, 0, 0};
+
+ mov(reg_tmp, reinterpret_cast<size_t>(&mask[8 - tail]));
+ vmovups(vtail_mask, ptr[reg_tmp]);
+ }
+
+ void prepare_relu() {
+ with_relu = bdesc_->is_fwd()
+ ? bdesc_->with_relu_post_op() || bdesc_->fuse_bn_relu()
+ : bdesc_->fuse_bn_relu();
+ with_relu_inf_only = with_relu && bdesc_->is_fwd()
+ && !(bdesc_->fuse_bn_relu() && bdesc_->is_training());
+
+ vzero = bdesc_->is_fwd() ? vdiff_beta : vbeta;
+ if (with_relu) {
+ uni_vpxor(vzero, vzero, vzero);
+ if (!bdesc_->is_fwd() && isa == avx2)
+ prepare_l_relu_mask_avx2();
+ }
+ }
+
+ void prepare_l_relu_mask_avx2() {
+ Label l_mask_after;
+ jmp(l_mask_after);
+ align(32);
+ L(l_relu_mask_avx2); /* [0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01] */
+ for (int i = 0; i < 8; ++i) dd(1<<i);
+ L(l_mask_after);
+ }
+
+ void fwd_process_relu_avx2(Vmm vdst, int offt, Vmm vstore_mask) {
+ Reg64 reg_store_mask = reg_diff_scale_shift;
+ shr(reg_soff, 5);
+ vcmpps(vstore_mask, vzero, vdst, _cmp_lt_os);
+ vmovmskps(reg_store_mask, vstore_mask);
+ mov(ptr[reg_ws + reg_soff + offt / (1 << 5)], reg_store_mask.cvt8());
+ vblendvps(vdst, vzero, vdst, vstore_mask);
+ shl(reg_soff, 5);
+ }
+
+ void fwd_process_relu_avx512_common(Vmm vdst, int offt) {
+ shr(reg_soff, 5);
+ vcmpps(kstore_mask, vzero, vdst, _cmp_lt_os);
+ kmovw(ptr[reg_ws + reg_soff + offt / (1 << 5)], kstore_mask);
+ vblendmps(vdst | kstore_mask, vzero, vdst);
+ shl(reg_soff, 5);
+ }
+
+ void bwd_process_relu_avx2(Vmm vdiff_dst, int offt, Vmm vstore_mask) {
+ shr(reg_soff, 5);
+ vpbroadcastb(vstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
+ vpand(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
+ vpcmpeqd(vstore_mask, vstore_mask, ptr[rip + l_relu_mask_avx2]);
+ vblendvps(vdiff_dst, vzero, vdiff_dst, vstore_mask);
+ shl(reg_soff, 5);
+ }
+
+ void bwd_process_relu_avx512_common(Vmm vdiff_dst, int offt) {
+ shr(reg_soff, 5);
+ kmovw(kstore_mask, ptr[reg_ws + reg_soff + offt / (1 << 5)]);
+ vmovups(vdiff_dst | kstore_mask | T_z, vdiff_dst);
+ shl(reg_soff, 5);
+ }
+
+ void uni_vmovups_tail_avx2_common(const Operand &dst,
+ const Operand &src, Label &l_ret) {
+ if (dst.isMEM()) {
+ vmaskmovps(dst.getAddress(), vtail_mask, Vmm(src.getIdx()));
+ } else {
+ vmaskmovps(Vmm(dst.getIdx()), vtail_mask, src.getAddress());
+ }
+ jmp(l_ret);
+ }
+
+ void uni_vmovups_tail_avx512_common(const Operand &dst,
+ const Operand &src, Label &l_ret) {
+ if (dst.isMEM())
+ uni_vmovups(dst.getAddress() | ktail_mask | T_z, Vmm(src.getIdx()));
+ else
+ uni_vmovups(Vmm(dst.getIdx()) | ktail_mask | T_z, src.getAddress());
+
+ jmp(l_ret);
+ }
+
+ void uni_vmovups_maybe_tail(const Operand &dst, const Operand &src) {
+ Label l_no_mask, l_ret;
+
+ if (is_c_padded()) {
+ mov(reg_tmp, ptr[rsp + stack_off_is_cblk_tail]);
+ cmp(reg_tmp, 0);
+ jz(l_no_mask);
+
+ lea(reg_tmp, ptr[reg_coff + vlen]);
+ cmp(reg_tmp, reg_coff_max);
+ jl(l_no_mask);
+ assert(isa == avx512_common || isa == avx2);
+ if (isa == avx512_common)
+ uni_vmovups_tail_avx512_common(dst, src, l_ret);
+ else if (isa == avx2)
+ uni_vmovups_tail_avx2_common(dst, src, l_ret);
+ }
+ L(l_no_mask);
+ if (dst.isMEM())
+ uni_vmovups(dst.getAddress(), Vmm(src.getIdx()));
+ else
+ uni_vmovups(Vmm(dst.getIdx()), src.getAddress());
+
+ L(l_ret);
+ }
+
+ void barrier() {
+ mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
+ mov(reg_bar, ptr[rsp + stack_off_barrier]);
+ simple_barrier::generate(*this, reg_bar, reg_nnthr);
+ }
+
+ Address mean_ptr(size_t offt = 0) {
+ return vmmword[reg_mean + reg_coff + offt + 0 * chan_data_offt];
+ }
+
+ Address var_ptr(size_t offt = 0) {
+ return vmmword[reg_var + reg_coff + offt + 0 * chan_data_offt];
+ }
+
+ Address diff_gamma_ptr(size_t offt = 0) {
+ return vmmword[reg_diff_scale_shift + reg_coff + offt
+ + 0 * chan_data_offt];
+ }
+
+ Address diff_beta_ptr(size_t offt = 0) {
+ return vmmword[reg_diff_scale_shift + reg_coff + offt
+ + 1 * chan_data_offt];
+ }
+
+ Address gamma_ptr(size_t offt = 0) {
+ return vmmword[reg_scale_shift + reg_coff + offt + 0 * chan_data_offt];
+ }
+
+ Address beta_ptr(size_t offt = 0) {
+ return vmmword[reg_scale_shift + reg_coff + offt + 1 * chan_data_offt];
+ }
+
+ template <typename init_t, typename body_t, typename fini_t>
+ void spat_loop(size_t len, size_t blocks, size_t regs,
+ init_t init, body_t body, fini_t fini) {
+ size_t factor = regs * blocks;
+ size_t loop_unroll = len / factor * factor;
+ size_t loop_tail = len - loop_unroll;
+ size_t num_active_regs = (len < regs) ? len : regs;
+ for (size_t i = 0; i < num_active_regs; i++)
+ init(i);
+ if (loop_unroll) {
+ if (is_spatial_thr_) {
+ mov(reg_ctr, ptr[rsp + stack_off_spat_size_loc]);
+ add(reg_soff, ptr[rsp + stack_off_s_s]);
+ } else {
+ mov(reg_ctr, loop_unroll);
+ }
+ Label label;
+ L(label); {
+ for (size_t i = 0; i < factor; i++) {
+ size_t base_reg = i % regs;
+ body(base_reg, i);
+ }
+ add(reg_soff, factor * vlen);
+ sub(reg_ctr, factor);
+ jnz(label);
+ }
+ if (is_spatial_thr_) {
+ add(reg_soff, ptr[rsp + stack_off_s_tail]);
+ }
+ }
+
+ for (size_t i = 0; i < loop_tail; i++) {
+ size_t base_reg = i % regs;
+ body(base_reg, i);
+ }
+ if (loop_tail)
+ add(reg_soff, loop_tail * vlen);
+
+ for (size_t i = 0; i < num_active_regs; i++)
+ fini(i);
+ }
+
+ void mean_channels() {
+ Label ch_label;
+ L(ch_label); {
+ uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
+ spat_loop(spat_size, unroll_blocks,
+ unroll_regs,
+ [=](size_t base_reg) {
+ Vmm v = Vmm(base_reg * 2);
+ if (base_reg)
+ uni_vpxor(v, v, v);
+ },
+ [=](size_t base_reg, size_t i) {
+ Vmm v0 = Vmm(base_reg * 2 + 0);
+ Vmm v1 = Vmm(base_reg * 2 + 1);
+ size_t offt = i * vlen;
+ uni_vmovups(v1,
+ vmmword[reg_src + reg_soff + offt]);
+ uni_vaddps(v0, v0, v1);
+ mic_prefetcht0(ptr[reg_src + reg_soff + offt
+ + t0_pf_offt]);
+ mic_prefetcht1(ptr[reg_src + reg_soff + offt
+ + t1_pf_offt]);
+ },
+ [=](size_t base_reg) {
+ Vmm b = Vmm(0);
+ Vmm v = Vmm(base_reg * 2);
+ if (base_reg)
+ uni_vaddps(b, b, v);
+ });
+ uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
+
+ add(reg_coff, vlen);
+ cmp(reg_coff, reg_coff_max);
+ jl(ch_label);
+ }
+ }
+
+ void var_channels() {
+ Label ch_label;
+ L(ch_label); {
+ uni_vmovups_maybe_tail(vmean, mean_ptr());
+ uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
+ spat_loop(spat_size, unroll_blocks, unroll_regs,
+ [=](size_t base_reg) {
+ Vmm v = Vmm(base_reg * 3);
+ if (base_reg > 0)
+ uni_vpxor(v, v, v);
+ },
+ [=](size_t base_reg, size_t i) {
+ Vmm v = Vmm(3 * base_reg);
+ Vmm vtmp0 = Vmm(3 * base_reg + 1);
+ Vmm vtmp1 = Vmm(3 * base_reg + 2);
+ size_t offt = i * vlen;
+ uni_vmovups(vtmp0,
+ vmmword[reg_src + reg_soff + offt]);
+ if (isa == sse42) {
+ movups(vtmp1, vmean);
+ subps(vtmp1, vtmp0);
+ } else {
+ vsubps(vtmp1, vmean, vtmp0);
+ }
+ uni_vfmadd231ps(v, vtmp1, vtmp1);
+
+ mic_prefetcht0(ptr[reg_src + reg_soff + offt
+ + t0_pf_offt]);
+ mic_prefetcht1(ptr[reg_src + reg_soff + offt
+ + t1_pf_offt]);
+ },
+ [=](size_t base_reg) {
+ Vmm b = Vmm(0);
+ Vmm v = Vmm(base_reg * 3);
+ if (base_reg)
+ uni_vaddps(b, b, v);
+ });
+ uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
+ add(reg_coff, vlen);
+ cmp(reg_coff, reg_coff_max);
+ jl(ch_label);
+ }
+ }
+
+ void compute_mean_variance() {
+ uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
+ xor_(reg_coff, reg_coff);
+ Label zero_rbuf;
+ L(zero_rbuf); {
+ uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
+ add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
+ cmp(reg_coff, reg_coff_max);
+ jne(zero_rbuf);
+ }
+
+ mov(reg_src, ptr[rsp + stack_off_src]);
+
+ xor_(reg_soff, reg_soff);
+ Label mean_spatial;
+ L(mean_spatial); {
+ xor_(reg_coff, reg_coff);
+
+ if (isa == sse42)
+ mov(reg_tmp_off, reg_soff);
+
+ mean_channels();
+
+ if (isa == sse42) {
+ mov(reg_soff, reg_tmp_off);
+ add(reg_src, vlen / 2);
+ mov(reg_coff, vlen / 2);
+
+ mean_channels();
+
+ sub(reg_src, vlen / 2);
+ }
+
+ add(reg_soff, reg_mb_stride_Bc);
+ cmp(reg_soff, reg_soff_max);
+ jne(mean_spatial);
+ }
+
+ Label no_mean_reduction;
+ barrier(); {
+ mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
+ cmp(reg_tmp, 0);
+ jne(no_mean_reduction);
+ mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
+ xor_(reg_coff, reg_coff);
+ Label mean_reduction_channels;
+ L(mean_reduction_channels); {
+ mov(reg_roff, reg_coff);
+ uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
+ uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
+ mov(reg_ctr, reg_nnthr);
+ Label mean_reduction_thrs;
+ L(mean_reduction_thrs); {
+ uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
+ uni_vmovups(vmmword[reg_rbuf1 + reg_roff], Vmm(0));
+ add(reg_roff, reg_coff_max);
+ sub(reg_ctr, 1);
+ jnz(mean_reduction_thrs);
+ }
+ uni_vdivps(Vmm(1), Vmm(1), vchan_size);
+ uni_vmovups_maybe_tail(mean_ptr(), Vmm(1));
+
+ add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
+
+ cmp(reg_coff, reg_coff_max);
+ jne(mean_reduction_channels);
+ }
+ }
+ L(no_mean_reduction);
+ barrier();
+
+ xor_(reg_soff, reg_soff);
+ Label var_spatial;
+ L(var_spatial); {
+ xor_(reg_coff, reg_coff);
+
+ if (isa == sse42)
+ mov(reg_tmp_off, reg_soff);
+
+ var_channels();
+
+ if (isa == sse42) {
+ mov(reg_soff, reg_tmp_off);
+ add(reg_src, vlen / 2);
+ mov(reg_coff, vlen / 2);
+
+ var_channels();
+
+ sub(reg_src, vlen / 2);
+ }
+
+ add(reg_soff, reg_mb_stride_Bc);
+ cmp(reg_soff, reg_soff_max);
+ jne(var_spatial);
+ }
+
+ Label no_var_reduction;
+ barrier(); {
+ mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
+ cmp(reg_tmp, 0);
+ jne(no_var_reduction);
+
+ mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
+ xor_(reg_coff, reg_coff);
+ Label var_reduction_channels;
+ L(var_reduction_channels); {
+ mov(reg_roff, reg_coff);
+ uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
+ mov(reg_ctr, reg_nnthr);
+ Label var_reduction_thrs;
+ L(var_reduction_thrs); { // TODO: unroll (?)
+ uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf1 + reg_roff]);
+ add(reg_roff, reg_coff_max);
+ sub(reg_ctr, 1);
+ jnz(var_reduction_thrs);
+ }
+ uni_vdivps(Vmm(1), Vmm(1), vchan_size);
+ uni_vmovups_maybe_tail(var_ptr(), Vmm(1));
+ add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
+
+ cmp(reg_coff, reg_coff_max);
+ jne(var_reduction_channels);
+ }
+ }
+ L(no_var_reduction);
+ barrier();
+ }
+
+ void forward_channels() {
+ Label ch_label;
+ L(ch_label); {
+ uni_vmovups_maybe_tail(vmean, mean_ptr());
+ uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
+ uni_vaddps(vsqrtvar, vsqrtvar, veps);
+ uni_vsqrtps(vsqrtvar, vsqrtvar);
+
+ if (bdesc_->use_scaleshift()) {
+ uni_vmovups_maybe_tail(vgamma, gamma_ptr());
+ uni_vmovups_maybe_tail(vbeta, beta_ptr());
+ }
+
+ Vmm vscale = bdesc_->use_scaleshift() ? vgamma : vone;
+ Vmm vdiv = bdesc_->use_scaleshift() ? vgamma : vsqrtvar;
+
+ if (isa == sse42) {
+ movups(vbuf, vscale);
+ divps(vbuf, vsqrtvar);
+ movups(vdiv, vbuf);
+ } else {
+ vdivps(vdiv, vscale, vsqrtvar);
+ }
+
+ auto compute = [=](bool output_is_aligned) {
+ spat_loop(spat_size, unroll_blocks, unroll_regs,
+ [](size_t base_reg) {UNUSED(base_reg);},
+ [=](size_t base_reg, size_t i) {
+ Vmm v = Vmm(base_reg);
+ size_t offt = i * vlen;
+ uni_vmovups(v,
+ vmmword[reg_src + reg_soff + offt]);
+ mic_prefetcht0(ptr[reg_src + reg_soff + offt
+ + t0_pf_offt]);
+ mic_prefetcht1(ptr[reg_src + reg_soff + offt
+ + t1_pf_offt]);
+ uni_vsubps(v, v, vmean);
+ if (bdesc_->use_scaleshift()) {
+ uni_vfmadd213ps(v, vgamma, vbeta);
+ } else {
+ uni_vmulps(v, v, vsqrtvar);
+ }
+ if (with_relu_inf_only) {
+ uni_vmaxps(v, v, vzero);
+ } else if (with_relu) {
+ if (isa == avx512_common)
+ fwd_process_relu_avx512_common(v, offt);
+ else
+ fwd_process_relu_avx2(v, offt, Vmm(3));
+ }
+ if (output_is_aligned) {
+ uni_vmovntps(
+ vmmword[reg_dst + reg_soff + offt], v);
+ } else {
+ uni_vmovups(
+ vmmword[reg_dst + reg_soff + offt], v);
+ }
+ },
+ [](size_t base_reg) {UNUSED(base_reg);});
+ };
+
+ Label unaligned_store, end_store;
+ test(reg_dst, vlen - 1);
+ jnz(unaligned_store, T_NEAR);
+ compute(true);
+ jmp(end_store, T_NEAR);
+ L(unaligned_store); {
+ compute(false);
+ }
+ L(end_store);
+
+ add(reg_coff, vlen);
+ cmp(reg_coff, reg_coff_max);
+ jl(ch_label);
+ }
+ }
+
+ void forward() {
+ mov(reg_src, ptr[rsp + stack_off_src]);
+ mov(reg_dst, ptr[rsp + stack_off_dst]);
+ mov(reg_ws, ptr[rsp + stack_off_ws]);
+
+ xor_(reg_soff, reg_soff);
+ Label dst_spatial;
+ L(dst_spatial); {
+ xor_(reg_coff, reg_coff);
+ if (isa == sse42)
+ mov(reg_tmp_off, reg_soff);
+
+ forward_channels();
+
+ if (isa == sse42) {
+ mov(reg_soff, reg_tmp_off);
+ add(reg_src, vlen / 2);
+ add(reg_dst, vlen / 2);
+ mov(reg_coff, vlen / 2);
+
+ forward_channels();
+
+ sub(reg_src, vlen / 2);
+ sub(reg_dst, vlen / 2);
+ }
+
+ add(reg_soff, reg_mb_stride_Bc);
+ cmp(reg_soff, reg_soff_max);
+ jnz(dst_spatial);
+ }
+ }
+
+ void backward_sh_channels() {
+ Label sh_channels;
+ L(sh_channels); {
+ uni_vmovups_maybe_tail(vmean, mean_ptr());
+ uni_vmovups(Vmm(0), vmmword[reg_rbuf1 + reg_coff]);
+ uni_vmovups(Vmm(1), vmmword[reg_rbuf2 + reg_coff]);
+ spat_loop(spat_size, 1, 1,
+ [=](size_t base_reg) {
+ if (base_reg > 0) {
+ for (int i = 0; i < 2; i++) {
+ Vmm v(base_reg * 5 + i);
+ uni_vpxor(v, v, v);
+ }
+ }
+ },
+ [=](size_t base_reg, size_t i) {
+ Vmm o0 = Vmm(base_reg * 5 + 0);
+ Vmm o1 = Vmm(base_reg * 5 + 1);
+ Vmm t1 = Vmm(base_reg * 5 + 2);
+ Vmm t2 = Vmm(base_reg * 5 + 3);
+ Vmm t3 = Vmm(base_reg * 5 + 4);
+ size_t offt = i * vlen;
+ uni_vmovups(t1, vmmword[reg_src + reg_soff + offt]);
+ uni_vmovups(t2, vmmword[reg_diff_dst + reg_soff
+ + offt]);
+ if (with_relu) {
+ if (isa == avx512_common)
+ bwd_process_relu_avx512_common(t2, offt);
+ else if (isa == avx2)
+ bwd_process_relu_avx2(t2, offt, t3);
+ else
+ assert(false);
+ }
+ uni_vsubps(t3, vmean, t1, t3);
+ if (isa == sse42) {
+ mulps(t3, t2);
+ subps(o0, t3);
+ } else {
+ vfnmadd231ps(o0, t3, t2);
+ }
+ uni_vaddps(o1, o1, t2);
+ mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
+ + t0_pf_offt]);
+ mic_prefetcht0(ptr[reg_src + reg_soff + offt
+ + t0_pf_offt]);
+ mic_prefetcht1(ptr[reg_diff_dst + reg_soff + offt
+ + t1_pf_offt]);
+ mic_prefetcht1(ptr[reg_src + reg_soff + offt
+ + t1_pf_offt]);
+ },
+ [=](size_t base_reg) {
+ Vmm b0 = Vmm(0);
+ Vmm b1 = Vmm(1);
+ if (base_reg) {
+ uni_vaddps(b0, b0, Vmm(base_reg * 5 + 0));
+ uni_vaddps(b1, b1, Vmm(base_reg * 5 + 1));
+ }
+ });
+ uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
+ uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(1));
+ add(reg_coff, vlen);
+ cmp(reg_coff, reg_coff_max);
+ jl(sh_channels);
+ }
+ }
+
+ void backward_diff_channels() {
+ Label diff_channels;
+ L(diff_channels); {
+ uni_vmovups_maybe_tail(vmean, mean_ptr());
+ uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
+ uni_vaddps(vsqrtvar, vsqrtvar, veps);
+ uni_vsqrtps(vsqrtvar, vsqrtvar);
+ uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
+ if (bdesc_->use_scaleshift())
+ uni_vmovups_maybe_tail(vgamma, gamma_ptr());
+ uni_vmovups_maybe_tail(vdiff_gamma, diff_gamma_ptr());
+ uni_vmovups_maybe_tail(vdiff_beta, diff_beta_ptr());
+ uni_vmulps(vdiff_gamma, vdiff_gamma, vsqrtvar);
+ uni_vdivps(vdiff_beta, vdiff_beta, vchan_size);
+ uni_vdivps(vdiff_gamma, vdiff_gamma, vchan_size);
+
+ auto compute = [=](bool output_is_aligned) {
+ spat_loop(spat_size, unroll_blocks, unroll_regs,
+ [=](size_t base_reg) {UNUSED(base_reg);},
+ [=](size_t base_reg, size_t i) {
+ Vmm v(base_reg * 2 + 0);
+ Vmm t(base_reg * 2 + 1);
+ Vmm t1(base_reg * 2 + 2);
+ size_t offt = i * vlen;
+ uni_vmovups(v, vmmword[reg_diff_dst + reg_soff
+ + offt]);
+ if (with_relu) {
+ if (isa == avx512_common)
+ bwd_process_relu_avx512_common(v, offt);
+ else if (isa == avx2)
+ bwd_process_relu_avx2(v, offt, t);
+ else
+ assert(false);
+ }
+ if (!bdesc_->use_global_stats()) {
+ uni_vsubps(v, v, vdiff_beta);
+ uni_vmovups(t, vmmword[reg_src + reg_soff
+ + offt]);
+ uni_vsubps(t, vmean, t, t1);
+ uni_vmulps(t, t, vdiff_gamma);
+ uni_vaddps(v, v, t);
+ }
+ uni_vmulps(v, v, vsqrtvar);
+ if (bdesc_->use_scaleshift()) {
+ uni_vmulps(v, v, vgamma);
+ }
+ if (output_is_aligned) {
+ uni_vmovntps(
+ vmmword[reg_diff_src + reg_soff + offt],
+ v);
+ } else {
+ uni_vmovups(
+ vmmword[reg_diff_src + reg_soff + offt],
+ v);
+ }
+ mic_prefetcht0(ptr[reg_diff_dst + reg_soff + offt
+ + t0_pf_offt]);
+ mic_prefetcht0(ptr[reg_src + reg_soff + offt
+ + t0_pf_offt]);
+ mic_prefetcht1(ptr[reg_diff_dst + reg_soff
+ + offt + t1_pf_offt]);
+ mic_prefetcht1(ptr[reg_src + reg_soff + offt
+ + t1_pf_offt]);
+ },
+ [=](size_t base_reg) {UNUSED(base_reg);});
+ };
+
+ Label unaligned_store, end_store;
+ test(reg_diff_src, vlen - 1);
+ jnz(unaligned_store, T_NEAR);
+ compute(true);
+ jmp(end_store, T_NEAR);
+ L(unaligned_store); {
+ compute(false);
+ }
+ L(end_store);
+
+ add(reg_coff, vlen);
+ cmp(reg_coff, reg_coff_max);
+ jl(diff_channels);
+ }
+ }
+
+ void backward() {
+ uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
+ xor_(reg_coff, reg_coff);
+ Label zero_rbuf, sh_spatial;
+
+ L(zero_rbuf); {
+ uni_vmovups(vmmword[reg_rbuf1 + reg_coff], Vmm(0));
+ uni_vmovups(vmmword[reg_rbuf2 + reg_coff], Vmm(0));
+ add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
+ cmp(reg_coff, reg_coff_max);
+ jne(zero_rbuf);
+ }
+
+ mov(reg_src, ptr[rsp + stack_off_src]);
+ mov(reg_diff_dst, ptr[rsp + stack_off_diff_dst]);
+ if (with_relu) {
+ assert(isa == avx2 || isa == avx512_common);
+ mov(reg_ws, ptr[rsp + stack_off_ws]);
+ }
+
+ xor_(reg_soff, reg_soff);
+ L(sh_spatial); {
+ xor_(reg_coff, reg_coff);
+ if (isa == sse42) {
+ mov(reg_tmp_off, reg_soff);
+ }
+ backward_sh_channels();
+ if (isa == sse42) {
+ mov(reg_soff, reg_tmp_off);
+ add(reg_diff_dst, vlen / 2);
+ add(reg_src, vlen / 2);
+ mov(reg_coff, vlen / 2);
+ backward_sh_channels();
+ sub(reg_diff_dst, vlen / 2);
+ sub(reg_src, vlen / 2);
+ }
+ add(reg_soff, reg_mb_stride_Bc);
+ cmp(reg_soff, reg_soff_max);
+ jne(sh_spatial);
+ }
+
+ mov(reg_diff_scale_shift, ptr[rsp + stack_off_diff_scale_shift]);
+
+ Label no_sh_reduction;
+ barrier(); {
+ mov(reg_tmp, ptr[rsp + stack_off_N_ithr]);
+ cmp(reg_tmp, 0);
+ Label sh_reduction_channels;
+ jne(no_sh_reduction, T_NEAR);
+
+ mov(reg_nnthr, ptr[rsp + stack_off_N_nthr]);
+ xor_(reg_coff, reg_coff);
+ L(sh_reduction_channels); {
+ mov(reg_roff, reg_coff);
+ uni_vpxor(Vmm(0), Vmm(0), Vmm(0));
+ uni_vpxor(Vmm(1), Vmm(1), Vmm(1));
+ uni_vmovups_maybe_tail(vsqrtvar, var_ptr());
+ uni_vaddps(vsqrtvar, vsqrtvar, veps);
+ uni_vsqrtps(vsqrtvar, vsqrtvar);
+ uni_vdivps(vsqrtvar, vone, vsqrtvar, vbuf);
+ mov(reg_ctr, reg_nnthr);
+ Label sh_reduction_thrs;
+ L(sh_reduction_thrs); { // TODO: unroll (?)
+ uni_vaddps(Vmm(0), Vmm(0), vmmword[reg_rbuf1 + reg_roff]);
+ uni_vaddps(Vmm(1), Vmm(1), vmmword[reg_rbuf2 + reg_roff]);
+ add(reg_roff, reg_coff_max);
+ sub(reg_ctr, 1);
+ jnz(sh_reduction_thrs);
+ }
+ uni_vmulps(Vmm(0), Vmm(0), vsqrtvar);
+ uni_vmovups_maybe_tail(diff_gamma_ptr(), Vmm(0));
+ uni_vmovups_maybe_tail(diff_beta_ptr(), Vmm(1));
+ add(reg_coff, isa == sse42 ? vlen / 2 : vlen);
+ cmp(reg_coff, reg_coff_max);
+ jne(sh_reduction_channels);
+ }
+ }
+ L(no_sh_reduction);
+ barrier();
+
+ mov(reg_diff_src, ptr[rsp + stack_off_diff_src]);
+ if (with_relu) {
+ assert(isa == avx2 || isa == avx512_common);
+ mov(reg_ws, ptr[rsp + stack_off_ws]);
+ }
+
+ xor_(reg_soff, reg_soff);
+ Label diff_spatial;
+ L(diff_spatial); {
+ xor_(reg_coff, reg_coff);
+ if (isa == sse42) {
+ mov(reg_tmp_off, reg_soff);
+ }
+ backward_diff_channels();
+ if (isa == sse42) {
+ mov(reg_soff, reg_tmp_off);
+ add(reg_diff_dst, vlen / 2);
+ add(reg_diff_src, vlen / 2);
+ add(reg_src, vlen / 2);
+ mov(reg_coff, vlen / 2);
+ backward_diff_channels();
+ sub(reg_diff_dst, vlen / 2);
+ sub(reg_diff_src, vlen / 2);
+ sub(reg_src, vlen / 2);
+ }
+ add(reg_soff, reg_mb_stride_Bc);
+ cmp(reg_soff, reg_soff_max);
+ jne(diff_spatial);
+ }
+ }
+
+ jit_bnorm_t(const batch_normalization_pd_t *bdesc): bdesc_(bdesc) {
+ static_assert(isa == sse42 || isa == avx2 || isa == avx512_common
+ || isa == avx512_mic, "unsupported isa");
+
+ const int simd_w = isa == sse42 ? 8 :
+ cpu_isa_traits<isa>::vlen / sizeof(data_t);
+ is_spatial_thr_ =
+ bnorm_utils::is_spatial_thr(bdesc_, simd_w, sizeof(data_t));
+
+ unroll_blocks = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
+ unroll_regs = isa == avx512_common && !is_spatial_thr_ ? 4 : 1;
+
+ preamble();
+
+ if (isa == avx512_common)
+ prepare_tail_mask_avx512_common();
+ else if (isa == avx2)
+ prepare_tail_mask_avx2_common();
+
+ compute_static_strides();
+ sub(rsp, stack_size_required);
+ load_common_params();
+ prepare_relu();
+
+ if (bdesc_->is_fwd()) {
+ if (!bdesc_->stats_is_src()) {
+ compute_mean_variance();
+ }
+ forward();
+ } else {
+ backward();
+ }
+ add(rsp, stack_size_required);
+ postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+ }
+};
+
+template <cpu_isa_t isa>
+struct uni_bnorm_driver_t: public c_compatible {
+ uni_bnorm_driver_t(const batch_normalization_pd_t *bdesc)
+ : bdesc_(bdesc), ker_(bdesc_)
+ {
+ const int nthrs = mkldnn_get_max_threads();
+ const dim_t C_PADDED = get_c_padded(bdesc_);
+
+ size_t data_size = sizeof(data_t) * bdesc_->MB() * C_PADDED
+ * bdesc_->D() * bdesc_->H() * bdesc_->W();
+ l3_size_ = get_cache_size(3, true) * nthrs / 2;
+ do_blocking_ = (data_size >= l3_size_ / 2 && l3_size_ > 0);
+ }
+
+ ~uni_bnorm_driver_t() {}
+
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const batch_normalization_pd_t *bdesc) {
+ int nthrs = mkldnn_get_max_threads();
+ dim_t C_PADDED = get_c_padded(bdesc);
+
+ int sbuf_sz = use_tmp_stats(bdesc) * 2 * C_PADDED;
+ int pbuf_sz = use_tmp_diff_scale_shift(bdesc) * 2 * C_PADDED;
+ int rbuf_sz = (bdesc->is_fwd() ? 1 : 2) * C_PADDED * nthrs;
+
+ scratchpad.book(key_bnorm_tmp_stats, sizeof(data_t) * sbuf_sz);
+ scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * pbuf_sz);
+ scratchpad.book(key_bnorm_reduction, sizeof(data_t) * rbuf_sz);
+
+ if (mkldnn_thr_syncable()) {
+ int n_barriers = C_PADDED / simd_w;
+ scratchpad.book(key_barrier, sizeof(barrier::ctx_t) * n_barriers);
+ }
+ }
+
+ void exec(int ithr, int nthr, const data_t *src, data_t *diff_src,
+ data_t *dst, const data_t *diff_dst, const data_t *scale_shift,
+ data_t *diff_scale_shift, const data_t *mean, const data_t *var,
+ const uint8_t *ws, const memory_tracking::grantor_t &scratchpad) {
+ auto sbuf = scratchpad.get<data_t>(key_bnorm_tmp_stats);
+ auto pbuf = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
+ auto rbuf = scratchpad.get<data_t>(key_bnorm_reduction);
+ auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
+
+ dim_t N = bdesc_->MB();
+ dim_t C = bdesc_->C();
+ dim_t C_PADDED = get_c_padded(bdesc_);
+ dim_t D = bdesc_->D();
+ dim_t H = bdesc_->H();
+ dim_t W = bdesc_->W();
+ dim_t SP = D * H * W;
+ dim_t img_size = C_PADDED * D * H * W;
+ const int vlen = isa == sse42 ? 32 : cpu_isa_traits<isa>::vlen;
+
+ typename jit_bnorm_t<isa>::call_params_t p;
+
+ p.eps = bdesc_->desc()->batch_norm_epsilon;
+ p.one = 1.0f;
+ p.spat_size = D * H * W;
+ p.chan_size = 1.0f * N * p.spat_size;
+
+ dim_t C_blks = C_PADDED / simd_w;
+
+ int C_ithr{0}, C_nthr{0}, N_ithr{0}, N_nthr{0}, S_ithr{0}, S_nthr{0};
+ dim_t C_blk_s{0}, C_blk_e{0}, N_s{0}, N_e{0}, S_s{0}, S_e{0};
+
+ dim_t C_blks_per_iter{ 1 };
+ int64_t iters{ 1 };
+ if (do_blocking_) {
+ int num_tensors = bdesc_->is_fwd() ? 1 : 2;
+ size_t working_set_size
+ = (N * D * H * W * simd_w * sizeof(data_t)) * num_tensors;
+ bnorm_utils::cache_balance(working_set_size, C_blks,
+ C_blks_per_iter, iters);
+ }
+
+ bool spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
+ true, ithr, nthr, N, do_blocking_ ? C_blks_per_iter : C_blks,
+ SP, C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s, N_e,
+ S_ithr, S_nthr, S_s, S_e);
+
+ int SP_N_ithr = N_ithr * S_nthr + S_ithr;
+ int SP_N_nthr = N_nthr * S_nthr;
+ assert(IMPLICATION(!mkldnn_thr_syncable(), SP_N_nthr == 1));
+
+ p.N_ithr = SP_N_ithr;
+ p.N_nthr = SP_N_nthr;
+
+ int last_iter_blks = C_blks - (iters - 1) * C_blks_per_iter;
+ int global_C_blk_s;
+ int global_barriers_per_iter = C_nthr;
+
+ for (int64_t it = 0; it < iters; it++) {
+ if (it == iters - 1 && iters > 1) {
+ C_blk_s = C_blk_e = N_s = N_e = 0;
+ spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking_,
+ spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
+ C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
+ N_e, S_ithr, S_nthr, S_s, S_e);
+
+ // Update call parameters for JIT, last iteration
+ p.N_ithr = N_ithr * S_nthr + S_ithr;
+ p.N_nthr = N_nthr * S_nthr;
+ }
+
+ global_C_blk_s = do_blocking_ ?
+ (C_blk_s == -1) ? -1 : it * C_blks_per_iter + C_blk_s :
+ C_blk_s;
+
+ int C_blks_thr = C_blk_e - C_blk_s;
+ int N_thr = N_e - N_s;
+
+ size_t coff_base = global_C_blk_s * simd_w;
+ size_t soff_base
+ = global_C_blk_s * p.spat_size * simd_w + N_s * img_size;
+
+ p.spat_size_loc = S_e - S_s;
+ p.S_s = S_s * vlen;
+ p.S_tail = (p.spat_size - S_e) * vlen;
+ p.coff_max = C_blks_thr * simd_w;
+ p.mean = (use_tmp_stats(bdesc_) ? sbuf : mean) + coff_base;
+ p.var = (use_tmp_stats(bdesc_) ? sbuf + C_PADDED : var) + coff_base;
+ p.scale_shift = scale_shift + coff_base;
+ p.diff_scale_shift = (use_tmp_diff_scale_shift(bdesc_)
+ ? pbuf : diff_scale_shift) + coff_base;
+
+ p.soff_max = N_thr * img_size;
+ p.src = src + soff_base;
+ p.dst = dst + soff_base;
+ p.diff_src = diff_src + soff_base;
+ p.diff_dst = diff_dst + soff_base;
+ p.ws = ws + soff_base / 8;
+
+ p.mb_stride_Bc = img_size - p.coff_max * p.spat_size;
+
+ // use SP_N_nthr which is the same as p.N_nthr except maybe for
+ // the last iteration.
+ p.rbuf1 = rbuf + ((it * C_blks_per_iter) * SP_N_nthr
+ + C_blk_s * p.N_nthr + p.N_ithr * C_blks_thr) * simd_w;
+ // rbuf1 and rbuf2 have to be disjoint
+ p.rbuf2 = p.rbuf1 + C_PADDED * nthr;
+ p.is_cblk_tail = (it * C_blks_per_iter + C_blk_e) * simd_w > C;
+
+ size_t iter_bariers
+ = do_blocking_ ? it * global_barriers_per_iter : 0;
+ p.barrier = barriers + C_ithr + iter_bariers;
+ if (p.soff_max != 0 && p.coff_max != 0)
+ ker_(&p);
+ }
+ }
+
+ void init_barriers(const memory_tracking::grantor_t &scratchpad) {
+ auto barriers = scratchpad.get<barrier::ctx_t>(key_barrier);
+ if (barriers) {
+ const int n_barriers = get_c_padded(bdesc_) / simd_w;
+ for (int i = 0; i < n_barriers; ++i)
+ barrier::ctx_init(&barriers[i]);
+ }
+ }
+
+private:
+ enum {
+ simd_w = isa == sse42 ? 8 : cpu_isa_traits<isa>::vlen / sizeof(data_t)
+ };
+
+ static bool use_tmp_stats(const batch_normalization_pd_t *bdesc) {
+ return true
+ && !bdesc->stats_is_src()
+ && bdesc->desc()->prop_kind == prop_kind::forward_inference;
+ }
+
+ static bool use_tmp_diff_scale_shift(const batch_normalization_pd_t *bdesc)
+ {
+ return false
+ || (bdesc->is_bwd() && !bdesc->use_scaleshift())
+ || bdesc->desc()->prop_kind == prop_kind::backward_data;
+ }
+
+ static dim_t get_c_padded(const batch_normalization_pd_t *bdesc)
+ { return bdesc->src_md()->padded_dims[1]; }
+
+ const batch_normalization_pd_t *bdesc_;
+ bool do_blocking_;
+ size_t l3_size_;
+
+ jit_bnorm_t<isa> ker_;
+};
+
+}
+
+using namespace data_type;
+using namespace format_tag;
+using namespace utils;
+
+/* fwd */
+
+template <cpu_isa_t isa>
+status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init() {
+ auto desired_fmt_tag = (ndims() == 4)
+ ? isa == avx512_common ? nChw16c : nChw8c
+ : isa == avx512_common ? nCdhw16c : nCdhw8c;
+
+ bool ok = true
+ && mayiuse(isa)
+ && is_fwd()
+ && !has_zero_dim_memory()
+ && one_of(ndims(), 4, 5)
+ && src_md()->data_type == f32
+ && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32)
+ && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
+ && (attr()->has_default_values() || this->with_relu_post_op());
+ if (!ok) return status::unimplemented;
+
+ if (is_training() && fuse_bn_relu()) {
+ if (isa < avx2) return status::unimplemented;
+ init_default_ws(1);
+ }
+
+ if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
+ && isa < avx2)
+ return status::unimplemented;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+jit_uni_batch_normalization_fwd_t<isa>::jit_uni_batch_normalization_fwd_t(
+ const pd_t *apd): cpu_primitive_t(apd)
+{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
+
+template <cpu_isa_t isa>
+status_t jit_uni_batch_normalization_fwd_t<isa>::execute(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+
+ auto mean = pd()->stats_is_src()
+ ? const_cast<data_t *>(CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN))
+ : CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN);
+ auto var = pd()->stats_is_src()
+ ? const_cast<data_t *>(CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE))
+ : CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE);
+
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ bnorm_driver_->init_barriers(scratchpad);
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ bnorm_driver_->exec(ithr, nthr, src, nullptr, dst, nullptr,
+ scale_shift, nullptr, mean, var, ws, scratchpad);
+ });
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+jit_uni_batch_normalization_fwd_t<isa>::~jit_uni_batch_normalization_fwd_t()
+{ delete bnorm_driver_; }
+
+/* bwd */
+
+template <cpu_isa_t isa>
+status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init() {
+ auto desired_fmt_tag = (ndims() == 4)
+ ? one_of(isa, sse42, avx2) ? nChw8c : nChw16c
+ : one_of(isa, sse42, avx2) ? nCdhw8c : nCdhw16c;
+
+ bool ok = true
+ && mayiuse(isa)
+ && is_bwd()
+ && !has_zero_dim_memory()
+ && one_of(ndims(), 4, 5)
+ && everyone_is(f32, src_md()->data_type, diff_src_md()->data_type)
+ && IMPLICATION(use_scaleshift(),
+ utils::everyone_is(f32,
+ weights_md()->data_type,
+ diff_weights_md()->data_type))
+ && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
+ && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag)
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ if (memory_desc_wrapper(src_md()).padded_dims()[1] != C()
+ && isa < avx2)
+ return status::unimplemented;
+
+ if (fuse_bn_relu()) {
+ if (isa < avx2) return status::unimplemented;
+ init_default_ws(1);
+ if (!compare_ws(hint_fwd_pd_))
+ return status::unimplemented;
+ }
+
+ /* TODO: extra checks required */
+
+ auto scratchpad = scratchpad_registry().registrar();
+ uni_bnorm_driver_t<isa>::init_scratchpad(scratchpad, this);
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+jit_uni_batch_normalization_bwd_t<isa>::jit_uni_batch_normalization_bwd_t(
+ const pd_t *apd): cpu_primitive_t(apd)
+{ bnorm_driver_ = new uni_bnorm_driver_t<isa>(pd()); }
+
+template <cpu_isa_t isa>
+status_t jit_uni_batch_normalization_bwd_t<isa>::execute(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
+ auto var = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto scale_shift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+ auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+ auto diff_scale_shift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ bnorm_driver_->init_barriers(scratchpad);
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ bnorm_driver_->exec(ithr, nthr, src, diff_src, nullptr, diff_dst,
+ scale_shift, diff_scale_shift, mean, var, ws, scratchpad);
+ });
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t()
+{ delete bnorm_driver_; }
+
+/* struct instantiation */
+template struct jit_uni_batch_normalization_fwd_t<sse42>;
+template struct jit_uni_batch_normalization_bwd_t<sse42>;
+template struct jit_uni_batch_normalization_fwd_t<avx2>;
+template struct jit_uni_batch_normalization_bwd_t<avx2>;
+template struct jit_uni_batch_normalization_fwd_t<avx512_common>;
+template struct jit_uni_batch_normalization_bwd_t<avx512_common>;
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp
new file mode 100644
index 0000000000..96410ec84e
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_batch_normalization.hpp
@@ -0,0 +1,100 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_UNI_BATCH_NORMALIZATION_HPP
+#define JIT_UNI_BATCH_NORMALIZATION_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_batch_normalization_pd.hpp"
+#include "cpu_isa_traits.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace { template <cpu_isa_t isa> struct uni_bnorm_driver_t; }
+
+template <cpu_isa_t isa>
+struct jit_uni_batch_normalization_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_batch_normalization_fwd_pd_t {
+ pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", isa, ""),
+ jit_uni_batch_normalization_fwd_t<isa>);
+
+ status_t init();
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ jit_uni_batch_normalization_fwd_t(const pd_t *apd);
+ ~jit_uni_batch_normalization_fwd_t();
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override;
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ uni_bnorm_driver_t<isa> *bnorm_driver_;
+};
+
+template <cpu_isa_t isa>
+struct jit_uni_batch_normalization_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_batch_normalization_bwd_pd_t {
+ pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", isa, ""),
+ jit_uni_batch_normalization_bwd_t<isa>);
+
+ status_t init();
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ jit_uni_batch_normalization_bwd_t(const pd_t *apd);
+ ~jit_uni_batch_normalization_bwd_t();
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override;
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ uni_bnorm_driver_t<isa> *bnorm_driver_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp
new file mode 100644
index 0000000000..b7dc5f85c5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.cpp
@@ -0,0 +1,1302 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "cpu_memory.hpp"
+
+#include "jit_uni_dw_conv_kernel_f32.hpp"
+
+#define GET_OFF(field) offsetof(jit_conv_call_s, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+using namespace Xbyak;
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::load_src(int ur_ch_blocks, int ur_w) {
+ int repeats = isa == sse42 ? 2 : 1;
+ for (int i = 0; i < repeats; i++) {
+ for (int ch = 0; ch < ur_ch_blocks; ch++) {
+ for (int ow = 0; ow < ur_w; ow++) {
+ Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
+
+ int b_off = ch*jcp.ch_block + i*4;
+ if (this->jcp.with_bias)
+ uni_vmovups(vmm_acc,
+ vmmword[reg_bias + b_off*sizeof(float)]);
+ else
+ uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
+
+ int o_off = ch*jcp.oh*jcp.ow*jcp.ch_block
+ + ow*jcp.ch_block + i*4;
+ if (this->jcp.with_sum)
+ uni_vaddps(vmm_acc, vmm_acc,
+ vmmword[reg_output + o_off*sizeof(float)]);
+ }
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter(
+ int ur_ch_blocks, int ur_w) {
+ int ch_blk = jcp.ch_block;
+ int dilate_h = jcp.dilate_h + 1;
+ int dilate_w = jcp.dilate_w + 1;
+ int stride_w = jcp.stride_w;
+
+ Label iter_exit_label;
+
+ cmp(reg_kh, 0);
+ je(iter_exit_label, T_NEAR);
+ cmp(reg_kw, 0);
+ je(iter_exit_label, T_NEAR);
+
+ mov(iter_kh, reg_kh);
+ Label kh_label;
+ L(kh_label); {
+ mov(iter_kw, reg_kw);
+ mov(aux1_reg_input, aux_reg_input);
+ mov(aux1_reg_kernel, aux_reg_kernel);
+
+ Label kw_label;
+ L(kw_label); {
+ int repeats = isa == sse42 ? 2 : 1;
+ for (int i = 0; i < repeats; i++) {
+ for (int ch = 0; ch < ur_ch_blocks; ch++) {
+ int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*4;
+ Vmm vmm_ker = get_ker_reg(0);
+ uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
+ + ker_off*sizeof(float)]);
+
+ for (int ow = 0; ow < ur_w; ow++) {
+ int inp_off = ch*jcp.ih*jcp.iw*ch_blk
+ + ow*stride_w*ch_blk + i*4;
+ Vmm vmm_src = get_src_reg(0);
+ uni_vmovups(vmm_src, ptr[aux1_reg_input
+ + inp_off*sizeof(float)]);
+
+ Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
+ + ch*ur_w + ow);
+ uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+ }
+ }
+ }
+ add(aux1_reg_kernel, ch_blk*sizeof(float));
+ add(aux1_reg_input, ch_blk*dilate_w*sizeof(float));
+
+ dec(iter_kw);
+ cmp(iter_kw, 0);
+ jg(kw_label, T_NEAR);
+ }
+ add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
+ add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
+
+ dec(iter_kh);
+ cmp(iter_kh, 0);
+ jg(kh_label, T_NEAR);
+ }
+
+ L(iter_exit_label);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_filter_unrolled(
+ int ur_ch_blocks, int ur_w) {
+ int ch_blk = jcp.ch_block;
+ int dilate_h = jcp.dilate_h + 1;
+ int dilate_w = jcp.dilate_w + 1;
+ int stride_w = jcp.stride_w;
+
+ Label iter_exit_label;
+
+ cmp(reg_kh, 0);
+ je(iter_exit_label, T_NEAR);
+
+ mov(iter_kh, reg_kh);
+ Label kh_label;
+ L(kh_label); {
+ int repeats = isa == sse42 ? 2 : 1;
+ for (int i = 0; i < repeats; i++) {
+ for (int ch = 0; ch < ur_ch_blocks; ch++) {
+ for (int kw = 0; kw < jcp.kw; kw++) {
+ int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*4;
+
+ Vmm vmm_ker = get_ker_reg(0);
+ uni_vmovups(vmm_ker, ptr[aux_reg_kernel
+ + ker_off*sizeof(float)]);
+
+ for (int ow = 0; ow < ur_w; ow++) {
+ int inp_off = ch*jcp.ih*jcp.iw*ch_blk
+ + ow*stride_w*ch_blk + kw*ch_blk*dilate_w + i*4;
+
+ Vmm vmm_src = get_src_reg(0);
+ uni_vmovups(vmm_src, ptr[aux_reg_input
+ + inp_off*sizeof(float)]);
+
+ Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w
+ + ch*ur_w + ow);
+ uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+ }
+ }
+ }
+ }
+
+ add(aux_reg_kernel, jcp.kw*ch_blk*sizeof(float));
+ add(aux_reg_input, jcp.iw*ch_blk*dilate_h*sizeof(float));
+
+ dec(iter_kh);
+ cmp(iter_kh, 0);
+ jg(kh_label, T_NEAR);
+ }
+
+ L(iter_exit_label);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::apply_activation(
+ int ur_ch_blocks, int ur_w) {
+ if (this->jcp.with_eltwise) {
+ int repeats = isa == sse42 ? 2 : 1;
+ eltwise_injector_->compute_vector_range(4, repeats * ur_w * ur_ch_blocks + 4);
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::store_dst(
+ int ur_ch_blocks, int ur_w) {
+ int ch_blk = jcp.ch_block;
+
+ int repeats = isa == sse42 ? 2 : 1;
+ for (int i = 0; i < repeats; i++) {
+ for (int ch = 0; ch < ur_ch_blocks; ch++) {
+ for (int ow = 0; ow < ur_w; ow++) {
+ int o_off = ch*jcp.oh*jcp.ow*ch_blk + ow*ch_blk + i*4;
+ Vmm vmm_dst = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
+
+ uni_vmovups(vmmword[reg_output + o_off*sizeof(float)], vmm_dst);
+ }
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::loop_body(int ur_ch_blocks) {
+ Label unrolled_w_label;
+ Label tail_w_label;
+ Label exit_label;
+
+ L(unrolled_w_label); {
+ int ur_w = jcp.ur_w;
+
+ cmp(reg_ur_w, ur_w);
+ jl(tail_w_label, T_NEAR);
+
+ mov(aux_reg_input, reg_input);
+ mov(aux_reg_kernel, reg_kernel);
+
+ load_src(ur_ch_blocks, ur_w);
+ apply_filter_unrolled(ur_ch_blocks, ur_w);
+ apply_activation(ur_ch_blocks, ur_w);
+ store_dst(ur_ch_blocks, ur_w);
+
+ add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
+
+ sub(reg_ur_w, ur_w);
+ jmp(unrolled_w_label);
+ }
+
+ L(tail_w_label); {
+ int ur_w = 1;
+
+ cmp(reg_ur_w, ur_w);
+ jl(exit_label, T_NEAR);
+
+ mov(aux_reg_input, reg_input);
+ mov(aux_reg_kernel, reg_kernel);
+
+ load_src(ur_ch_blocks, ur_w);
+ apply_filter(ur_ch_blocks, ur_w);
+ apply_activation(ur_ch_blocks, ur_w);
+ store_dst(ur_ch_blocks, ur_w);
+
+ add(reg_input, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_output, sizeof(float) * ur_w * jcp.ch_block);
+
+ sub(reg_ur_w, ur_w);
+ jmp(tail_w_label);
+ }
+
+ L(exit_label);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::generate() {
+ this->preamble();
+
+ mov(reg_input, ptr[this->param1 + GET_OFF(src)]);
+ mov(reg_output, ptr[this->param1 + GET_OFF(dst)]);
+ mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
+ if (jcp.with_bias)
+ mov(reg_bias, ptr[this->param1 + GET_OFF(bias)]);
+ mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
+ mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
+ mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
+ mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
+
+ Label ch_blocks_tail_label;
+ Label exit_label;
+
+ int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
+
+ cmp(reg_ch_blocks, jcp.nb_ch_blocking);
+ jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
+
+ loop_body(jcp.nb_ch_blocking); // channel main loop
+
+ if (ch_blocks_tail) {
+ L(ch_blocks_tail_label);
+
+ cmp(reg_ch_blocks, ch_blocks_tail);
+ jne(exit_label, T_NEAR);
+
+ loop_body(ch_blocks_tail); // channel tail loop
+ }
+
+ L(exit_label);
+
+ this->postamble();
+
+ if (jcp.with_eltwise)
+ eltwise_injector_->prepare_table();
+}
+
+template <cpu_isa_t isa>
+bool jit_uni_dw_conv_fwd_kernel_f32<isa>::post_ops_ok(
+ jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
+ const auto &p = attr.post_ops_;
+
+ auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+
+ switch (p.len_) {
+ case 0: return true; // no post_ops
+ case 1: return is_eltwise(0) || is_sum(0); // sum OR eltwise
+ case 2: return is_sum(0) && is_eltwise(1); // sum -> eltwise
+ default: return false;
+ }
+
+ return false;
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
+ const primitive_attr_t &attr)
+{
+ if (!mayiuse(isa)) return status::unimplemented;
+
+ const int simd_w = isa == avx512_common ? 16 : 8;
+
+ jcp.prop_kind = cd.prop_kind;
+
+ const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
+ if (!with_groups) return status::unimplemented;
+
+ jcp.ngroups = weights_d.dims()[0];
+ jcp.mb = src_d.dims()[0];
+
+ jcp.oc = dst_d.dims()[1];
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = src_d.dims()[1];
+
+ jcp.ih = src_d.dims()[2];
+ jcp.iw = src_d.dims()[3];
+ jcp.oh = dst_d.dims()[2];
+ jcp.ow = dst_d.dims()[3];
+
+ jcp.kh = weights_d.dims()[3];
+ jcp.kw = weights_d.dims()[4];
+
+ jcp.t_pad = cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][1];
+ jcp.b_pad = cd.padding[1][0];
+ jcp.r_pad = cd.padding[1][1];
+
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+
+ jcp.dilate_h = cd.dilates[0];
+ jcp.dilate_w = cd.dilates[1];
+
+ jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
+
+ if (!post_ops_ok(jcp, attr))
+ return status::unimplemented;
+
+ const auto &p = attr.post_ops_;
+ jcp.with_sum = p.find(primitive_kind::sum) != -1;
+ const int eltwise_ind = p.find(primitive_kind::eltwise);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise)
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+
+ bool ok_to_pad_channels = true
+ && jcp.oc == jcp.ngroups
+ && jcp.ic == jcp.ngroups
+ && one_of(isa, avx512_common, avx2);
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.oc, simd_w);
+ jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
+ }
+
+ auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
+ auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
+
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
+ jcp.dst_tag = dst_d.matches_one_of_tag(dat_tag);
+
+ bool args_ok = true
+ && jcp.oc == jcp.ngroups
+ && jcp.ic == jcp.ngroups
+ && jcp.ngroups % simd_w == 0
+ && jcp.src_tag == dat_tag
+ && jcp.wei_tag == wei_tag
+ && jcp.dst_tag == dat_tag
+ && jcp.ic <= src_d.padded_dims()[1]
+ && jcp.oc <= dst_d.padded_dims()[1]
+ && jcp.ngroups <= weights_d.padded_dims()[0];
+ if (!args_ok) return status::unimplemented;
+
+ jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
+
+ jcp.ch_block = simd_w;
+ jcp.nb_ch = jcp.oc / jcp.ch_block;
+ jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
+ if (jcp.nb_ch < jcp.nb_ch_blocking)
+ jcp.nb_ch_blocking = jcp.nb_ch;
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ if (jcp.with_bias && jcp.oc_without_padding != jcp.oc)
+ scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+}
+
+template struct jit_uni_dw_conv_fwd_kernel_f32<avx512_common>;
+template struct jit_uni_dw_conv_fwd_kernel_f32<avx2>;
+template struct jit_uni_dw_conv_fwd_kernel_f32<sse42>;
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::load_ddst(
+ int ur_ch_blocks, int ur_str_w) {
+ int repeats = isa == sse42 ? 2 : 1;
+ for (int i = 0; i < repeats; i++) {
+ for (int ch = 0; ch < ur_ch_blocks; ch++) {
+ for (int w = 0; w < ur_str_w; w++) {
+ Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
+ + ch*ur_str_w + w);
+ uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
+ }
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::apply_filter(
+ int ur_ch_blocks, int ur_str_w) {
+ int kw = jcp.kw;
+ int kh = jcp.kh;
+ int ow = jcp.ow;
+ int oh = jcp.oh;
+
+ int ch_blk = jcp.ch_block;
+ int stride_h = jcp.stride_h;
+ int stride_w = jcp.stride_w;
+
+ Label iter_exit_label;
+
+ cmp(reg_kh, 0);
+ je(iter_exit_label, T_NEAR);
+
+ cmp(reg_kw, 0);
+ je(iter_exit_label, T_NEAR);
+
+ mov(iter_kh, reg_kh);
+ Label kh_label;
+ L(kh_label); {
+ mov(aux1_reg_ddst, aux_reg_ddst);
+ mov(aux1_reg_kernel, aux_reg_kernel);
+
+ mov(iter_kw, reg_kw);
+ Label kw_label;
+ L(kw_label); {
+ int repeats = isa == sse42 ? 2 : 1;
+ for (int i = 0; i < repeats; i++) {
+ for (int ch = 0; ch < ur_ch_blocks; ch++) {
+ int ker_off = ch*kh*kw*ch_blk + i*4;
+ Vmm vmm_ker = get_ker_reg(0);
+ uni_vmovups(vmm_ker, ptr[aux1_reg_kernel
+ + ker_off*sizeof(float)]);
+
+ for (int w = 0; w < ur_str_w; w++) {
+ int ddst_off = (ch*oh*ow + w)*ch_blk + i*4;
+
+ Vmm vmm_src = get_src_reg(0);
+ uni_vmovups(vmm_src, ptr[aux1_reg_ddst
+ + ddst_off*sizeof(float)]);
+
+ Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
+ + ch*ur_str_w + w);
+ uni_vfmadd231ps(vmm_acc, vmm_src, vmm_ker);
+ }
+ }
+ }
+
+ add(aux1_reg_kernel, ch_blk*stride_w*sizeof(float));
+ sub(aux1_reg_ddst, ch_blk*sizeof(float));
+
+ sub(iter_kw, stride_w);
+ cmp(iter_kw, 0);
+ jg(kw_label, T_NEAR);
+ }
+
+ add(aux_reg_kernel, kw*ch_blk*stride_h*sizeof(float));
+ sub(aux_reg_ddst, ow*ch_blk*sizeof(float));
+
+ sub(iter_kh, stride_h);
+ cmp(iter_kh, 0);
+ jg(kh_label, T_NEAR);
+ }
+
+ L(iter_exit_label);
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::store_dsrc(
+ int ur_ch_blocks, int ur_str_w) {
+ int ch_blk = jcp.ch_block;
+ int iw = jcp.iw;
+ int ih = jcp.ih;
+ int stride_w = jcp.stride_w;
+
+ int repeats = isa == sse42 ? 2 : 1;
+ for (int i = 0; i < repeats; i++) {
+ for (int ch = 0; ch < ur_ch_blocks; ch++) {
+ for (int w = 0; w < ur_str_w; w++) {
+ int dsrc_off = (ch*ih*iw + w*stride_w)*ch_blk + i*4;
+ Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_str_w
+ + ch*ur_str_w + w);
+
+ uni_vmovups(ptr[reg_dsrc + dsrc_off*sizeof(float)], vmm_acc);
+ }
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::loop_body(
+ int ur_ch_blocks) {
+ Label unrolled_w_label;
+ Label tail_w_label;
+ Label exit_label;
+
+ L(unrolled_w_label); {
+ int ur_w = jcp.ur_w;
+
+ cmp(reg_ur_str_w, ur_w);
+ jl(tail_w_label, T_NEAR);
+
+ mov(aux_reg_ddst, reg_ddst);
+ mov(aux_reg_kernel, reg_kernel);
+
+ load_ddst(ur_ch_blocks, ur_w);
+ apply_filter(ur_ch_blocks, ur_w);
+ store_dsrc(ur_ch_blocks, ur_w);
+
+ add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
+
+ sub(reg_ur_str_w, ur_w);
+ jmp(unrolled_w_label);
+ }
+
+ L(tail_w_label); {
+ int ur_w = 1;
+
+ cmp(reg_ur_str_w, ur_w);
+ jl(exit_label, T_NEAR);
+
+ mov(aux_reg_ddst, reg_ddst);
+ mov(aux_reg_kernel, reg_kernel);
+
+ load_ddst(ur_ch_blocks, ur_w);
+ apply_filter(ur_ch_blocks, ur_w);
+ store_dsrc(ur_ch_blocks, ur_w);
+
+ add(reg_dsrc, sizeof(float) * ur_w * jcp.ch_block * jcp.stride_w);
+ add(reg_ddst, sizeof(float) * ur_w * jcp.ch_block);
+
+ sub(reg_ur_str_w, ur_w);
+ jmp(tail_w_label);
+ }
+
+ L(exit_label);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::generate() {
+ preamble();
+
+ mov(reg_dsrc, ptr[this->param1 + GET_OFF(src)]);
+ mov(reg_ddst, ptr[this->param1 + GET_OFF(dst)]);
+ mov(reg_kernel, ptr[this->param1 + GET_OFF(filt)]);
+ mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
+ mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
+ mov(reg_ch_blocks, ptr[this->param1 + GET_OFF(ch_blocks)]);
+ mov(reg_ur_str_w, ptr[this->param1 + GET_OFF(ur_str_w)]);
+
+ Label ch_blocks_tail_label;
+ Label exit_label;
+
+ int ch_blocks_tail = jcp.nb_ch % jcp.nb_ch_blocking;
+
+ cmp(reg_ch_blocks, jcp.nb_ch_blocking);
+ jne(ch_blocks_tail ? ch_blocks_tail_label : exit_label, T_NEAR);
+
+ loop_body(jcp.nb_ch_blocking); // channel main loop
+
+ if (ch_blocks_tail) {
+ L(ch_blocks_tail_label);
+
+ cmp(reg_ch_blocks, ch_blocks_tail);
+ jne(exit_label, T_NEAR);
+
+ loop_body(ch_blocks_tail); // channel tail loop
+ }
+
+ L(exit_label);
+
+ this->postamble();
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_conf(
+ jit_conv_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d) {
+ if (!mayiuse(isa)) return status::unimplemented;
+
+ const int simd_w = isa == avx512_common ? 16 : 8;
+
+ const bool with_groups = weights_d.ndims() == diff_src_d.ndims() + 1;
+ if (!with_groups) return status::unimplemented;
+
+ jcp.ngroups = weights_d.dims()[0];
+ jcp.mb = diff_src_d.dims()[0];
+
+ jcp.oc = diff_dst_d.dims()[1];
+ jcp.oc_without_padding = jcp.oc;
+ jcp.ic = diff_src_d.dims()[1];
+
+ jcp.ih = diff_src_d.dims()[2];
+ jcp.iw = diff_src_d.dims()[3];
+ jcp.oh = diff_dst_d.dims()[2];
+ jcp.ow = diff_dst_d.dims()[3];
+
+ jcp.kh = weights_d.dims()[3];
+ jcp.kw = weights_d.dims()[4];
+
+ jcp.t_pad = cd.padding[0][0];
+ jcp.l_pad = cd.padding[0][1];
+ jcp.b_pad = cd.padding[1][0];
+ jcp.r_pad = cd.padding[1][1];
+
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+
+ jcp.dilate_h = cd.dilates[0];
+ jcp.dilate_w = cd.dilates[1];
+
+ jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
+ jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
+
+ bool ok_to_pad_channels = true
+ && jcp.oc == jcp.ngroups
+ && jcp.ic == jcp.ngroups
+ && one_of(isa, avx512_common, avx2);
+ if (ok_to_pad_channels) {
+ jcp.oc = rnd_up(jcp.oc, simd_w);
+ jcp.ic = rnd_up(jcp.oc, simd_w);
+ jcp.ngroups = rnd_up(jcp.ngroups, simd_w);
+ }
+
+ auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
+ auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
+
+ jcp.src_tag = diff_src_d.matches_one_of_tag(dat_tag);
+ jcp.wei_tag = weights_d.matches_one_of_tag(wei_tag);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
+
+ bool args_ok = true
+ && jcp.oc == jcp.ngroups
+ && jcp.ic == jcp.ngroups
+ && jcp.ngroups % simd_w == 0
+ && jcp.dilate_h == 0
+ && jcp.dilate_w == 0
+ && jcp.src_tag == dat_tag
+ && jcp.wei_tag == wei_tag
+ && jcp.dst_tag == dat_tag
+ && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
+ && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1
+ && jcp.ic <= diff_src_d.padded_dims()[1]
+ && jcp.oc <= diff_dst_d.padded_dims()[1]
+ && jcp.ngroups <= weights_d.padded_dims()[0];
+ if (!args_ok) return status::unimplemented;
+
+ jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
+
+ jcp.ch_block = simd_w;
+ jcp.nb_ch = jcp.ic / jcp.ch_block;
+ jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
+ if (jcp.nb_ch < jcp.nb_ch_blocking)
+ jcp.nb_ch_blocking = jcp.nb_ch;
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ UNUSED(scratchpad);
+ UNUSED(jcp);
+}
+
+template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx512_common>;
+template struct jit_uni_dw_conv_bwd_data_kernel_f32<avx2>;
+template struct jit_uni_dw_conv_bwd_data_kernel_f32<sse42>;
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_filter() {
+ for (int r = 0; r < reg_repeats; ++r) {
+ for (int i = 0; i < jcp.kw; ++i) {
+ Vmm vmm_acc = get_acc_reg(r * jcp.kw + i);
+ uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_filter() {
+ for (int r = 0; r < reg_repeats; ++r) {
+ const int reg_set = r * jcp.kw;
+ for (int i = 0; i < jcp.kw; ++i) {
+ int off_filter = (reg_set + i) * simd_w;
+ Vmm vmm_acc = get_acc_reg(reg_set + i);
+ uni_vmovups(vmm_acc,
+ vmmword[reg_tmp_filter + off_filter * sizeof(float)]);
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::zero_bias() {
+ for (int r = 0; r < reg_repeats; ++r) {
+ Vmm vmm_bias = get_bias_reg(r);
+ uni_vpxor(vmm_bias, vmm_bias, vmm_bias);
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::load_bias() {
+ for (int r = 0; r < reg_repeats; ++r) {
+ Vmm vmm_bias = get_bias_reg(r);
+ uni_vmovups(
+ vmm_bias, vmmword[reg_bias_baddr + r * simd_w * sizeof(float)]);
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_step_unroll(
+ int unroll_w, int l_pad, int pad_offset, int ow_block) {
+
+ const int iw_block = ow_block * jcp.stride_w;
+ const int right_border = jcp.iw - iw_block;
+
+ const int cascade_input = nstl::min(jcp.stride_w, jcp.kw);
+
+ /* preamble count for number of cascaded LOAD + FMA operation */
+ const int input_overlap = nstl::max(jcp.kw - l_pad, 0);
+
+ /* LOAD initial input registers, then cascade LOADs and FMAs*/
+ for (int r = 0; r < reg_repeats; ++r) {
+ for (int i_ur = 0; i_ur < unroll_w; ++i_ur) {
+ int off_output = (i_ur * reg_repeats + r) * simd_w;
+ Vmm vmm_output = get_output_reg(r);
+ uni_vmovups(vmm_output,
+ ptr[reg_tmp_output + off_output * sizeof(float)]);
+ if (i_ur == 0) {
+ for (int c = 0; c < input_overlap; ++c) {
+ int off_input
+ = ((c - pad_offset) * reg_repeats + r) * simd_w;
+ Vmm vmm_input
+ = get_input_reg((c % jcp.kw) * reg_repeats + r);
+ uni_vmovups(vmm_input,
+ ptr[reg_tmp_input + off_input * sizeof(float)]);
+ }
+ } else {
+ for (int c = 0; c < cascade_input; ++c) {
+ int overlap = (i_ur - 1) * jcp.stride_w + input_overlap;
+ int off_input
+ = ((overlap + c - pad_offset) * reg_repeats + r)
+ * simd_w;
+ Vmm vmm_input = get_input_reg(
+ ((overlap + c) % jcp.kw) * reg_repeats + r);
+ uni_vmovups(vmm_input,
+ ptr[reg_tmp_input + off_input * sizeof(float)]);
+ }
+ }
+
+ for (int i_kw = 0; i_kw < jcp.kw; ++i_kw) {
+ int io_overlap = i_kw + (i_ur * jcp.stride_w);
+
+ /* Don't apply FMAs that fall into the padded region */
+ if (io_overlap - l_pad < 0
+ || io_overlap - jcp.l_pad >= right_border)
+ continue;
+
+ Vmm vmm_input = get_input_reg(
+ ((io_overlap - l_pad) % jcp.kw) * reg_repeats + r);
+ Vmm vmm_acc = get_acc_reg(i_kw * reg_repeats + r);
+ Vmm vmm_aux = isa == sse42 ? get_aux_reg() : vmm_input;
+ if (isa == sse42)
+ uni_vmovups(vmm_aux, vmm_input);
+ uni_vfmadd231ps(vmm_acc, vmm_aux, vmm_output);
+ }
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void
+jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_step_unroll(
+ const int unroll_w) {
+ for (int r = 0; r < reg_repeats; ++r) {
+ for (int i = 0; i < unroll_w; ++i) {
+ Vmm vmm_bias = get_bias_reg(r);
+ int off_output = (i * reg_repeats + r) * simd_w;
+ if (isa == sse42) {
+ /* Need to support unaligned address loads for SSE42*/
+ Vmm vmm_output = get_output_reg(1 + r);
+ uni_vmovups(vmm_output,
+ ptr[reg_tmp_output + off_output * sizeof(float)]);
+ uni_vaddps(vmm_bias, vmm_bias, vmm_output);
+ } else {
+ uni_vaddps(vmm_bias, vmm_bias,
+ vmmword[reg_tmp_output + off_output * sizeof(float)]);
+ }
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_filter() {
+ for (int r = 0; r < reg_repeats; ++r) {
+ const int reg_set = r * jcp.kw;
+ for (int i = 0; i < jcp.kw; ++i) {
+ int off_filter = (i + reg_set) * simd_w;
+ Vmm vmm_acc = get_acc_reg(i + reg_set);
+ uni_vmovups(vmmword[reg_tmp_filter + off_filter * sizeof(float)],
+ vmm_acc);
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::store_bias() {
+ for (int r = 0; r < reg_repeats; ++r) {
+ Vmm vmm_bias = get_bias_reg(r);
+ uni_vmovups(
+ vmmword[reg_bias_baddr + r * simd_w * sizeof(float)], vmm_bias);
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_bias_loop(
+ const int block_size) {
+ Label oh_label;
+ Label ow_blk_label;
+
+ const int unroll_w = nstl::min(block_size, jcp.ow);
+ const int unroll_w_trips = jcp.ow / unroll_w;
+ const int tail_w = jcp.ow > block_size ? jcp.ow % block_size : 0;
+
+ const int ch_offset = jcp.ch_block;
+
+ mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
+ mov(reg_oh_worksize,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
+
+ mov(reg_tmp_output, reg_output_baddr);
+ L(oh_label);
+ {
+
+ mov(iter_ow_blk, unroll_w_trips);
+ L(ow_blk_label);
+ {
+
+ compute_bias_step_unroll(unroll_w);
+ add(reg_tmp_output, unroll_w * ch_offset * sizeof(float));
+
+ dec(iter_ow_blk);
+ cmp(iter_ow_blk, 0);
+ jg(ow_blk_label, T_NEAR);
+ }
+
+ if (tail_w > 0) {
+ compute_bias_step_unroll(tail_w);
+ add(reg_tmp_output, tail_w * ch_offset * sizeof(float));
+ }
+
+ inc(reg_oh);
+ cmp(reg_oh, reg_oh_worksize);
+ jl(oh_label, T_NEAR);
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_zero_filter() {
+
+ const int ch_offset = jcp.ch_block;
+
+ Label kh_loop_label, skip_zeroing_label;
+
+ mov(reg_exec_flags,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
+ and_(reg_exec_flags, FLAG_ZERO_FILTER);
+ test(reg_exec_flags, reg_exec_flags);
+ je(skip_zeroing_label);
+
+ zero_filter();
+
+ mov(reg_tmp_filter, reg_filter_baddr);
+ mov(reg_kh, jcp.kh);
+ L(kh_loop_label);
+ {
+ store_filter();
+
+ add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
+ dec(reg_kh);
+ cmp(reg_kh, 0);
+ jg(kh_loop_label);
+ }
+
+ /* Comeback pointers */
+ sub(reg_tmp_filter, jcp.kh * jcp.kw * ch_offset * sizeof(float));
+
+ L(skip_zeroing_label);
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_step(
+ int unroll_w, int l_pad, int pad_offset, int ow_block) {
+
+ const int ch_offset = jcp.ch_block;
+
+ Label kh_loop_label, skip_loop_label;
+
+ cmp(reg_kh_count, 0);
+ je(skip_loop_label, T_NEAR);
+
+ mov(reg_kh, reg_kh_count);
+ L(kh_loop_label);
+ {
+ load_filter();
+ compute_ow_step_unroll(unroll_w, l_pad, pad_offset, ow_block);
+ store_filter();
+
+ add(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
+ add(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
+ dec(reg_kh);
+ cmp(reg_kh, 0);
+ jg(kh_loop_label);
+ }
+
+ /* Comeback pointers */
+ Label kh_comeback_label;
+ mov(reg_kh, reg_kh_count);
+ L(kh_comeback_label);
+ {
+ sub(reg_tmp_input, jcp.iw * ch_offset * sizeof(float));
+ sub(reg_tmp_filter, jcp.kw * ch_offset * sizeof(float));
+ dec(reg_kh);
+ cmp(reg_kh, 0);
+ jg(kh_comeback_label, T_NEAR);
+ }
+
+ L(skip_loop_label);
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_h_loop(
+ int unroll_w, int l_pad, int pad_offset, int ow_block) {
+
+ const size_t io_overlap = jcp.ih / jcp.stride_h < jcp.oh ?
+ jcp.ih / jcp.stride_h - 1 :
+ jcp.oh - jcp.b_pad - 1;
+ const int ch_offset = jcp.ch_block;
+ const int t_overlap_off = jcp.t_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
+ const int b_overlap_off = jcp.b_pad % jcp.stride_h == 0 ? jcp.stride_h : 1;
+
+ Label tpad_loop_label, h_loop_label, skip_tpad_label, skip_bpad_label,
+ end_h_loop_label;
+
+ mov(reg_oh, ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_index)]);
+ mov(reg_oh_worksize,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, oh_count)]);
+ mov(reg_kh_count,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, kh_count)]);
+
+ mov(reg_tmp_output, reg_output_baddr);
+ mov(reg_tmp_input, reg_input_baddr);
+ mov(reg_tmp_filter, reg_filter_baddr);
+
+ L(h_loop_label);
+ {
+
+ compute_h_step(unroll_w, l_pad, pad_offset, ow_block);
+
+ add(reg_tmp_output, jcp.ow * ch_offset * sizeof(float));
+
+ /* If within the top_pad region */
+ if (jcp.t_pad > 0) {
+ /* Skip t_pad area if no longer in initial h_block */
+ cmp(reg_oh, jcp.t_pad);
+ jg(skip_tpad_label, T_NEAR);
+
+ cmp(reg_kh_count, jcp.kh);
+ jge(skip_tpad_label, T_NEAR);
+
+ add(reg_kh_count, t_overlap_off);
+ sub(reg_tmp_filter,
+ t_overlap_off * jcp.kw * ch_offset * sizeof(float));
+
+ /* kernel has moved beyond padding (adjust for stride effects) */
+ if (jcp.t_pad % jcp.stride_h != 0) {
+ int inp_corr = jcp.stride_h - jcp.t_pad % jcp.stride_h;
+ add(reg_tmp_input,
+ inp_corr * jcp.iw * ch_offset * sizeof(float));
+ }
+ jmp(tpad_loop_label, T_NEAR);
+ }
+
+ L(skip_tpad_label);
+
+ cmp(reg_oh, io_overlap);
+ jl(skip_bpad_label, T_NEAR);
+ sub(reg_kh_count, b_overlap_off);
+
+ L(skip_bpad_label);
+ add(reg_tmp_input, jcp.stride_h * jcp.iw * ch_offset * sizeof(float));
+
+ L(tpad_loop_label);
+
+ cmp(reg_oh, jcp.ih / jcp.stride_h);
+ jge(end_h_loop_label, T_NEAR);
+
+ inc(reg_oh);
+
+ cmp(reg_oh, reg_oh_worksize);
+ jl(h_loop_label, T_NEAR);
+ }
+ L(end_h_loop_label);
+}
+
+template <cpu_isa_t isa>
+inline void
+jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::compute_ow_block_unroll() {
+
+ const int ch_offset = jcp.ch_block;
+ int ow = jcp.ow;
+ int pad_offset = 0;
+ int l_pad = jcp.l_pad;
+
+ /* Calculate effective padding */
+ int r_pad = nstl::max(0, (ow - 1) * jcp.stride_w
+ + (jcp.kw - 1) * (jcp.dilate_w + 1)
+ - (jcp.iw + jcp.l_pad - 1));
+
+ /* Is this strictly defined by:
+ * -code-size (?)
+ * -address size (?) */
+ const int max_unroll_w = 30;
+ const int block_size = 15;
+
+ int unroll_w_tail = 0;
+ int unroll_w = 0;
+ int unroll_w_trips = 0;
+
+ if (jcp.ow > max_unroll_w) {
+ unroll_w = nstl::min(block_size, jcp.ow);
+ unroll_w_trips = ow / unroll_w;
+ /* calculate tail */
+ unroll_w_tail = ow % unroll_w;
+ /* Perform some rebalancing if tail too small*/
+ if ((unroll_w_tail == 0 && r_pad != 0)
+ || (r_pad > 0 && r_pad >= unroll_w_tail)) {
+ if (unroll_w_trips > 1) {
+ unroll_w_tail += unroll_w;
+ unroll_w_trips--;
+ } else {
+ /* Idealy, this case shouldn't happen */
+ unroll_w_tail += (unroll_w - unroll_w / 2);
+ unroll_w = unroll_w / 2;
+ }
+ }
+ } else {
+ unroll_w = jcp.ow;
+ unroll_w_trips = nstl::max(1, ow / unroll_w);
+ }
+ if (jcp.with_bias) {
+ Label skip_load_bias;
+ mov(reg_bias_baddr,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, bias)]);
+
+ zero_bias();
+
+ mov(reg_exec_flags,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, exec_flags)]);
+ and_(reg_exec_flags, FLAG_ZERO_BIAS);
+ test(reg_exec_flags, reg_exec_flags);
+ jne(skip_load_bias);
+
+ load_bias();
+
+ L(skip_load_bias);
+ compute_bias_loop(block_size);
+
+ store_bias();
+ }
+
+ /* Pass filter address, then offset for h_padding. */
+ compute_zero_filter();
+ mov(reg_kh_offset,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter_pad_off)]);
+ add(reg_filter_baddr, reg_kh_offset);
+
+ /* compute left padded block */
+ if (l_pad) {
+ compute_h_loop(unroll_w, l_pad, 0, 0);
+ add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
+ add(reg_input_baddr,
+ unroll_w * jcp.stride_w * ch_offset * sizeof(float));
+ unroll_w_trips--;
+ pad_offset = l_pad;
+ l_pad = 0;
+ }
+
+ /* compute middle block */
+ Label ow_blk_label;
+
+ /* Insert loop for 'ow' block when middle block needs to execute more
+ * than once */
+ bool do_ow_blk_loop = unroll_w_trips > 1;
+ if (do_ow_blk_loop) {
+ mov(iter_ow_blk, unroll_w_trips);
+ L(ow_blk_label);
+ }
+ if (unroll_w_trips > 0) {
+ compute_h_loop(unroll_w, l_pad, pad_offset, 0);
+ add(reg_output_baddr, unroll_w * ch_offset * sizeof(float));
+ add(reg_input_baddr,
+ unroll_w * jcp.stride_w * ch_offset * sizeof(float));
+ }
+ if (do_ow_blk_loop) {
+ dec(iter_ow_blk);
+ cmp(iter_ow_blk, 0);
+ jg(ow_blk_label, T_NEAR);
+ }
+
+ /* compute right padded block */
+ if (unroll_w_tail) {
+ compute_h_loop(unroll_w_tail, 0, pad_offset, jcp.ow - unroll_w_tail);
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::generate() {
+ preamble();
+
+ mov(reg_input_baddr,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, input)]);
+ mov(reg_output_baddr,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, output)]);
+ mov(reg_filter_baddr,
+ ptr[this->param1 + offsetof(jit_dw_conv_call_s, filter)]);
+
+ compute_ow_block_unroll();
+
+ this->postamble();
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_conf(
+ jit_conv_conf_t &jcp, const convolution_desc_t &cd,
+ const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &diff_weights_d,
+ const memory_desc_wrapper &diff_dst_d, int nthreads) {
+ if (!mayiuse(isa))
+ return status::unimplemented;
+
+ jcp.ngroups = diff_weights_d.dims()[0];
+ jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
+ jcp.ic = src_d.dims()[1] / jcp.ngroups;
+
+ const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
+
+ jcp.is_depthwise = true && with_groups && everyone_is(1, jcp.oc, jcp.ic);
+
+ if (!jcp.is_depthwise)
+ return status::unimplemented;
+
+ jcp.ch_block = isa == avx512_common ? 16 : 8;
+
+ jcp.mb = src_d.dims()[0];
+
+ jcp.ih = src_d.dims()[2];
+ jcp.iw = src_d.dims()[3];
+ jcp.oh = diff_dst_d.dims()[2];
+ jcp.ow = diff_dst_d.dims()[3];
+
+ jcp.kh = diff_weights_d.dims()[3];
+ jcp.kw = diff_weights_d.dims()[4];
+
+ jcp.stride_h = cd.strides[0];
+ jcp.stride_w = cd.strides[1];
+
+ jcp.t_pad = cd.padding[0][0];
+ jcp.b_pad = cd.padding[1][0];
+
+ jcp.l_pad = cd.padding[0][1];
+ jcp.r_pad = cd.padding[1][1];
+
+ jcp.dilate_h = cd.dilates[0];
+ jcp.dilate_w = cd.dilates[1];
+
+ jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
+ jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
+
+ jcp.with_bias = cd.diff_bias_desc.format_kind != format_kind::undef;
+
+ auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
+ auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
+
+ jcp.src_tag = src_d.matches_one_of_tag(dat_tag);
+ jcp.wei_tag = diff_weights_d.matches_one_of_tag(wei_tag);
+ jcp.dst_tag = diff_dst_d.matches_one_of_tag(dat_tag);
+
+ bool args_ok = true
+ && jcp.src_tag == dat_tag
+ && jcp.wei_tag == wei_tag
+ && jcp.dst_tag == dat_tag
+ && jcp.ngroups % jcp.ch_block == 0 && jcp.dilate_h == 0
+ && jcp.dilate_w == 0 && jcp.kw <= 3
+ && jcp.oh == (jcp.ihp - jcp.kh) / jcp.stride_h + 1
+ && jcp.ow == (jcp.iwp - jcp.kw) / jcp.stride_w + 1;
+ if (!args_ok)
+ return status::unimplemented;
+
+ jcp.nb_ch = jcp.ngroups / jcp.ch_block;
+
+ /* kernel applicability check wrt boundaries
+ * the conditions are quite general across the kernels we have,
+ * but ideally the check should belong to a specific kernel... */
+ const int max_hpad = (jcp.kh - 1 + 1) / 2;
+ const int max_wpad = (jcp.kw - 1 + 1) / 2;
+ const bool boundaries_ok = true && jcp.t_pad <= max_hpad
+ && jcp.b_pad <= max_hpad && jcp.l_pad <= max_wpad
+ && jcp.r_pad <= max_wpad;
+ if (!boundaries_ok)
+ return status::unimplemented;
+
+ balance(jcp, nthreads);
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad(
+ memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp) {
+ /* Notes: if splitting thread work on 'mb', then a reduction has to take
+ * place. Hence, book a per-thread, local weights-buffer for the
+ * reduction */
+ if (jcp.nthr_mb > 1) {
+ const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
+ scratchpad.book(key_conv_wei_reduction,
+ sizeof(float) * wei_size * (jcp.nthr_mb - 1));
+
+ if (jcp.with_bias)
+ scratchpad.book(key_conv_bia_reduction,
+ sizeof(float) * jcp.ngroups * (jcp.nthr_mb - 1));
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::balance(jit_conv_conf_t &jcp,
+ int nthreads) {
+ jcp.nthr = nthreads;
+ jcp.nthr_g = jcp.nthr_mb = 1;
+
+ /* Basic-Heuristics for parallel strategy:
+ * 1) Tries to parallel on the number of Groups (g) where tasks are
+ * independent. Otherwise,
+ * 2) Tries to split the work across g and MiniBatch (mb).
+ * Parallelizing on mb requires computing a reduction for weights.
+ *
+ * NOTE: because of 'task partitioning' scheme, there will be unbalanced
+ * per-thread load when the number of threads is high (e.g. > 16).
+ */
+ jcp.nthr_g = nstl::min(jcp.nb_ch, jcp.nthr);
+ jcp.nthr_mb = nstl::min(nstl::max(1, jcp.nthr / jcp.nthr_g), jcp.mb);
+
+ jcp.nthr = jcp.nthr_g * jcp.nthr_mb;
+}
+
+template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx512_common>;
+template struct jit_uni_dw_conv_bwd_weights_kernel_f32<avx2>;
+template struct jit_uni_dw_conv_bwd_weights_kernel_f32<sse42>;
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp
new file mode 100644
index 0000000000..9c08fc4a09
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_conv_kernel_f32.hpp
@@ -0,0 +1,253 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_UNI_DW_CONV_KERNEL_F32_HPP
+#define JIT_UNI_DW_CONV_KERNEL_F32_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+#include "jit_uni_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <cpu_isa_t isa>
+struct jit_uni_dw_conv_fwd_kernel_f32: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_fwd_kernel_f32)
+
+ jit_uni_dw_conv_fwd_kernel_f32(jit_conv_conf_t ajcp)
+ : jcp(ajcp), eltwise_injector_(nullptr)
+ {
+ if (jcp.with_eltwise)
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
+ jcp.eltwise);
+
+ this->generate();
+ jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
+ }
+
+ ~jit_uni_dw_conv_fwd_kernel_f32() {
+ delete eltwise_injector_;
+ }
+
+ static bool post_ops_ok(jit_conv_conf_t &jcp,
+ const primitive_attr_t &attr);
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &dst_d, const primitive_attr_t &attr);
+
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ jit_conv_conf_t jcp;
+ void (*jit_ker)(jit_conv_call_s *);
+
+private:
+ using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
+ isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+ using reg64_t = const Xbyak::Reg64;
+ const Xbyak::AddressFrame &vmmword = (isa == sse42)
+ ? xword : (isa == avx2) ? yword : zword;
+ const int vlen = cpu_isa_traits<isa>::vlen;
+
+ // dw convolution
+ reg64_t reg_input = r8;
+ reg64_t aux_reg_input = r9;
+ reg64_t aux1_reg_input = r10;
+ reg64_t reg_kernel = r11;
+ reg64_t aux_reg_kernel = r12;
+ reg64_t aux1_reg_kernel = r13;
+ reg64_t reg_output = r14;
+ reg64_t reg_bias = r15;
+ reg64_t reg_kh = rax;
+ reg64_t reg_kw = rbx;
+ reg64_t iter_kh = rdx;
+ reg64_t iter_kw = rsi;
+ reg64_t reg_ur_w = rbp;
+ reg64_t reg_ch_blocks = aux1_reg_input;
+ reg64_t imm_addr64 = aux1_reg_input;
+
+ inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
+ inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
+ inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
+
+ inline void load_src(int ur_ch_blocks, int ur_w);
+ inline void apply_filter(int ur_ch_blocks, int ur_w);
+ inline void apply_filter_unrolled(int ur_ch_blocks, int ur_w);
+ inline void apply_activation(int ur_ch_blocks, int ur_w);
+ inline void store_dst(int ur_ch_blocks, int ur_w);
+ inline void loop_body(int ur_ch_blocks);
+
+ jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
+
+ void generate();
+};
+
+template <cpu_isa_t isa>
+struct jit_uni_dw_conv_bwd_data_kernel_f32: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_data_kernel_f32)
+
+ jit_uni_dw_conv_bwd_data_kernel_f32(jit_conv_conf_t ajcp): jcp(ajcp) {
+ this->generate();
+ jit_ker = (void (*)(jit_conv_call_s *))this->getCode();
+ }
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd,
+ const memory_desc_wrapper &diff_src_d,
+ const memory_desc_wrapper &weights_d,
+ const memory_desc_wrapper &diff_dst_d);
+
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ jit_conv_conf_t jcp;
+ void (*jit_ker)(jit_conv_call_s *);
+
+private:
+ using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
+ isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+ using reg64_t = const Xbyak::Reg64;
+
+ inline Vmm get_ker_reg(int idx) { return Vmm(idx + 0); }
+ inline Vmm get_src_reg(int idx) { return Vmm(idx + 1); }
+ inline Vmm get_acc_reg(int idx) { return Vmm(idx + 4); }
+
+ reg64_t reg_ddst = rax;
+ reg64_t aux_reg_ddst = r8;
+ reg64_t aux1_reg_ddst = abi_not_param1;
+ reg64_t reg_kernel = rdx;
+ reg64_t aux_reg_kernel = r10;
+ reg64_t aux1_reg_kernel = rbp;
+ reg64_t reg_dsrc = rsi;
+
+ reg64_t reg_ur_str_w = r9;
+ reg64_t reg_ch_blocks = rbx;
+
+ reg64_t iter_kh = r11;
+ reg64_t iter_kw = r12;
+ reg64_t reg_kh = r13;
+ reg64_t reg_kw = r14;
+
+ inline void loop_body(int ur_ch_blocks);
+ inline void load_ddst(int ur_ch_blocks, int ur_str_w);
+ inline void apply_filter(int ur_ch_blocks, int ur_str_w);
+ inline void store_dsrc(int ur_ch_blocks, int ur_str_w);
+
+ void generate();
+};
+
+template <cpu_isa_t isa>
+struct jit_uni_dw_conv_bwd_weights_kernel_f32 : public jit_generator {
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_dw_conv_bwd_weights_kernel_f32)
+
+ jit_uni_dw_conv_bwd_weights_kernel_f32(jit_conv_conf_t ajcp) : jcp(ajcp) {
+ this->generate();
+ jit_ker = (void (*)(jit_dw_conv_call_s *)) this->getCode();
+ }
+
+ static status_t init_conf(jit_conv_conf_t &jcp,
+ const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
+ const memory_desc_wrapper &diff_weights_d,
+ const memory_desc_wrapper &diff_dst_d, int nthreads);
+
+ static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
+ const jit_conv_conf_t &jcp);
+
+ static void balance(jit_conv_conf_t &jcp, int nthreads);
+
+ jit_conv_conf_t jcp;
+ void (*jit_ker)(jit_dw_conv_call_s *);
+
+private:
+ using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
+ isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+ using reg64_t = const Xbyak::Reg64;
+ const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
+ const int reg_repeats = (isa == sse42) ? 2 : 1;
+
+ const Xbyak::AddressFrame &vmmword
+ = (isa == sse42) ? xword : (isa == avx2) ? yword : zword;
+
+ /* XXX: offset between input and accummulators is 3, therefore, assume 'kw'
+ * is no larger than 3*/
+ inline Vmm get_bias_reg(int idx = 0) { return Vmm(idx); }
+ inline Vmm get_output_reg(int idx) { return Vmm(idx + 1); }
+ inline Vmm get_input_reg(int idx) { return Vmm(idx + 4 * reg_repeats + 1); }
+ inline Vmm get_acc_reg(int idx) { return Vmm(idx + 1 * reg_repeats + 1); }
+ inline Vmm get_aux_reg() { return Vmm(0); }
+
+ reg64_t reg_tmp_input = r9;
+ reg64_t reg_tmp_output = r10;
+ reg64_t reg_tmp_filter = r13;
+ reg64_t reg_kh_offset = rax;
+
+ /* parameter passed by driver into kernel */
+ Xbyak::Reg8 reg_exec_flags = bl;
+
+ reg64_t reg_oh_worksize = r14;
+ reg64_t reg_oh = rax;
+
+ reg64_t iter_ow_blk = r11;
+
+ reg64_t reg_kh = rsi;
+ reg64_t reg_kh_count = rdx;
+
+ /* Base addresses for convolution parameters. */
+ reg64_t reg_input_baddr = r15;
+ reg64_t reg_output_baddr = r12;
+ reg64_t reg_filter_baddr = abi_not_param1;
+ reg64_t reg_bias_baddr = r13;
+
+ /* Micro-kernel JIT'ing, fusing 'kw' and 'ow_block' loops into unrolled FMAs
+ */
+ inline void compute_ow_step_unroll(
+ int unroll_w, int l_pad, int pad_offset, int ow_block);
+
+ /* JIT'ing the outer loops for the micro-kernel -> {kh, oh_block} */
+ inline void compute_h_step(
+ int unroll_w, int l_pad, int pad_offset, int ow_block);
+ inline void compute_h_loop(
+ int unroll_w, int l_pad, int pad_offset, int ow_block);
+
+ /* Write 'width' micro-kernel JITs; depending on the padding and convolution
+ * size, write a micro-kernel for the left ow-block, middle ow-block(s), and
+ * right ow-block.*/
+ inline void compute_ow_block_unroll();
+
+ inline void compute_zero_filter();
+ inline void load_filter();
+ inline void zero_filter();
+ inline void load_bias();
+ inline void zero_bias();
+ inline void compute_bias_step_unroll(const int unroll_w);
+ inline void compute_bias_loop(const int block_size);
+ inline void store_filter();
+ inline void store_bias();
+
+ void generate();
+};
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp
new file mode 100644
index 0000000000..58449601a3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.cpp
@@ -0,0 +1,427 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "jit_uni_dw_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace mkldnn::impl::utils;
+
+template <cpu_isa_t isa>
+void _jit_uni_dw_convolution_fwd_t<isa>::execute_forward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const auto &jcp = kernel_->jcp;
+
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = this->scratchpad(ctx).template get<data_t>(
+ key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bias = padded_bias;
+ }
+
+ int dil_h = jcp.dilate_h + 1;
+ int dil_w = jcp.dilate_w + 1;
+ int str_h = jcp.stride_h;
+ int str_w = jcp.stride_w;
+
+ auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh,
+ int kh_padding, int ch, int ch_num, int n) {
+ auto par_conv = jit_conv_call_s();
+
+ const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
+ const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
+ + (jcp.kw - 1)*dil_w - jcp.l_pad + 1)) - jcp.iw;
+
+ const int iw = nstl::max((ow*str_w - jcp.l_pad
+ + div_up(i_l_overflow, dil_w)*dil_w), 0);
+ const int kw = div_up(i_l_overflow, dil_w);
+
+ const int kw_padding = jcp.kw - div_up(i_l_overflow, dil_w)
+ - div_up(i_r_overflow, dil_w);
+
+ par_conv.src = &src[src_d.blk_off(n, ch, ih, iw)];
+ par_conv.dst = &dst[dst_d.blk_off(n, ch, oh, ow)];
+
+ par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, kh, kw)];
+ if (bias) par_conv.bias = &bias[bias_d.blk_off(ch*jcp.ch_block)];
+
+ par_conv.kh_padding = (size_t)nstl::max(0, kh_padding);
+ par_conv.kw_padding = (size_t)nstl::max(0, kw_padding);
+
+ par_conv.ur_w = (size_t)ur_w_step;
+
+ par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch;
+
+ return par_conv;
+ };
+
+ const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
+ parallel_nd(jcp.mb, chb_work, jcp.oh,
+ [&](int n, int chb, int oh) {
+ int ch = chb * jcp.nb_ch_blocking;
+ int ch_num = jcp.nb_ch_blocking;
+
+ const int i_t_overflow = nstl::max(0, (int)(jcp.t_pad - oh*str_h));
+ const int i_b_overflow = nstl::max(jcp.ih,
+ (int)(oh*str_h + (jcp.kh - 1)*dil_h - jcp.t_pad + 1)) - jcp.ih;
+
+ const int ih = nstl::max((int)(oh*str_h - jcp.t_pad
+ + div_up(i_t_overflow, dil_h)*dil_h), 0);
+ const int kh = div_up(i_t_overflow, dil_h);
+ const int kh_padding = jcp.kh - div_up(i_t_overflow, dil_h)
+ - div_up(i_b_overflow, dil_h);
+
+ // left border
+ int ow = 0;
+ int l_border = nstl::min(div_up(jcp.l_pad, str_w), jcp.ow);
+ int ur_w_step = 1;
+ for (; ow < l_border; ow++) {
+ jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
+ kh, kh_padding, ch, ch_num, n);
+
+ kernel_->jit_ker(&par_conv);
+ }
+
+ // main loop
+ ur_w_step = (jcp.iw - (jcp.kw - 1)*dil_w + jcp.l_pad - 1)
+ / jcp.stride_w - ow + 1;
+ if (ur_w_step > 0) {
+ jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
+ kh, kh_padding, ch, ch_num, n);
+
+ kernel_->jit_ker(&par_conv);
+
+ ow += ur_w_step;
+ }
+
+ // right border
+ ur_w_step = 1;
+ for (; ow < jcp.ow; ow++) {
+ jit_conv_call_s par_conv = kernel_params(ur_w_step, ow, oh, ih,
+ kh, kh_padding, ch, ch_num, n);
+
+ kernel_->jit_ker(&par_conv);
+ }
+ });
+
+ if (pd()->wants_zero_pad_dst())
+ ctx.memory(MKLDNN_ARG_DST)->zero_pad();
+}
+
+template struct _jit_uni_dw_convolution_fwd_t<avx512_common>;
+template struct _jit_uni_dw_convolution_fwd_t<avx2>;
+template struct _jit_uni_dw_convolution_fwd_t<sse42>;
+
+template <cpu_isa_t isa>
+void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+
+ const auto &jcp = kernel_->jcp;
+
+ auto kernel_params = [&](int ur_str_w, int iw, int oh, int ih,
+ int i_t_overflow, int i_b_overflow, int stride_off_h,
+ int ch, int ch_num, int n) {
+ auto par_conv = jit_conv_call_s();
+
+ const int i_l_overflow = nstl::max(0, (jcp.kw - 1 - iw - jcp.l_pad));
+ const int i_r_overflow = nstl::max(0, (jcp.kw - 1 - (jcp.iw - 1 - iw)
+ - jcp.r_pad));
+
+ int ow = iw + jcp.l_pad - i_r_overflow;
+ int stride_off_w = ow % jcp.stride_w;
+ ow /= jcp.stride_w;
+
+ par_conv.src = &diff_src[diff_src_d.blk_off(n, ch, ih, iw)];
+ par_conv.dst = &diff_dst[diff_dst_d.blk_off(n, ch, oh, ow)];
+ par_conv.filt = &weights[weights_d.blk_off(ch, 0, 0, i_b_overflow
+ + stride_off_h, i_r_overflow + stride_off_w)];
+
+ par_conv.kh_padding = nstl::max(0, jcp.kh - i_t_overflow - i_b_overflow
+ - stride_off_h);
+ par_conv.kw_padding = nstl::max(0, jcp.kw - i_l_overflow - i_r_overflow
+ - stride_off_w);
+
+ par_conv.ur_str_w = ur_str_w;
+
+ par_conv.ch_blocks = nstl::min(ch + ch_num, jcp.nb_ch) - ch;
+
+ return par_conv;
+ };
+
+ const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
+ parallel_nd(jcp.mb, chb_work, jcp.ih,
+ [&](int n, int chb, int ih) {
+ int ch = chb * jcp.nb_ch_blocking;
+ int ch_num = jcp.nb_ch_blocking;
+
+ const int i_t_overflow = nstl::max(0, (int)(jcp.kh - 1 - ih
+ - jcp.t_pad));
+ const int i_b_overflow = nstl::max(0, (int)(jcp.kh - 1
+ - (jcp.ih - 1 - ih) - jcp.b_pad));
+
+ int oh = ih + jcp.t_pad - i_b_overflow;
+ int stride_off_h = oh % jcp.stride_h;
+ oh /= jcp.stride_h;
+
+ for (int i_str_w = 0; i_str_w < jcp.stride_w; i_str_w++) {
+ // left border
+ int iw = i_str_w;
+ int l_border = nstl::min(jcp.kw - 1 - jcp.l_pad, jcp.iw);
+ int ur_str_w = 1;
+ for (; iw < l_border; iw += jcp.stride_w) {
+ jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
+ ih, i_t_overflow, i_b_overflow,
+ stride_off_h, ch, ch_num, n);
+
+ kernel_->jit_ker(&par_conv);
+ }
+
+ // main loop
+ ur_str_w = nstl::min((jcp.iw - jcp.kw + jcp.r_pad - iw)
+ / jcp.stride_w, jcp.iw);
+ if (ur_str_w > 0) {
+ jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
+ ih, i_t_overflow, i_b_overflow,
+ stride_off_h, ch, ch_num, n);
+
+ kernel_->jit_ker(&par_conv);
+
+ iw += ur_str_w * jcp.stride_w;
+ }
+
+ // right border
+ ur_str_w = 1;
+ for (; iw < jcp.iw; iw += jcp.stride_w) {
+ jit_conv_call_s par_conv = kernel_params(ur_str_w, iw, oh,
+ ih, i_t_overflow, i_b_overflow,
+ stride_off_h, ch, ch_num, n);
+
+ kernel_->jit_ker(&par_conv);
+ }
+ }
+ });
+}
+
+template struct _jit_uni_dw_convolution_bwd_data_t<avx512_common>;
+template struct _jit_uni_dw_convolution_bwd_data_t<avx2>;
+template struct _jit_uni_dw_convolution_bwd_data_t<sse42>;
+
+template <cpu_isa_t isa>
+_jit_uni_dw_convolution_bwd_weights_t<isa>::
+_jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , kernel_(nullptr), acc_ker_(nullptr)
+{
+ kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32<isa>(pd()->jcp_);
+ if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction())
+ acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
+}
+
+template <cpu_isa_t isa>
+void _jit_uni_dw_convolution_bwd_weights_t<isa>::execute_backward_weights(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ auto diff_wei_reduction_buf =
+ scratchpad(ctx).template get<data_t>(key_conv_wei_reduction);
+ auto diff_bia_reduction_buf =
+ scratchpad(ctx).template get<data_t>(key_conv_bia_reduction);
+
+ const auto &jcp = kernel_->jcp;
+
+ /* Used when executing a parallel reduction */
+ simple_barrier::ctx_t reduction_bctx;
+ simple_barrier::ctx_init(&reduction_bctx);
+
+ const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
+ const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0;
+
+ const int ch_block = jcp.ch_block;
+
+ auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params,
+ const int batch, const int group, const int oh_start,
+ const int work_size, const unsigned char exec_flag,
+ const size_t kh_padding, const size_t filter_off) {
+ const int tpad_underflow_off = jcp.t_pad - filter_off;
+
+ conv_params->exec_flags = exec_flag;
+ conv_params->kh_count = jcp.kh - kh_padding;
+
+ const int oh_s = oh_start;
+ const int oh_e = oh_start + work_size;
+ const int ih_s = oh_s * jcp.stride_h;
+
+ conv_params->filter_pad_off
+ = filter_off * jcp.kw * ch_block * sizeof(float);
+ conv_params->oh_index = oh_s;
+ conv_params->oh_count = oh_e;
+
+ size_t diff_dst_off
+ = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh
+ + oh_start)
+ * jcp.ow;
+
+ size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih
+ + ih_s - tpad_underflow_off) * jcp.iw;
+
+ conv_params->output = &diff_dst[diff_dst_off * ch_block];
+ conv_params->input = &src[src_off * ch_block];
+ };
+
+ parallel(jcp.nthr, [&](const int ithr, const int nthr) {
+ assert(nthr == jcp.nthr);
+
+ auto conv_params = jit_dw_conv_call_s();
+ const int h_block_size = 15;
+
+ /* assign iteration space to thread */
+ const int ithr_g = ithr % jcp.nthr_g;
+ const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb;
+
+ /* split dimensions */
+ int g_start{ 0 }, g_end{ 0 };
+ balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end);
+
+ int mb_start{ 0 }, mb_end{ 0 };
+ balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end);
+
+ auto diff_wei = ithr_mb == 0
+ ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size;
+ auto diff_bia = ithr_mb == 0
+ ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size;
+
+ for (int g = g_start; g < g_end; ++g) {
+ unsigned char zero_filter_flag = FLAG_ZERO_FILTER;
+ unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0;
+
+ size_t diff_wei_off = g * jcp.kh * jcp.kw;
+ conv_params.filter = &diff_wei[diff_wei_off * ch_block];
+
+ if (jcp.with_bias)
+ conv_params.bias = &diff_bia[g * ch_block];
+
+ for (int mb = mb_start; mb < mb_end; ++mb) {
+ int oh = 0;
+ while (oh < jcp.oh) {
+ const int h_work = nstl::min(h_block_size, jcp.oh - oh);
+ auto kh_t_padding = nstl::max(0, jcp.t_pad - oh);
+ auto kh_b_padding
+ = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ?
+ jcp.b_pad - (h_work - 1) :
+ 0;
+
+ set_kernel_params(&conv_params, mb, g, oh, h_work,
+ zero_filter_flag | zero_bias_flag,
+ kh_t_padding + kh_b_padding, kh_t_padding);
+ kernel_->jit_ker(&conv_params);
+
+ zero_bias_flag &= ~FLAG_ZERO_BIAS;
+ zero_filter_flag &= ~FLAG_ZERO_FILTER;
+ oh += h_work;
+ }
+ }
+ }
+
+ if (do_parallel_reduction() && jcp.nthr_mb > 1) {
+ size_t reduct_start{ 0 }, reduct_end{ 0 };
+ balance211(wei_size, nthr, ithr, reduct_start, reduct_end);
+
+ const int acc_size = reduct_end - reduct_start;
+ const size_t reduct_off = reduct_start;
+ auto *acc_data = diff_weights + reduct_off;
+
+ simple_barrier::barrier(&reduction_bctx, nthr);
+
+ for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
+ auto *src_data = diff_wei_reduction_buf
+ + (thr_mb - 1) * wei_size + reduct_off;
+ acc_ker_->accumulate(acc_data, src_data, acc_size);
+ }
+ }
+ });
+
+ if (jcp.nthr_mb <= 1) return;
+
+ /* Apply single-threaded 'mb' reduction */
+ for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
+ size_t mb_accum_offset = (thr_mb - 1) * wei_size;
+ size_t b_accum_offset = (thr_mb - 1) * bias_size;
+
+ for (int g = 0; g < jcp.nb_ch; ++g) {
+ /* Reduction on Bias */
+ if (jcp.with_bias) {
+ PRAGMA_OMP_SIMD()
+ for (int g_block = 0; g_block < ch_block; ++g_block) {
+ size_t bias_offset = g * ch_block + g_block;
+ diff_bias[bias_offset] += diff_bia_reduction_buf[
+ b_accum_offset + bias_offset];
+ }
+ }
+
+ if (do_parallel_reduction()) continue;
+
+ for (int kh = 0; kh < jcp.kh; ++kh)
+ for (int kw = 0; kw < jcp.kw; ++kw)
+ {
+ size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw;
+ PRAGMA_OMP_SIMD()
+ for (int g_block = 0; g_block < ch_block; ++g_block) {
+ const size_t off = wei_offset * ch_block + g_block;
+ diff_weights[off] +=
+ diff_wei_reduction_buf[mb_accum_offset + off];
+ }
+ }
+ }
+ }
+}
+
+template struct _jit_uni_dw_convolution_bwd_weights_t<avx512_common>;
+template struct _jit_uni_dw_convolution_bwd_weights_t<avx2>;
+template struct _jit_uni_dw_convolution_bwd_weights_t<sse42>;
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp
new file mode 100644
index 0000000000..ca53749ec2
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_dw_convolution.hpp
@@ -0,0 +1,266 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_UNI_DW_CONVOLUTION_HPP
+#define CPU_JIT_UNI_DW_CONVOLUTION_HPP
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+
+#include "cpu_barrier.hpp"
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+#include "cpu_reducer.hpp"
+
+#include "jit_uni_dw_conv_kernel_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <cpu_isa_t isa>
+struct _jit_uni_dw_convolution_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const typename pd_t::base_class *hint_fwd_pd)
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
+ _jit_uni_dw_convolution_fwd_t<isa>);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_uni_dw_conv_fwd_kernel_f32<isa>::init_conf(
+ jcp_, *desc(), src_md(), *weights_md(), *dst_md(), *attr());
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_uni_dw_conv_fwd_kernel_f32<isa>::init_scratchpad(scratchpad,
+ jcp_);
+
+ return status::success;
+ }
+
+ jit_conv_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
+ auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ _jit_uni_dw_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
+ { kernel_ = new jit_uni_dw_conv_fwd_kernel_f32<isa>(pd()->jcp_); }
+
+ ~_jit_uni_dw_convolution_fwd_t() { delete kernel_; }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_uni_dw_conv_fwd_kernel_f32<isa> *kernel_;
+};
+
+using jit_avx512_common_dw_convolution_fwd_t =
+ _jit_uni_dw_convolution_fwd_t<avx512_common>;
+using jit_avx2_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t<avx2>;
+using jit_sse42_dw_convolution_fwd_t = _jit_uni_dw_convolution_fwd_t<sse42>;
+
+template <cpu_isa_t isa>
+struct _jit_uni_dw_convolution_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_data_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_()
+ {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
+ _jit_uni_dw_convolution_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::undef, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+
+ if (!ok) return status::unimplemented;
+
+ status_t status = jit_uni_dw_conv_bwd_data_kernel_f32<isa>::
+ init_conf(jcp_, *desc(), *diff_src_md(), *weights_md(),
+ *diff_dst_md());
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_uni_dw_conv_bwd_data_kernel_f32<isa>::init_scratchpad(
+ scratchpad, jcp_);
+
+ return status::success;
+ }
+
+ jit_conv_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
+ auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ _jit_uni_dw_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd)
+ { kernel_ = new jit_uni_dw_conv_bwd_data_kernel_f32<isa>(pd()->jcp_); }
+ ~_jit_uni_dw_convolution_bwd_data_t() { delete kernel_; };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_uni_dw_conv_bwd_data_kernel_f32<isa> *kernel_;
+};
+
+using jit_avx512_common_dw_convolution_bwd_data_t =
+ _jit_uni_dw_convolution_bwd_data_t<avx512_common>;
+using jit_avx2_dw_convolution_bwd_data_t =
+ _jit_uni_dw_convolution_bwd_data_t<avx2>;
+using jit_sse42_dw_convolution_bwd_data_t =
+ _jit_uni_dw_convolution_bwd_data_t<sse42>;
+
+template <cpu_isa_t isa>
+struct _jit_uni_dw_convolution_bwd_weights_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_weights_pd_t {
+ pd_t(engine_t *engine,
+ const convolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const convolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_convolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , jcp_() {}
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
+ _jit_uni_dw_convolution_bwd_weights_t<isa>);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_weights
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(data_type::f32, data_type::f32,
+ data_type::f32, data_type::f32, data_type::f32)
+ && !has_zero_dim_memory()
+ && set_default_formats();
+ if (!ok) return status::unimplemented;
+
+ const int max_threads = mkldnn_in_parallel()
+ ? 1 : mkldnn_get_max_threads();
+
+ status_t status = jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::
+ init_conf(jcp_, *desc(), *src_md(), *diff_weights_md(),
+ *diff_dst_md(), max_threads);
+ if (status != status::success) return status;
+
+ auto scratchpad = scratchpad_registry().registrar();
+ jit_uni_dw_conv_bwd_weights_kernel_f32<isa>::init_scratchpad(
+ scratchpad, jcp_);
+
+ return status::success;
+ }
+
+ jit_conv_conf_t jcp_;
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+
+ auto dat_tag = isa == avx512_common ? nChw16c : nChw8c;
+ auto wei_tag = isa == avx512_common ? Goihw16g : Goihw8g;
+
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ _jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd);
+ ~_jit_uni_dw_convolution_bwd_weights_t() {
+ delete kernel_;
+ delete acc_ker_;
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_weights(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_weights(const exec_ctx_t &ctx) const;
+ bool do_parallel_reduction() const { return false; }
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_uni_dw_conv_bwd_weights_kernel_f32<isa> *kernel_;
+ cpu_accumulator_1d_t<data_type::f32> *acc_ker_;
+};
+
+using jit_avx512_common_dw_convolution_bwd_weights_t =
+ _jit_uni_dw_convolution_bwd_weights_t<avx512_common>;
+using jit_avx2_dw_convolution_bwd_weights_t =
+ _jit_uni_dw_convolution_bwd_weights_t<avx2>;
+using jit_sse42_dw_convolution_bwd_weights_t =
+ _jit_uni_dw_convolution_bwd_weights_t<sse42>;
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp
new file mode 100644
index 0000000000..2af6435871
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.cpp
@@ -0,0 +1,1142 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "jit_uni_eltwise.hpp"
+
+#define GET_OFF(field) offsetof(jit_args, field)
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace Xbyak;
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::injector_preamble(size_t start_idx,
+ size_t end_idx) {
+ preserved_vecs_count = 0;
+ vecs_to_preserve = (size_t)aux_vecs_count(alg_);
+ start_idx_tail = start_idx;
+
+ // For sse42 mask register has to be Xmm(0)
+ if (isa == sse42 && vecs_to_preserve > 0) {
+ size_t idx = 0;
+ assert(idx < start_idx);
+ preserved_vec_idxs[preserved_vecs_count++] = idx;
+ }
+
+ for (size_t idx = preserved_vecs_count; idx < vecs_count; idx++) {
+ if (preserved_vecs_count >= vecs_to_preserve) break;
+ if (start_idx <= idx && idx < end_idx) continue;
+
+ preserved_vec_idxs[preserved_vecs_count++] = idx;
+ }
+
+ size_t preserved_vecs_count_tail = vecs_to_preserve - preserved_vecs_count;
+ for (size_t i = 0; i < preserved_vecs_count_tail; i++) {
+ preserved_vec_idxs[preserved_vecs_count++] = start_idx_tail++;
+ }
+
+ assert(preserved_vecs_count == vecs_to_preserve);
+
+ if (save_state_) {
+ h->push(p_table);
+
+ if (preserved_vecs_count)
+ h->sub(h->rsp, preserved_vecs_count * vlen);
+
+ for (size_t i = 0; i < preserved_vecs_count; ++i)
+ h->uni_vmovups(h->ptr[h->rsp + i * vlen],
+ Vmm(preserved_vec_idxs[i]));
+
+ load_table_addr();
+ }
+
+ assign_regs();
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::injector_preamble_tail(size_t start_idx)
+{
+ size_t tail_vecs_to_preserve = start_idx_tail - start_idx;
+ if (tail_vecs_to_preserve == 0) return;
+
+ const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;
+
+ if (save_state_) {
+ if (idx_off)
+ h->add(h->rsp, idx_off * vlen);
+
+ for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
+ h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
+ h->ptr[h->rsp + i * vlen]);
+ }
+
+ for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
+ preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;
+
+ if (save_state_) {
+ for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
+ h->uni_vmovups(h->ptr[h->rsp + i * vlen],
+ Vmm(preserved_vec_idxs[idx_off + i]));
+
+ if (idx_off)
+ h->sub(h->rsp, idx_off * vlen);
+ }
+
+ assign_regs();
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::injector_postamble() {
+ if (!save_state_) return;
+
+ for (size_t i = 0; i < preserved_vecs_count; ++i)
+ h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
+ h->ptr[h->rsp + i * vlen]);
+
+ if (preserved_vecs_count)
+ h->add(h->rsp, preserved_vecs_count * vlen);
+
+ h->pop(p_table);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::assign_regs() {
+ vmm_mask = Vmm(preserved_vec_idxs[0]);
+ vmm_aux0 = Vmm(preserved_vec_idxs[0]);
+ vmm_aux1 = Vmm(preserved_vec_idxs[1]);
+ vmm_aux2 = Vmm(preserved_vec_idxs[2]);
+ vmm_aux3 = Vmm(preserved_vec_idxs[3]);
+ vmm_aux4 = Vmm(preserved_vec_idxs[4]);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::exp_compute_vector(const Vmm &vmm_src) {
+ h->uni_vminps(vmm_src, vmm_src, table_val(10));
+ h->uni_vmaxps(vmm_src, vmm_src, table_val(11));
+ h->uni_vmovups(vmm_aux0, vmm_src);
+ //calculate exp(x)
+ // fx = x * log2ef + 0.5
+ h->uni_vmulps(vmm_src, vmm_src, table_val(2));
+ h->uni_vaddps(vmm_src, vmm_src, table_val(1));
+
+ // tmp = floorf(fx)
+ if (isa == avx512_common) {
+ h->vcvtps2dq(vmm_aux1 | h->T_rd_sae, vmm_src);
+ h->vcvtdq2ps(vmm_aux1, vmm_aux1);
+
+ h->vcmpps(k_mask, vmm_aux1, vmm_src, _cmp_nle_us);
+ h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
+
+ h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux3);
+ } else {
+ h->uni_vroundps(vmm_aux1, vmm_src, _op_floor);
+ }
+
+ //keep fx for further computations
+ h->uni_vmovups(vmm_src, vmm_aux1); //vmm_src = fx
+
+ //x = x - fx * ln2
+ h->uni_vfnmadd231ps(vmm_aux0, vmm_aux1, table_val(3));
+
+ // compute 2^n
+ h->uni_vcvtps2dq(vmm_aux1, vmm_src);
+ h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
+ h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //Vmm(6) = 2^-fx
+
+ // y = p5
+ h->uni_vmovups(vmm_src, table_val(9));
+ // y = y * x + p4
+ h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(8));
+ // y = y * x + p3
+ h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(7));
+ // y = y * x + p2
+ h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(6));
+ // y = y * x + p1
+ h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(0));
+ // y = y * x + p0
+ h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(5)); //exp(q)
+ // y = y * 2^n
+ h->uni_vmulps(vmm_src, vmm_src, vmm_aux1);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::relu_compute_vector(const Vmm &vmm_src)
+{
+ const int alpha_off = 0, zero_off = 1;
+
+ h->uni_vmovups(vmm_aux1, vmm_src);
+ if (isa == sse42) {
+ h->movups(vmm_mask, vmm_src);
+ h->mulps(vmm_src, table_val(alpha_off));
+ h->cmpps(vmm_mask, table_val(zero_off), _cmp_nle_us);
+ h->blendvps(vmm_src, vmm_aux1);
+ } else if (isa == avx2) {
+ h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
+ h->vcmpgtps(vmm_mask, vmm_aux1, table_val(zero_off));
+ h->vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
+ } else if (isa == avx512_common) {
+ h->vmulps(vmm_src, vmm_src, table_val(alpha_off));
+ h->vcmpps(k_mask, vmm_aux1, table_val(zero_off), _cmp_nle_us);
+ h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::relu_zero_ns_compute_vector(
+ const Vmm &vmm_src) {
+ const int zero_off = 1;
+ h->uni_vmaxps(vmm_src, vmm_src, table_val(zero_off));
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::elu_compute_vector(const Vmm &vmm_src) {
+ const int alpha_off = 23, zero_off = 24;
+
+ // compute exponent
+ h->uni_vmovups(vmm_aux2, vmm_src);
+ exp_compute_vector(vmm_src);
+
+ // alpha * (exp(x) - 1)
+ h->uni_vsubps(vmm_src, vmm_src, table_val(0));
+ h->uni_vmulps(vmm_src, vmm_src, table_val(alpha_off));
+
+ // combine with mask
+ if (isa == sse42) {
+ h->pxor(vmm_mask, vmm_mask);
+ h->cmpps(vmm_mask, vmm_aux2, _cmp_le_os);
+ h->blendvps(vmm_src, vmm_aux2);
+ } else if (isa == avx2) {
+ h->uni_vcmpgtps(vmm_mask, vmm_aux2, table_val(zero_off));
+ h->uni_vblendvps(vmm_src, vmm_src, vmm_aux2, vmm_mask);
+ } else if (isa == avx512_common) {
+ h->vcmpps(k_mask, vmm_aux2, table_val(zero_off), _cmp_nle_us);
+ h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux2);
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::tanh_compute_vector(const Vmm &vmm_src)
+{
+ // # comes from Taylor expansion error bound
+ // > linear_sat_point = single(sqrt(3) * 1b-12);
+ // # comes from the exp formula cancellation
+ // > exp_bound_point = (single(log(3)/2));
+ // # comes from rounding accuracy in float
+ // > one_sat_point = round(atanh(1 - 1b-25), single, RU);
+ // > P = fpminimax(f, [|1, 3, 5, 7, 9|], [|24... |],
+ // [linear_sat_point, exp_bound_point], relative, floating);
+ // > err_bound = D(sup(supnorm(P, tanh(x),
+ // [linear_sat_point, exp_bound_point], relative, theta)));
+ // 0x1.fffd6f00b9539p-25
+ // > P;
+ // x * (0x1.fffffep-1 + x^0x1p1 * (-0x1.55539ep-2 + x^0x1p1 *
+ // (0x1.10be3ep-3 + x^0x1p1 * (-0x1.ae57b4p-5
+ // + x^0x1p1 * 0x1.09fa1p-6))))
+
+ // register mapping
+ // vmm_src contains input
+ // vmm_aux0 contains mask of currently valid results.
+ // 1 is need computation, 0 is already computed
+ // vmm_aux1 contains current output
+ // vmm_aux2, vmm_aux3 contains auxiliary values
+ // vmm_aux4 contains the original sign of inputs
+
+ Label end_tanh_label;
+
+ auto test_exit =[&](Xbyak::Address threshold){
+ // is not necessary for >AVX, but should not matter on perf
+ h->uni_vmovups(vmm_aux0, vmm_src);
+ if (isa == avx512_common){
+ h->vcmpps(k_mask, vmm_aux0, threshold, 0x5);
+ h->kortestw(k_mask, k_mask);
+ } else {
+ h->uni_vcmpgeps(vmm_aux0, vmm_aux0, threshold);
+ h->uni_vtestps(vmm_aux0, vmm_aux0);
+ }
+ h->jz(end_tanh_label, Xbyak::CodeGenerator::T_NEAR);
+ };
+
+ auto blend_results=[&](Vmm vmm_partial_res){
+ if (isa == avx512_common)
+ h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_partial_res);
+ else
+ h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_partial_res, vmm_aux0);
+ };
+
+ // because tanh(x) = -tanh(-x), we extract sign to make x postive
+ // and reapply sign at the end
+ // mov is not necessary for >AVX, but should not matter for performance
+ h->uni_vmovups(vmm_aux4, vmm_src);
+ h->uni_vandps(vmm_aux4, vmm_aux4, table_val(12));
+ h->uni_vandps(vmm_src, vmm_src, table_val(17));
+
+ // if x < linear_sat_point for all inputs, we just return the input
+ h->uni_vmovups(vmm_aux1, vmm_src);
+ test_exit(table_val(13));
+
+ // if one of the mask is one, we have to compute an better approx
+ h->uni_vmovups(vmm_aux2, vmm_src);
+ h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux2);
+ h->uni_vmovups(vmm_aux3, table_val(22));
+ h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(21));
+ h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(20));
+ h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(19));
+ h->uni_vfmadd213ps(vmm_aux3, vmm_aux2, table_val(18));
+ h->uni_vmulps(vmm_aux3, vmm_aux3, vmm_src);
+
+ // we blend only the result that need update
+ blend_results(vmm_aux3);
+
+ // if x < exp_bound_point, we go to return point
+ test_exit(table_val(14));
+
+ // if not we use a better approx 1 - 2 / (1 + exp(2x))
+ // compute 2x
+ h->uni_vmovups(vmm_aux3, vmm_src);
+ h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux3);
+
+ // Compute exp(2x)
+ // We need to save kmask, vmm_aux0, vmm_aux1 and vmm_src as exp can use them
+ // vmm_src is not more read afterwards, so we do not have to save it
+ auto stack_size = 3 * vlen + (isa == avx512_common) * 4;
+ h->sub(h->rsp, stack_size);
+ h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_aux0);
+ h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_aux1);
+ h->uni_vmovups(h->ptr[h->rsp + 2 * vlen], vmm_src);
+ if (isa == avx512_common)
+ h->kmovw(h->ptr[h->rsp + 3 * vlen], k_mask);
+
+ exp_compute_vector(vmm_aux3);
+
+ h->uni_vmovups(vmm_aux0, h->ptr[h->rsp + 0 * vlen]);
+ h->uni_vmovups(vmm_aux1, h->ptr[h->rsp + 1 * vlen]);
+ h->uni_vmovups(vmm_src, h->ptr[h->rsp + 2 * vlen]);
+ if (isa == avx512_common)
+ h->kmovw(k_mask, h->ptr[h->rsp + 3 * vlen]);
+ h->add(h->rsp, stack_size);
+
+ // 1 + exp(2x)
+ h->uni_vaddps(vmm_aux3, vmm_aux3, table_val(0));
+
+ // 1 - 2 / (1 + exp(2x))
+ h->uni_vmovups(vmm_aux2, table_val(16));
+ h->uni_vdivps(vmm_aux2, vmm_aux2, vmm_aux3);
+ h->uni_vaddps(vmm_aux2, vmm_aux2, table_val(0));
+
+ // we blend only the result that need update
+ blend_results(vmm_aux2);
+
+ // finally, we saturate to 1 if needed
+ // TODO: maybe move that up if most inputs saturate in practice
+ if (isa == avx512_common)
+ h->vcmpps(k_mask, vmm_aux0, table_val(15), 0x5);
+ else {
+ h->uni_vmovups(vmm_aux0, vmm_src);
+ h->uni_vcmpgeps(vmm_aux0, vmm_aux0, table_val(15));
+ }
+ h->uni_vmovups(vmm_aux2, table_val(0));
+ blend_results(vmm_aux2);
+
+ h->L(end_tanh_label);
+ {
+ // we apply the sign of x to the result and we are done
+ h->uni_vmovups(vmm_src, vmm_aux1);
+ h->uni_vpxor(vmm_src, vmm_src, vmm_aux4);
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::square_compute_vector(
+ const Vmm &vmm_src) {
+ h->uni_vmulps(vmm_src, vmm_src, vmm_src);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::abs_compute_vector(const Vmm &vmm_src) {
+ // compute abs(x) = _mm_and_ps(x, 01111..111));
+ h->uni_vandps(vmm_src, vmm_src, table_val(0));
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::sqrt_compute_vector(const Vmm &vmm_src)
+{
+ if (isa == avx512_common) {
+ h->vcmpps(k_mask, vmm_src, table_val(0), _cmp_nle_us);
+ h->uni_vsqrtps(vmm_aux1, vmm_src);
+ h->uni_vmovups(vmm_src, table_val(0));
+ h->vblendmps(vmm_src | k_mask, vmm_src, vmm_aux1);
+ } else {
+ h->uni_vmovups(vmm_mask, vmm_src);
+ h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(0));
+ h->uni_vsqrtps(vmm_aux1, vmm_src);
+ h->uni_vmovups(vmm_src, table_val(0));
+ h->uni_vblendvps(vmm_src, vmm_src, vmm_aux1, vmm_mask);
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::linear_compute_vector(
+ const Vmm &vmm_src) {
+ // compute x = alpha * x + beta;
+ h->uni_vmovups(vmm_aux0, table_val(0));
+ h->uni_vfmadd213ps(vmm_src, vmm_aux0, table_val(1));
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::bounded_relu_compute_vector(
+ const Vmm &vmm_src) {
+ // compute bounded relu */
+ h->uni_vmaxps(vmm_src, vmm_src, table_val(1));
+ h->uni_vminps(vmm_src, vmm_src, table_val(0));
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::soft_relu_compute_vector(
+ const Vmm &vmm_src) {
+ // duplicate src
+ h->uni_vmovups(vmm_aux2, vmm_src);
+
+ h->uni_vminps(vmm_src, vmm_src, table_val(24));
+ h->uni_vmaxps(vmm_src, vmm_src, table_val(25));
+ h->uni_vmovups(vmm_aux1, vmm_src);
+ // calculate exp(x)
+ // fx = x * log2ef + 0.5
+ h->uni_vmulps(vmm_src, vmm_src, table_val(2));
+ h->uni_vaddps(vmm_src, vmm_src, table_val(1));
+
+ // tmp = floorf(fx)
+ if (isa == avx512_common) {
+ h->vcvtps2dq(vmm_aux0 | h->T_rd_sae, vmm_src);
+ h->vcvtdq2ps(vmm_aux0, vmm_aux0);
+
+ h->vcmpps(k_mask, vmm_aux0, vmm_src, _cmp_nle_us);
+ h->vmovups(vmm_aux3 | k_mask | h->T_z, table_val(0));
+
+ h->vsubps(vmm_aux0, vmm_aux0, vmm_aux3);
+ } else {
+ h->uni_vroundps(vmm_aux0, vmm_src, _op_floor);
+ }
+
+ // keep fx for further computations
+ h->uni_vmovups(vmm_src, vmm_aux0); //vmm_src = fx
+ // calculation fx * ln2
+ h->uni_vmulps(vmm_aux0, vmm_aux0, table_val(3));
+ // x = x - fx * ln2
+ h->uni_vsubps(vmm_aux1, vmm_aux1, vmm_aux0);
+ // y = p5
+ h->uni_vmovups(vmm_aux3, table_val(22));
+ // y = y * x + p4
+ h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(21));
+ // y = y * x + p3
+ h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(20));
+ // y = y * x + p2
+ h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(19));
+ // y = y * x + p1
+ h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(0));
+ // y = y * x + p0
+ h->uni_vfmadd213ps(vmm_aux3, vmm_aux1, table_val(17));
+
+ // compute 2^(-n)
+ if (isa == avx512_common) {
+ h->vmulps(vmm_aux1, vmm_src, table_val(23));
+ h->vcvtps2dq(vmm_aux1, vmm_aux1);
+ } else {
+ h->uni_vcvtps2dq(vmm_aux1, vmm_src);
+ h->uni_vpsignd(vmm_aux1, vmm_aux1, table_val(23));
+ }
+
+ h->uni_vpaddd(vmm_aux1, vmm_aux1, table_val(4));
+ h->uni_vpslld(vmm_aux1, vmm_aux1, 23); //vmm_aux1 = 2^-fx
+ // calculate ln(1 + y)
+ h->uni_vaddps(vmm_aux3, vmm_aux3, vmm_aux1);
+ // x = y; y is free; keep x for further computations
+ h->uni_vmovups(vmm_src, vmm_aux3);
+ // frexp()
+ h->uni_vpsrld(vmm_src, vmm_src, 23);
+ h->uni_vcvtdq2ps(vmm_src, vmm_src);
+ // got n. where n is x = 2^n * y. y = 0.5 .. 1
+ h->uni_vsubps(vmm_src, vmm_src, table_val(5));
+
+ h->uni_vandps(vmm_aux3, vmm_aux3, table_val(6));
+ // got y. (mantisa) 0.5 < y < 1
+ h->uni_vorps(vmm_aux3, vmm_aux3, table_val(7));
+ // y = y - 1
+ h->uni_vsubps(vmm_aux3, vmm_aux3, table_val(0));
+ // y = p8
+ h->uni_vmovups(vmm_aux1, table_val(16));
+ // y = y * x + p7
+ h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(15));
+ // y = y * x + p6
+ h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(14));
+ // y = y * x + p5
+ h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(13));
+ // y = y * x + p4
+ h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(12));
+ // y = y * x + p3
+ h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(11));
+ // y = y * x + p2
+ h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(10));
+ // y = y * x + p1
+ h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(9));
+ // y = y * x + p0 ; p0 = 0
+ h->uni_vfmadd213ps(vmm_aux1, vmm_aux3, table_val(8));
+ //calculate ln(2) * n
+ h->uni_vmulps(vmm_src, vmm_src, table_val(3));
+ h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_src);
+ h->uni_vaddps(vmm_aux1, vmm_aux1, vmm_aux0);
+
+ // get vmm_mask = src > max logf
+ h->uni_vmovups(vmm_mask, vmm_aux2);
+ if (isa == avx512_common) {
+ // y = (x < max log f) ? soft_relu(x) : x
+ h->vcmpps(k_mask, vmm_mask, table_val(24), _cmp_nle_us);
+ h->vblendmps(vmm_aux1 | k_mask, vmm_aux1, vmm_aux2);
+ } else {
+ // y = (x < max log f) ? soft_relu(x) : x
+ h->uni_vcmpgtps(vmm_mask, vmm_mask, table_val(24));
+ h->uni_vblendvps(vmm_aux1, vmm_aux1, vmm_aux2, vmm_mask);
+ }
+
+ h->uni_vmovups(vmm_src, vmm_aux1);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::logistic_compute_vector(
+ const Vmm &vmm_src) {
+ // we store the original sign and make x negative
+ // IMPORTANT: we assume vmm_aux0 to be xmm0, as for sse4.2 path it is required
+ // IMPORTANT: we use vmm_aux2 for the mask as exp_compute does not use it.
+ h->uni_vmovups(vmm_aux2, vmm_src);
+ h->uni_vandps(vmm_aux2, vmm_aux2, table_val(12));
+ h->uni_vorps(vmm_src, vmm_src, table_val(12));
+
+ exp_compute_vector(vmm_src);
+ // dup exp(x)
+ h->uni_vmovups(vmm_aux1, vmm_src);
+ // (exp(x) + 1)
+ h->uni_vaddps(vmm_aux1, vmm_aux1, table_val(0));
+ // y = exp(x) / (exp(x) + 1)
+ h->uni_vdivps(vmm_src, vmm_src, vmm_aux1);
+
+ // Now we have to apply the "symmetry" based on original sign
+ h->uni_vmovups(vmm_aux3, table_val(0));
+ h->uni_vsubps(vmm_aux3, vmm_aux3, vmm_src);
+ if (isa == avx512_common) {
+ h->vptestmd(k_mask, vmm_aux2, vmm_aux2);
+ h->vblendmps(vmm_aux3 | k_mask, vmm_aux3, vmm_src);
+ } else {
+ h->uni_vmovups(vmm_aux0, vmm_aux2);// The mask should be xmm0 for sse4.2
+ h->uni_vblendvps(vmm_aux3, vmm_aux3, vmm_src, vmm_aux0);
+ }
+ h->uni_vmovups(vmm_src, vmm_aux3);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::relu_prepare_table() {
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::elu_prepare_table() {
+ const unsigned int cvals[] = {
+ 0x3f800000, // [0] 1.0f
+ 0x3f000000, // [1] 0.5f
+ 0x3fb8aa3b, // [2] log2ef = 1.44269502f
+ 0x3f317218, // [3] ln2f = 0.69314718f
+ 0x0000007f, // [4] 0x7f
+ // exp(x) polynom
+ 0x3f800001, // [5] p0 = 1.0000001f
+ 0x3efffe85, // [6] p2 = 0.4999887f
+ 0x3e2aaa3e, // [7] p3 = 0.16666505f
+ 0x3d2bb1b1, // [8] p4 = 0.041917507f
+ 0x3c091ec1, // [9] p5 = 0.008369149f
+ 0x42b0c0a5, //[10] max logf = 88.3762589f
+ 0xc1766666, //[11] min logf = -14.5f
+ // tanh(x) constants,
+ 0x80000000, //[12] mask to extract sign
+ 0x39ddb3d7, //[13] arg below which tanh(x) = x
+ 0x3f0c9f54, //[14] arg below which pol approx is valid
+ 0x41102cb4, //[15] arg after which tanh(x) = 1
+ 0xc0000000, //[16] -2.0f
+ 0x7fffffff, //[17] mask to make positive
+ // tanh pol approx
+ 0x3f7fffff, //[18] p0
+ 0xbeaaa9cf, //[19] p1
+ 0x3e085f1f, //[20] p2
+ 0xbd572bda, //[21] p3
+ 0x3c84fd08, //[22] p4
+ };
+
+ for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(cvals[i]);
+ }
+
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::soft_relu_prepare_table() {
+ const unsigned int cvals[] = {
+ 0x3f800000, // [0] 1.0f
+ 0x3f000000, // [1] 0.5f
+ 0x3fb8aa3b, // [2] log2ef = 1.44269502f
+ 0x3f317218, // [3] ln2f = 0.69314718f
+ 0x0000007f, // [4] 0x7f
+ 0x42fc0000, // [5] 126
+ 0x807fffff, // [6] and with (to get 0.5 * mantissa)
+ 0x3f000000, // [7] or with (to get 0.5 * mantissa)
+ // ln(1 + x) polynomial
+ 0xb2b4637d, // [8] p0 = 0.0000000244f
+ 0x3f7fff8e, // [9] p1 = 0.9999976971f
+ 0xbf001759, //[10] p2 = -0.5002478215f
+ 0x3ea70608, //[11] p3 = 0.3272714505f
+ 0xbea3d7bf, //[12] p4 = -0.3153830071f
+ 0xbe361d04, //[13] p5 = -0.1701777461f
+ 0xbfa8f1e6, //[14] p6 = -1.3254635147f
+ 0xbfe1e812, //[15] p7 = -1.7971917960f
+ 0xbfc4d30e, //[16] p8 = -1.5652673123f
+ // exp(x) polynomial
+ 0x3f800001, //[17] p0 = 1.0000001f
+ 0x3f800000, //[18] p1 = 1.0f
+ 0x3efffe85, //[19] p2 = 0.4999887f
+ 0x3e2aaa3e, //[20] p3 = 0.16666505f
+ 0x3d2bb1b1, //[21] p4 = 0.041917507f
+ 0x3c091ec1, //[22] p5 = 0.008369149f
+ 0xbf800000, //[23] is required for sign changing
+ 0x42b0c0a5, //[24] max logf = 88.3762589f
+ 0xc1766666 //[25] min logf = -14.5f
+ };
+
+ for (size_t i = 0; i < sizeof(cvals) / sizeof(cvals[0]); ++i) {
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) {
+ h->dd(cvals[i]);
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::abs_prepare_table() {
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0x7fffffff);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::sqrt_prepare_table() {
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::linear_prepare_table() {
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(beta_));
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::bounded_relu_prepare_table() {
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(float2int(alpha_));
+ for (size_t d = 0; d < vlen / sizeof(float); ++d) h->dd(0);
+}
+
+template <cpu_isa_t isa>
+int jit_uni_eltwise_injector_f32<isa>::aux_vecs_count(alg_kind_t alg_) {
+ switch (alg_) {
+ case alg_kind::eltwise_relu: return (alpha_ == 0.f) ? 0 : 2;
+ case alg_kind::eltwise_elu: return 4;
+ case alg_kind::eltwise_tanh: return 5;
+ case alg_kind::eltwise_square: return 0;
+ case alg_kind::eltwise_abs: return 0;
+ case alg_kind::eltwise_sqrt: return 2;
+ case alg_kind::eltwise_linear: return 1;
+ case alg_kind::eltwise_bounded_relu: return 0;
+ case alg_kind::eltwise_soft_relu: return 4;
+ case alg_kind::eltwise_logistic: return 4;
+ default: assert(!"unsupported eltwise algorithm");
+ }
+
+ return 0;
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::compute_body(size_t start_idx,
+ size_t end_idx) {
+ using namespace alg_kind;
+ for (size_t idx = start_idx; idx < end_idx; idx++) {
+ switch (alg_) {
+ case eltwise_relu:
+ if (alpha_ == 0.f) relu_zero_ns_compute_vector(Vmm(idx));
+ else relu_compute_vector(Vmm(idx));
+ break;
+ case eltwise_elu: elu_compute_vector(Vmm(idx)); break;
+ case eltwise_tanh: tanh_compute_vector(Vmm(idx)); break;
+ case eltwise_square: square_compute_vector(Vmm(idx)); break;
+ case eltwise_abs: abs_compute_vector(Vmm(idx)); break;
+ case eltwise_sqrt: sqrt_compute_vector(Vmm(idx)); break;
+ case eltwise_linear: linear_compute_vector(Vmm(idx)); break;
+ case eltwise_bounded_relu: bounded_relu_compute_vector(Vmm(idx)); break;
+ case eltwise_soft_relu: soft_relu_compute_vector(Vmm(idx)); break;
+ case eltwise_logistic: logistic_compute_vector(Vmm(idx)); break;
+ default: assert(!"unsupported eltwise algorithm");
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::compute_vector_range(size_t start_idx,
+ size_t end_idx) {
+ assert(start_idx < end_idx && end_idx <= vecs_count);
+
+ injector_preamble(start_idx, end_idx);
+ compute_body(start_idx_tail, end_idx);
+ injector_preamble_tail(start_idx);
+ compute_body(start_idx, start_idx_tail);
+ injector_postamble();
+}
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_injector_f32<isa>::prepare_table(bool gen_table) {
+ using namespace alg_kind;
+
+ h->align(64);
+ h->L(l_table);
+
+ if (gen_table) {
+ switch (alg_) {
+ case eltwise_relu: relu_prepare_table(); break;
+ case eltwise_elu:
+ case eltwise_tanh:
+ case eltwise_logistic:
+ elu_prepare_table(); break;
+ case eltwise_soft_relu: soft_relu_prepare_table(); break;
+ case eltwise_abs: abs_prepare_table(); break;
+ case eltwise_sqrt: sqrt_prepare_table(); break;
+ case eltwise_linear: linear_prepare_table(); break;
+ case eltwise_bounded_relu: bounded_relu_prepare_table(); break;
+ case eltwise_square: break;
+ default: assert(!"unsupported eltwise algorithm");
+ }
+ }
+}
+
+template struct jit_uni_eltwise_injector_f32<avx512_common>;
+template struct jit_uni_eltwise_injector_f32<avx2>;
+template struct jit_uni_eltwise_injector_f32<sse42>;
+
+
+struct jit_args {
+ const float *from;
+ const float *for_comparison;
+ const float *to;
+ size_t work_amount;
+};
+
+struct jit_uni_eltwise_kernel_f32 : public c_compatible {
+ const eltwise_desc_t &desc_;
+
+ void (*ker_)(const jit_args *);
+ void operator()(const jit_args *args) { assert(ker_); ker_(args); }
+
+ jit_uni_eltwise_kernel_f32(const eltwise_desc_t &desc)
+ : desc_(desc), ker_(nullptr) {}
+ virtual ~jit_uni_eltwise_kernel_f32() {}
+
+protected:
+ bool is_bwd() const { return desc_.prop_kind == prop_kind::backward_data; }
+};
+
+/* jit kernels */
+namespace {
+
+template <cpu_isa_t isa>
+struct jit_uni_relu_kernel_f32 : public jit_uni_eltwise_kernel_f32,
+ public jit_generator
+{
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_relu_kernel_f32)
+
+ void compute_step(bool vectorize, const int uf, const int shift) {
+ for (int i = 0; i < uf; i++) {
+ if (vectorize) {
+ uni_vmovups(Vmm(i + 1), ptr[reg_from + i * shift]);
+ if (is_bwd())
+ uni_vmovups(Vmm(uf + i + 1),
+ ptr[reg_for_comparison + i * shift]);
+ } else {
+ movss(Xmm(i + 1), ptr[reg_from + i * shift]);
+ if (is_bwd())
+ movss(Xmm(uf + i + 1),
+ ptr[reg_for_comparison + i * shift]);
+ }
+ }
+
+ if (isa == sse42) {
+ for (int i = 0; i < uf; i++) {
+ movups(Vmm(2 * uf + i + 1), Vmm(i + 1));
+ mulps(Vmm(2 * uf + i + 1), vmm_ns);
+
+ Vmm mask = Vmm(0);
+ if (is_bwd()) {
+ movups(mask, Vmm(uf + i + 1));
+ cmpps(mask, vmm_zero, _cmp_nle_us);
+ } else {
+ movups(mask, Vmm(i + 1));
+ cmpps(mask, vmm_zero, _cmp_nle_us);
+ }
+ blendvps(Vmm(2 * uf + i + 1), Vmm(i + 1));
+ }
+ } else {
+ for (int i = 0; i < uf; i++) {
+ vmulps(Vmm(2 * uf + i + 1), Vmm(i + 1), vmm_ns);
+ if (isa == avx2) {
+ if (is_bwd())
+ vcmpgtps(vmm_mask, Vmm(uf + i + 1), vmm_zero);
+ else
+ vcmpgtps(vmm_mask, Vmm(i + 1), vmm_zero);
+
+ vblendvps(Vmm(2 * uf + i + 1), Vmm(2 * uf + i + 1),
+ Vmm(i + 1), vmm_mask);
+
+ } else {
+ if (is_bwd())
+ vcmpps(k_mask, Vmm(uf + i + 1), vmm_zero, _cmp_nle_us);
+ else
+ vcmpps(k_mask, Vmm(i + 1), vmm_zero, _cmp_nle_us);
+ vblendmps(Vmm(2 * uf + i + 1) | k_mask, Vmm(2 * uf + i + 1),
+ Vmm(i + 1));
+ }
+ }
+ }
+
+ for (int i = 0; i < uf; i++) {
+ if (vectorize) {
+ uni_vmovups(ptr[reg_to + i * shift], Vmm(2 * uf + i + 1));
+ } else {
+ movss(ptr[reg_to + i * shift], Xmm(2 * uf + i + 1));
+ }
+ }
+ }
+
+ jit_uni_relu_kernel_f32(const eltwise_desc_t &desc)
+ : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
+ assert(desc.alg_kind == alg_kind::eltwise_relu);
+ assert(isa == sse42 || isa == avx2 || isa == avx512_common);
+
+ Reg64 param = abi_param1;
+
+ const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
+ const int loop_dec[] = {simd_w, 1};
+ const int uf[] = {1, 1};
+ const int shift[] = {cpu_isa_traits<isa>::vlen, sizeof(float)};
+ const bool loop_vectorize[] = {true, false};
+
+ this->preamble();
+
+ mov(reg_from, ptr[param + GET_OFF(from)]);
+ if (is_bwd())
+ mov(reg_for_comparison, ptr[param + GET_OFF(for_comparison)]);
+ mov(reg_to, ptr[param + GET_OFF(to)]);
+ mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
+
+ mov(imm_addr64, float2int(desc.alpha));
+ movq(xmm_ns, imm_addr64);
+ uni_vbroadcastss(vmm_ns, xmm_ns);
+
+ uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
+
+ Label loop_label[3];
+
+ for (int id = 0; id < 2; id++) {
+ L(loop_label[id]);
+ cmp(reg_work_amount, uf[id] * loop_dec[id] - 1);
+ jle(loop_label[id + 1], T_NEAR);
+
+ compute_step(loop_vectorize[id], uf[id], shift[id]);
+
+ add(reg_from, uf[id] * shift[id]);
+ add(reg_to, uf[id] * shift[id]);
+ if (is_bwd())
+ add(reg_for_comparison, uf[id] * shift[id]);
+
+ sub(reg_work_amount, uf[id] * loop_dec[id]);
+ jmp(loop_label[id]);
+ }
+
+ L(loop_label[2]);
+ this->postamble();
+
+ ker_ = (decltype(ker_))this->getCode();
+ }
+
+private:
+ using Vmm = typename utils::conditional3<isa == sse42, Xmm,
+ isa == avx2, Ymm, Zmm>::type;
+
+ Reg64 reg_from = rax;
+ Reg64 reg_for_comparison = is_bwd() ? rdx : reg_from;
+ Reg64 reg_to = r8;
+ Reg64 reg_work_amount = rsi;
+ Reg64 imm_addr64 = rbx;
+
+ Xmm xmm_ns = Xmm(14);
+
+ Vmm vmm_ns = Vmm(isa == avx512_common ? 30 : 14);
+ Vmm vmm_zero = Vmm(isa == avx512_common ? 31 : 15);
+
+ Vmm vmm_mask = Vmm(isa == avx512_common ? 28 : 12);
+ Opmask k_mask = Opmask(1);
+};
+
+template <cpu_isa_t isa>
+struct jit_uni_kernel_fwd_f32: public jit_uni_eltwise_kernel_f32,
+ public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_kernel_fwd_f32)
+
+ jit_uni_kernel_fwd_f32(const eltwise_desc_t &desc)
+ : jit_uni_eltwise_kernel_f32(desc), jit_generator() {
+
+ eltwise_injector_ = new jit_uni_eltwise_injector_f32<isa>(this,
+ desc.alg_kind, desc.alpha, desc.beta, false, r9, Opmask(1));
+
+ using namespace alg_kind;
+
+ assert(is_bwd() == false);
+ assert(utils::one_of(desc.alg_kind, eltwise_tanh, eltwise_elu,
+ eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+ eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
+
+ preamble();
+
+ Reg64 param = abi_param1;
+ mov(reg_from, ptr[param + GET_OFF(from)]);
+ mov(reg_to, ptr[param + GET_OFF(to)]);
+ mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]);
+ eltwise_injector_->load_table_addr();
+
+ Label reminder_loop_start, reminder_loop_end;
+ Label vectorized_loop_start, vectorized_loop_end;
+
+ cmp(reg_work_amount, simd_w);
+ jl(reminder_loop_start, T_NEAR);
+
+ L(vectorized_loop_start);
+
+ uni_vmovups(vmm_src, ptr[reg_from]);
+ eltwise_injector_->compute_vector(vmm_src.getIdx());
+ uni_vmovups(ptr[reg_to], vmm_src);
+
+ add(reg_from, vlen);
+ add(reg_to, vlen);
+
+ sub(reg_work_amount, simd_w);
+ cmp(reg_work_amount, simd_w);
+ jge(vectorized_loop_start, T_NEAR);
+
+ L(vectorized_loop_end);
+
+ L(reminder_loop_start);
+
+ cmp(reg_work_amount, 0);
+ jle(reminder_loop_end, T_NEAR);
+
+ movss(xmm_src, ptr[reg_from]);
+ eltwise_injector_->compute_vector(xmm_src.getIdx());
+ movss(ptr[reg_to], xmm_src);
+
+ add(reg_from, sizeof(float));
+ add(reg_to, sizeof(float));
+
+ dec(reg_work_amount);
+ jmp(reminder_loop_start, T_NEAR);
+
+ L(reminder_loop_end);
+
+ postamble();
+
+ eltwise_injector_->prepare_table();
+
+ ker_ = (decltype(ker_))this->getCode();
+ }
+
+ ~jit_uni_kernel_fwd_f32() { delete eltwise_injector_; }
+
+private:
+ using Vmm = typename utils::conditional3<isa == sse42, Xmm,
+ isa == avx2, Ymm, Zmm>::type;
+
+ const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
+ const int vlen = cpu_isa_traits<isa>::vlen;
+
+ Reg64 reg_from = rax;
+ Reg64 reg_to = r8;
+ Reg64 reg_work_amount = rsi;
+ Reg64 imm_addr64 = rbx;
+
+ Xmm xmm_src = Xmm(1);
+ Vmm vmm_src = Vmm(1);
+
+ jit_uni_eltwise_injector_f32<isa> *eltwise_injector_;
+};
+
+} /* namespace */
+
+template <cpu_isa_t isa>
+status_t jit_uni_eltwise_fwd_t<isa>::pd_t::init() {
+ using namespace alg_kind;
+
+ bool ok = true
+ && mayiuse(isa)
+ && is_fwd()
+ && utils::everyone_is(data_type::f32, desc()->data_desc.data_type)
+ && !has_zero_dim_memory()
+ && utils::one_of(desc()->alg_kind, eltwise_relu, eltwise_tanh,
+ eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt,
+ eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu,
+ eltwise_logistic)
+ && memory_desc_wrapper(src_md()).is_dense(true)
+ && IMPLICATION(!memory_desc_wrapper(src_md()).is_dense(false),
+ math::eltwise_fwd_preserves_zero(desc()->alg_kind, true))
+ && attr()->has_default_values();
+
+ return ok ? status::success : status::unimplemented;
+}
+
+template <cpu_isa_t isa>
+jit_uni_eltwise_fwd_t<isa>::jit_uni_eltwise_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd), kernel_(nullptr) {
+ const auto &desc = *pd()->desc();
+ switch (desc.alg_kind) {
+ case alg_kind::eltwise_relu:
+ kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
+ default:
+ kernel_ = new jit_uni_kernel_fwd_f32<isa>(desc);
+ }
+}
+
+template <cpu_isa_t isa>
+jit_uni_eltwise_fwd_t<isa>::~jit_uni_eltwise_fwd_t()
+{ delete kernel_; }
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_fwd_t<isa>::execute_forward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+
+ const size_t nelems = data_d.nelems(true);
+
+ src += data_d.offset0();
+ dst += data_d.offset0();
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ size_t start{0}, end{0};
+
+ const int cache_line = 16;
+
+ balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
+ start = nstl::min(nelems, start * cache_line);
+ end = nstl::min(nelems, end * cache_line);
+
+ auto arg = jit_args();
+ arg.from = &src[start];
+ arg.for_comparison = &src[start];
+ arg.to = &dst[start];
+ arg.work_amount = end - start;
+ if (arg.work_amount)
+ (*kernel_)(&arg);
+ });
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_eltwise_bwd_t<isa>::pd_t::init() {
+ bool ok = true
+ && !is_fwd()
+ && utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu)
+ && src_md()->data_type == data_type::f32
+ && !has_zero_dim_memory()
+ && mayiuse(isa)
+ && memory_desc_wrapper(src_md()).is_dense()
+ && memory_desc_wrapper(diff_dst_md()) == memory_desc_wrapper(src_md())
+ && attr()->has_default_values();
+
+ return ok ? status::success : status::unimplemented;
+}
+
+template <cpu_isa_t isa>
+jit_uni_eltwise_bwd_t<isa>::jit_uni_eltwise_bwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd), kernel_(nullptr) {
+ const auto &desc = *pd()->desc();
+ switch (desc.alg_kind) {
+ case alg_kind::eltwise_relu:
+ kernel_ = new jit_uni_relu_kernel_f32<isa>(desc); break;
+ default: assert(!"unknown eltwise alg_kind");
+ }
+}
+
+template <cpu_isa_t isa>
+jit_uni_eltwise_bwd_t<isa>::~jit_uni_eltwise_bwd_t()
+{ delete kernel_; }
+
+template <cpu_isa_t isa>
+void jit_uni_eltwise_bwd_t<isa>::execute_backward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+ const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
+
+ const size_t nelems = data_d.nelems();
+
+ src += data_d.offset0();
+ diff_dst += diff_data_d.offset0();
+ diff_src += diff_data_d.offset0();
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ size_t start{0}, end{0};
+
+ const int cache_line = 16;
+
+ balance211(utils::div_up(nelems, cache_line), nthr, ithr, start, end);
+ start = nstl::min(nelems, start * cache_line);
+ end = nstl::min(nelems, end * cache_line);
+
+ auto arg = jit_args();
+ arg.from = &diff_dst[start];
+ arg.to = &diff_src[start];
+ arg.for_comparison = &src[start];
+ arg.work_amount = end - start;
+ if (arg.work_amount)
+ (*kernel_)(&arg);
+ });
+}
+
+template struct jit_uni_eltwise_fwd_t<sse42>;
+template struct jit_uni_eltwise_bwd_t<sse42>;
+template struct jit_uni_eltwise_fwd_t<avx2>;
+template struct jit_uni_eltwise_bwd_t<avx2>;
+template struct jit_uni_eltwise_fwd_t<avx512_common>;
+template struct jit_uni_eltwise_bwd_t<avx512_common>;
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp
new file mode 100644
index 0000000000..45436b9f46
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_eltwise.hpp
@@ -0,0 +1,193 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_UNI_ELTWISE_HPP
+#define CPU_JIT_UNI_ELTWISE_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_eltwise_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <cpu_isa_t isa>
+struct jit_uni_eltwise_injector_f32 {
+ using Vmm = typename utils::conditional3<isa == sse42, Xbyak::Xmm,
+ isa == avx2, Xbyak::Ymm, Xbyak::Zmm>::type;
+
+ jit_uni_eltwise_injector_f32(jit_generator *host, alg_kind_t alg,
+ float alpha, float beta, bool save_state = true,
+ Xbyak::Reg64 p_table = Xbyak::util::rax,
+ Xbyak::Opmask k_mask = Xbyak::Opmask(1))
+ : alg_(alg), alpha_(alpha), beta_(beta), h(host)
+ , save_state_(save_state), p_table(p_table), k_mask(k_mask)
+ {
+ using namespace alg_kind;
+ assert(utils::one_of(isa, sse42, avx2, avx512_common));
+ assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
+ eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+ eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
+ }
+
+ // note that eltwise.scale is ignored
+ jit_uni_eltwise_injector_f32(jit_generator *host,
+ const post_ops_t::entry_t::eltwise_t &eltwise,
+ bool save_state = true, Xbyak::Reg64 p_table = Xbyak::util::rax,
+ Xbyak::Opmask k_mask = Xbyak::Opmask(1))
+ : jit_uni_eltwise_injector_f32(host, eltwise.alg, eltwise.alpha,
+ eltwise.beta, save_state, p_table, k_mask) {}
+
+ void compute_vector_range(size_t start_idx, size_t end_idx);
+ void compute_vector(size_t idx) { compute_vector_range(idx, idx + 1); }
+ void prepare_table(bool gen_table=true);
+ void load_table_addr() { h->mov(p_table, l_table); }
+
+ const alg_kind_t alg_;
+ const float alpha_;
+ const float beta_;
+
+ jit_generator * const h;
+
+ const bool save_state_;
+ const Xbyak::Reg64 p_table;
+ const Xbyak::Opmask k_mask;
+ Xbyak::Label l_table;
+
+private:
+ // if only the injector was inherited from jit_generator...
+ enum {
+ _cmp_le_os = jit_generator::_cmp_le_os,
+ _cmp_nle_us = jit_generator::_cmp_nle_us,
+ _op_floor = jit_generator::_op_floor,
+ };
+
+ size_t vlen = cpu_isa_traits<isa>::vlen;
+
+ const static size_t preserved_vecs_max = 5;
+
+ size_t vecs_to_preserve = 0;
+ size_t vecs_count = isa == avx512_common ? 32 : 16;
+ size_t preserved_vecs_count = 0;
+ size_t preserved_vec_idxs[preserved_vecs_max] = {0};
+ size_t start_idx_tail = 0;
+
+ Vmm vmm_mask, vmm_aux0, vmm_aux1, vmm_aux2, vmm_aux3, vmm_aux4;
+
+ Xbyak::Address table_val(int index)
+ { return h->ptr[p_table + index * vlen]; }
+
+ int aux_vecs_count(alg_kind_t alg);
+
+ void compute_body(size_t start_idx, size_t end_idx);
+ void injector_preamble(size_t start_idx, size_t end_idx);
+ void injector_preamble_tail(size_t start_idx);
+ void injector_postamble();
+ void assign_regs();
+
+ void exp_compute_vector(const Vmm &vmm_src);
+ void relu_compute_vector(const Vmm &vmm_src);
+ void relu_zero_ns_compute_vector(const Vmm &vmm_src);
+ void elu_compute_vector(const Vmm &vmm_src);
+ void tanh_compute_vector(const Vmm &vmm_src);
+ void square_compute_vector(const Vmm &vmm_src);
+ void abs_compute_vector(const Vmm &vmm_src);
+ void sqrt_compute_vector(const Vmm &vmm_src);
+ void linear_compute_vector(const Vmm &vmm_src);
+ void bounded_relu_compute_vector(const Vmm &vmm_src);
+ void soft_relu_compute_vector(const Vmm &vmm_src);
+ void logistic_compute_vector(const Vmm &vmm_src);
+
+ void relu_prepare_table();
+ void elu_prepare_table();
+ void soft_relu_prepare_table();
+ void abs_prepare_table();
+ void sqrt_prepare_table();
+ void linear_prepare_table();
+ void bounded_relu_prepare_table();
+};
+
+struct jit_uni_eltwise_kernel_f32;
+
+template <cpu_isa_t isa>
+struct jit_uni_eltwise_fwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_eltwise_fwd_pd_t {
+ using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", isa, ""),
+ jit_uni_eltwise_fwd_t<isa>);
+
+ status_t init();
+ };
+
+ jit_uni_eltwise_fwd_t(const pd_t *apd);
+ ~jit_uni_eltwise_fwd_t();
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ jit_uni_eltwise_kernel_f32 *kernel_;
+};
+
+template <cpu_isa_t isa>
+struct jit_uni_eltwise_bwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_eltwise_bwd_pd_t {
+ using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", isa, ""),
+ jit_uni_eltwise_bwd_t<isa>);
+
+ status_t init();
+ };
+
+ jit_uni_eltwise_bwd_t(const pd_t *apd);
+ ~jit_uni_eltwise_bwd_t();
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ jit_uni_eltwise_kernel_f32 *kernel_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp
new file mode 100644
index 0000000000..a3ca6273a0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.cpp
@@ -0,0 +1,949 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "jit_uni_i8i8_pooling.hpp"
+
+#include <math.h>
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace Xbyak;
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::types;
+using namespace alg_kind;
+
+template <cpu_isa_t isa>
+struct jit_uni_i8i8_pooling_fwd_ker_t: public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_i8i8_pooling_fwd_ker_t)
+
+ struct call_params_t {
+ const char *src_i8;
+ const char *dst_i8;
+ size_t kw_range;
+ size_t kh_range;
+ float idivider;
+ };
+
+ using Vmm = typename cpu_isa_traits<isa>::Vmm;
+ Xmm xreg(int idx) const { return Xmm(idx); }
+ Ymm yreg(int idx) const { return Ymm(xreg(idx).getIdx()); }
+ Vmm vreg(int idx) const { return Vmm(xreg(idx).getIdx()); }
+
+ // In case of avx2 with data type i8 we need to use
+ // maskmovdqu instruction which has its destination hardcoded in rdi.
+ // Windows ABI: abi_param1 is rcx - nothing to do else
+ // Unix ABI: abi_param1 is rdi - copy it to rcx and use it as abi_param1
+ Reg64 reg_param = rcx; // Our "unified abi_param1"
+ Reg64 reg_ptr_src_i8 = r8;
+ Reg64 reg_ptr_dst_i8 = r9;
+ Reg64 reg_ptr_maskmovdqu_dst = rdi; // store destination - must be rdi
+
+ Reg64 ki = r10;
+ Reg64 kj = r11;
+ Reg64 reg_kw = r12;
+ Reg64 reg_kh = r13;
+ Reg64 c_iter = r14;
+
+ Reg64 aux_reg_src_h = rax;
+ Reg64 aux_reg_src_w = rbx;
+
+ Reg64 reg_tmp = rdx;
+
+ Reg64 reg_mask = r15;
+
+ Opmask k_cmp_mask = Opmask(7);
+
+ Opmask mask(int idx) {
+ return Opmask(6 - idx);
+ }
+
+ // ref to any of XYZ-regs via xreg/yreg/vreg functions
+ Xmm xmm_tmp = xreg(0); // temp to init vreg_tmp
+ Vmm vreg_tmp = vreg(0); // max pooling : holds minimum values for data_type
+ Vmm vreg_zeros = vreg(1);
+
+ // only in case of <isa> == avx2
+ Vmm vreg_mask = vreg(2); // full byte-mask
+ Xmm xreg_mask_lo = xreg(2); // low 128-bits part of byte-mask (alias for xmm part of vreg_mask)
+ Xmm xreg_mask_hi = xreg(3); // "max" - high 128-bits part of byte-mask (stored separately)
+ Xmm xreg_mask_q = xreg(3); // "avg" - 1/4 part of the mask for s8/u8 operations
+ Vmm vreg_mask_q = vreg(3); // "avg" - 1/4 part for non-zero tails
+
+ enum:int {vidx_base = isa == avx2 ? 4 : 2};
+ Vmm base_vr(int idx) const { return vreg(vidx_base + idx); }
+
+ size_t sizeof_src_dt() const { return data_type_size(jpp.src_dt); }
+ size_t sizeof_dst_dt() const { return data_type_size(jpp.dst_dt); }
+
+ /* max pooling */
+ Vmm vreg_src(int idx) const { return base_vr(idx); } // [0 .. ur_c-1]
+ Vmm vreg_dst(int idx) const { return base_vr(jpp.ur_c + idx); } // [ur_c .. 2*ur_c-1]
+
+ /* avg pooling */
+ // s32 used for processing of s8/u8 data
+ // thus we need to take into account ratio of sizes s32/i8 = 4
+ static constexpr data_type_t avg_proc_dt = data_type::s32;
+ enum:int {
+ s32_to_i8_ratio = sizeof(typename prec_traits<avg_proc_dt>::type)
+ / sizeof(typename prec_traits<data_type::u8>::type),
+ max_num_ll = s32_to_i8_ratio
+ };
+ Vmm vreg_src_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 0*max_num_ll); } // ll: 0..4 [0..3]
+ Vmm vreg_dst_s32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 1*max_num_ll); } // ll: 0..4 [4..7]
+ Vmm vreg_dst_f32(int jj, int ll) { return base_vr(3*max_num_ll*jj + ll + 2*max_num_ll); } // ll: 0..4 [8..11]
+
+ void (*ker_)(const call_params_t *);
+ jit_pool_conf_t jpp;
+
+ void init_tmp_reg();
+ void init_mask();
+
+ void load_vreg_mask_q(int ll) {};
+
+ void load_src_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+ void load_src_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+ void load_src(int jj, int ll, int c_tail);
+
+ void store_dst_max_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+ void store_dst_avg_op(int jj, int ll, size_t offset, bool masked, uint64_t msk);
+ void store_dst(int jj, int ll, int c_tail);
+
+ void compute_avg_step(int ur_c, int c_tail);
+ void compute_max_op(const int jj);
+ void compute_max_step(int ur_c, int c_tail);
+ void compute_step(int ur_c, int c_tail);
+
+ void compute_c_block();
+ void generate();
+
+ static status_t init_conf(jit_pool_conf_t &jpp, const pooling_pd_t *ppd);
+
+ jit_uni_i8i8_pooling_fwd_ker_t(const jit_pool_conf_t &jpp_)
+ : jpp(jpp_) {
+ generate();
+ ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(
+ getCode()));
+ }
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_vreg_mask_q(int ll) {
+
+ // extract ll-th part of mask (ll-th QWORD)
+ vpblendd(vreg_mask_q, vreg_zeros, vreg_mask, 0x3 << ll); // 0x3 - mask for 2 x DWORD
+
+ // Move mask from ll-th pos to 0-th pos
+ if (ll>0)
+ vpermq(vreg_mask_q, vreg_mask_q, ll);
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_max_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ if (masked) {
+ if (jpp.src_dt == s32) {
+ vpblendd(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], static_cast<uint8_t>(msk));
+ } else {
+ vpblendvb(vreg_src(jj), vreg_tmp, ptr[aux_reg_src_w + offset], vreg_mask);
+ }
+ } else
+ vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_max_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ if (masked) {
+ if (jpp.src_dt == s32)
+ vmovups(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
+ else
+ vmovdqu8(vreg_src(jj) | mask(0), ptr[aux_reg_src_w + offset]);
+ } else
+ vmovups(vreg_src(jj), ptr[aux_reg_src_w + offset]);
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::load_src_avg_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ // Don't generate useless code
+ if (masked && !msk)
+ return;
+
+ auto load_i8 = [&](bool is_signed, const Vmm& vr_src) {
+
+ // Need to use mask of tail?
+ if (masked) {
+
+ // load ll-th part of mask into vreg_mask_q
+ load_vreg_mask_q(ll);
+
+ // Load by mask from mem into register vr_src
+ vpblendvb(vr_src, vreg_zeros, ptr[aux_reg_src_w + offset], vreg_mask_q);
+
+ // Conversion s8/u8 -> s32
+ if (is_signed)
+ vpmovsxbd(vr_src, vr_src);
+ else
+ vpmovzxbd(vr_src, vr_src);
+ } else {
+
+ // Load from mem into vr_src with conversion
+ if (is_signed)
+ vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
+ else
+ vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
+ }
+ };
+
+ switch (jpp.src_dt) {
+ case s32:
+ if (masked)
+ vpblendd(vreg_src_s32(jj, ll), vreg_zeros, ptr[aux_reg_src_w + offset],
+ static_cast<uint8_t>(msk));
+ else
+ vmovups(vreg_src_s32(jj, ll), ptr[aux_reg_src_w + offset]);
+ break;
+ case s8:
+ load_i8(true, vreg_src_s32(jj, ll));
+ break;
+ case u8:
+ load_i8(false, vreg_src_s32(jj, ll));
+ break;
+ default: assert(!"unsupported src data type");
+ }
+};
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::load_src_avg_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ // Don't generate useless code
+ if (masked && !msk)
+ return;
+
+ const Vmm& vr_src = masked ?
+ vreg_src_s32(jj, ll) | mask(ll) :
+ vreg_src_s32(jj, ll);
+
+ switch (jpp.src_dt) {
+ case s32:
+ vmovups(vr_src, ptr[aux_reg_src_w + offset]);
+ break;
+ case s8:
+ vpmovsxbd(vr_src, ptr[aux_reg_src_w + offset]);
+ break;
+ case u8:
+ vpmovzxbd(vr_src, ptr[aux_reg_src_w + offset]);
+ break;
+ default: assert(!"unsupported src data type");
+ }
+};
+
+template <cpu_isa_t isa>
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::load_src(int jj, int ll, int c_tail) {
+ using namespace data_type;
+
+ int c_block = jpp.c_block;
+ int ur_c = jpp.ur_c;
+
+ switch (jpp.alg) {
+ case pooling_max: {
+ auto offset = jj*c_block*sizeof_src_dt();
+ bool masked = jj == ur_c - 1 && c_tail;
+ load_src_max_op(jj, ll, offset, masked, jpp.tail[0]);
+ break;
+ }
+ case pooling_avg_include_padding:
+ case pooling_avg_exclude_padding: {
+ auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_src_dt();
+ bool masked = jj == ur_c - 1 && c_tail;
+ load_src_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
+ break;
+ }
+ default: assert(!"unsupported algorithm");
+ }
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_max_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ int c_block = jpp.c_block;
+
+ if (masked) {
+ switch (jpp.src_dt) {
+ case s32:
+ vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst(jj));
+ break;
+ case s8:
+ case u8: {
+ // Store low half by mask (bytes 0...15)
+ lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
+ maskmovdqu(vreg_dst(jj), xreg_mask_lo);
+
+ // Do we need to store high half (bytes 16...31) ?
+ const uint64_t low_mask = (1ULL << (c_block/2))-1;
+ if (msk & ~low_mask) {
+ vextracti128(Xmm(vreg_dst(jj).getIdx()), vreg_dst(jj), 1);
+ add(reg_ptr_maskmovdqu_dst, c_block / 2);
+ maskmovdqu(vreg_dst(jj), xreg_mask_hi);
+ }
+ } break;
+ default: assert(!"unsupported src data type");
+ }
+ } else
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_max_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ if (masked) {
+ switch (jpp.src_dt) {
+ case s32:
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
+ break;
+ case s8:
+ case u8:
+ vmovdqu8(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj) | mask(0));
+ break;
+ default: assert(!"unsupported src data type");
+ }
+ } else
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst(jj));
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::store_dst_avg_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk){
+ using namespace data_type;
+
+ // Don't generate useless code
+ if (masked && !msk)
+ return;
+
+ auto s32_to_i8 = [&](bool is_signed, const Vmm& vr_dst) {
+
+ // conversion: s32 -> s16/u16 : {8 x s32}{8 x 0} -> {16 x s16/u16}
+ // Result QWORDs (qw0, qw1) permuted: {qw0, 0, qw1, 0}
+ if (is_signed)
+ vpackssdw(vr_dst, vr_dst, vreg_zeros);
+ else
+ vpackusdw(vr_dst, vr_dst, vreg_zeros);
+
+ // Permute qwords to restore original order
+ // {qw0, 0, qw1, 0} -> {qw0, qw1, 0, 0}
+ vpermq(vr_dst, vr_dst, 0x58);
+
+ // conversion: s16/u16 -> s8/u8 : {16 x s16/u16}{16 x 0} -> {32 x s8/u8}
+ // Target QWORD qw = {8 x s8/u8} has proper position: {qw, xx, xx, xx}
+ if (is_signed)
+ vpacksswb(vr_dst, vr_dst, vreg_zeros);
+ else
+ vpackuswb(vr_dst, vr_dst, vreg_zeros);
+
+ };
+
+ auto store_i8 = [&](bool is_signed, bool is_masked, const Vmm& vr_dst) {
+
+ // Conversion s32 -> s8/u8
+ s32_to_i8(is_signed, vr_dst);
+
+ // Need to use mask of tail?
+ if (is_masked) {
+ // load ll-th part of mask into vreg_mask_q
+ load_vreg_mask_q(ll);
+ }
+
+ // store 8 bytes
+ lea(reg_ptr_maskmovdqu_dst, ptr[reg_ptr_dst_i8 + offset]);
+ maskmovdqu(vr_dst, xreg_mask_q);
+ };
+
+ switch (jpp.dst_dt) {
+ case s32:
+ if (masked) {
+ vpmaskmovd(ptr[reg_ptr_dst_i8 + offset], vreg_mask, vreg_dst_s32(jj, ll));
+ } else
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vreg_dst_s32(jj, ll));
+ break;
+ case s8:
+ store_i8(true, masked, vreg_dst_s32(jj, ll));
+ break;
+ case u8:
+ store_i8(false, masked, vreg_dst_s32(jj, ll));
+ break;
+ default: assert(!"unsuppotred dst data_type");
+ }
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::store_dst_avg_op(int jj, int ll,
+ size_t offset, bool masked, uint64_t msk) {
+ using namespace data_type;
+
+ // Don't generate useless code
+ if (masked && !msk)
+ return;
+
+ const Vmm& vr_dst = masked ?
+ vreg_dst_s32(jj, ll) | mask(ll) :
+ vreg_dst_s32(jj, ll);
+
+ switch (jpp.dst_dt) {
+ case s32:
+ vmovups(ptr[reg_ptr_dst_i8 + offset], vr_dst);
+ break;
+ case s8:
+ vpmovdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
+ break;
+ case u8:
+ vpmovusdb(ptr[reg_ptr_dst_i8 + offset], vr_dst);
+ break;
+ default: assert(!"unsupported dst data_type");
+ }
+}
+
+
+template <cpu_isa_t isa>
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::store_dst(int jj, int ll,
+ int c_tail) {
+ using namespace data_type;
+
+ int c_block = jpp.c_block;
+ int ur_c = jpp.ur_c;
+
+ switch(jpp.alg) {
+ case pooling_max: {
+ auto offset = jj*c_block*sizeof_dst_dt();
+ bool masked = jj == ur_c - 1 && c_tail;
+ store_dst_max_op(jj, ll, offset, masked, jpp.tail[ll]);
+ break;
+ }
+ case pooling_avg_include_padding:
+ case pooling_avg_exclude_padding: {
+ auto offset = (ll*(c_block/max_num_ll) + jj*c_block)*sizeof_dst_dt();
+ bool masked = jj == ur_c - 1 && c_tail;
+ store_dst_avg_op(jj, ll, offset, masked, jpp.tail[ll]);
+ break;
+ }
+ default: assert(!"unsupported pooling algorithm");
+ }
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::compute_max_op(const int jj)
+{
+ using namespace data_type;
+ switch (jpp.src_dt) {
+ case s32:
+ vpmaxsd(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
+ break;
+ case s8:
+ vpmaxsb(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
+ break;
+ case u8:
+ vpmaxub(vreg_dst(jj), vreg_dst(jj), vreg_src(jj));
+ break;
+ default: assert(!"unsupported src data type");
+ }
+}
+
+template <>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::compute_max_op(const int jj)
+{
+ using namespace data_type;
+
+ // Compare
+ switch (jpp.src_dt) {
+ case s32:
+ vpcmpd(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
+ break;
+ case s8:
+ vpcmpb(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
+ break;
+ case u8:
+ vpcmpub(k_cmp_mask, vreg_dst(jj), vreg_src(jj), _cmp_lt_os);
+ break;
+ default: assert(!"unsupported src data type");
+ }
+
+ // move max values into vreg_dst
+ if (jpp.src_dt == s32)
+ vpblendmd(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
+ else
+ vpblendmb(vreg_dst(jj) | k_cmp_mask, vreg_dst(jj), vreg_src(jj));
+}
+
+
+template <cpu_isa_t isa>
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_max_step(int ur_c, int c_tail)
+{
+ Label l_kw, l_kh;
+
+ int iw = jpp.iw;
+ int c = jpp.c;
+
+ for (int jj = 0; jj < ur_c; jj++)
+ vmovups(vreg_dst(jj), vreg_tmp);
+
+ mov(aux_reg_src_h, reg_ptr_src_i8);
+
+ xor_(kj, kj);
+ L(l_kh);
+ {
+ mov(aux_reg_src_w, aux_reg_src_h);
+ xor_(ki, ki);
+ L(l_kw);
+ {
+ for (int jj = 0; jj < ur_c; jj++) {
+ load_src(jj, 0, c_tail);
+ compute_max_op(jj);
+ }
+ add(aux_reg_src_w, c * sizeof_src_dt());
+ inc(ki);
+ cmp(ki, reg_kw);
+ jl(l_kw, T_NEAR);
+ }
+ add(aux_reg_src_h, iw * c * sizeof_src_dt());
+ inc(kj);
+ cmp(kj, reg_kh);
+ jl(l_kh, T_NEAR);
+ }
+
+ for (int jj = 0; jj < ur_c; jj++)
+ store_dst(jj, 0, c_tail);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_avg_step(int ur_c, int c_tail)
+{
+ using namespace data_type;
+
+ Label l_kw, l_kh;
+
+ int iw = jpp.iw;
+ int c = jpp.c;
+
+ const int num_ll = data_type_size(avg_proc_dt)/data_type_size(jpp.src_dt);
+
+ for (int jj = 0; jj < ur_c; jj++) {
+ for (int ll = 0; ll < num_ll; ll++) {
+ bool masked = jj == ur_c - 1 && c_tail;
+ size_t msk = jpp.tail[ll];
+ if (!(masked && !msk)) {
+ uni_vpxor(vreg_src_s32(jj, ll), vreg_src_s32(jj, ll), vreg_src_s32(jj, ll));
+ uni_vpxor(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll));
+ }
+ }
+ }
+
+ mov(aux_reg_src_h, reg_ptr_src_i8);
+
+ xor_(kj, kj);
+ L(l_kh);
+ {
+ mov(aux_reg_src_w, aux_reg_src_h);
+ xor_(ki, ki);
+ L(l_kw);
+ {
+ for (int jj = 0; jj < ur_c; jj++) {
+ for (int ll = 0; ll < num_ll; ll++) {
+ bool masked = jj == ur_c - 1 && c_tail;
+ size_t msk = jpp.tail[ll];
+ if (!(masked && !msk)) {
+ load_src(jj, ll, c_tail);
+ vpaddd(vreg_dst_s32(jj, ll), vreg_dst_s32(jj, ll),
+ vreg_src_s32(jj, ll));
+ }
+ }
+ }
+ add(aux_reg_src_w, c * sizeof_src_dt());
+ inc(ki);
+ cmp(ki, reg_kw);
+ jl(l_kw, T_NEAR);
+ }
+ add(aux_reg_src_h, iw * c * sizeof_src_dt());
+ inc(kj);
+ cmp(kj, reg_kh);
+ jl(l_kh, T_NEAR);
+ }
+
+ for (int jj = 0; jj < ur_c; jj++) {
+ for (int ll = 0; ll < num_ll; ll++) {
+ bool masked = jj == ur_c - 1 && c_tail;
+ size_t msk = jpp.tail[ll];
+ if (!(masked && !msk)) {
+ vcvtdq2ps(vreg_dst_f32(jj, ll), vreg_dst_s32(jj, ll));
+ vfmadd132ps(vreg_dst_f32(jj, ll), vreg_zeros, vreg_tmp);
+ vcvtps2dq(vreg_dst_s32(jj, ll), vreg_dst_f32(jj, ll));
+ store_dst(jj, ll, c_tail);
+ }
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_step(int ur_c, int c_tail) {
+ switch (jpp.alg) {
+ case pooling_max:
+ compute_max_step(ur_c, c_tail); break;
+ case pooling_avg_include_padding:
+ case pooling_avg_exclude_padding:
+ compute_avg_step(ur_c, c_tail); break;
+ default: assert(!"unsupported pooling algorithm");
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::compute_c_block(){
+ Label l_main_loop;
+
+ int nb_c = jpp.nb_c;
+ int c_block = jpp.c_block;
+ int ur_c = jpp.ur_c;
+ int ur_c_tail = jpp.ur_c_tail;
+ int c_steps = nb_c / ur_c;
+ int c_tail = jpp.c_tail;
+
+ xor_(c_iter, c_iter);
+ if (c_steps > 0) {
+ L(l_main_loop); {
+ compute_step(ur_c, 0);
+ add(reg_ptr_src_i8, ur_c*c_block*sizeof_src_dt());
+ add(reg_ptr_dst_i8, ur_c*c_block*sizeof_dst_dt());
+ inc(c_iter);
+ cmp(c_iter, c_steps);
+ jl(l_main_loop, T_NEAR);
+ }
+ }
+
+ if (ur_c_tail != 0) {
+ compute_step(ur_c_tail, c_tail);
+ }
+}
+
+template<>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx2>::init_mask() {
+ using namespace data_type;
+ using cpu_isa = cpu_isa_traits<avx2>;
+
+ // AVX2 mask initialization: mask stored in Ymm-regs
+ auto init = [&](uint64_t bit_mask, bool init_mask_q) {
+ const size_t QW_PER_VREG = cpu_isa::vlen / sizeof(uint64_t);
+
+ uint64_t vmask[QW_PER_VREG];
+ for (size_t i = 0; i < QW_PER_VREG; i++){
+
+ uint64_t qw_vmask=0ULL;
+ const size_t DBITS = 8*sizeof_src_dt();
+ const uint64_t VMSK = 1ULL << (DBITS-1);
+ const size_t D_PER_QW = (8*sizeof(qw_vmask))/DBITS;
+ for (size_t j = 0; j < D_PER_QW; j++) {
+ if (bit_mask & 1)
+ qw_vmask |= VMSK << DBITS * j;
+ bit_mask >>= 1;
+ }
+ vmask[i] = qw_vmask;
+ }
+
+ // Put QWORDS with target mask into xmm regs
+ const int xdst_i[QW_PER_VREG] = {
+ xreg_mask_lo.getIdx(),
+ xreg_mask_lo.getIdx(),
+ xreg_mask_hi.getIdx(),
+ xreg_mask_hi.getIdx()
+ };
+ const int xsrc_i[QW_PER_VREG] = {
+ vreg_zeros.getIdx(), // 0-th qword insert in zeros -> {qw0, 0}
+ xreg_mask_lo.getIdx(), // 1-st and 0-th merge -> {qw0,qw1}
+ vreg_zeros.getIdx(),
+ xreg_mask_hi.getIdx()
+ };
+ const uint8 qw_dst_idx[QW_PER_VREG] = {0, 1, 0, 1}; // qword index in 128-bit xreg
+
+ for (size_t i = 0; i < QW_PER_VREG; i++) {
+ mov(reg_mask, vmask[i]);
+ vpinsrq(Xmm(xdst_i[i]), Xmm(xsrc_i[i]), reg_mask, qw_dst_idx[i]);
+ }
+
+ // Merge Low (xreg_mask_lo alias for vreg_mask.xreg)
+ // and High (xreg_mask_hi) into full vreg_mask
+ // vreg_mask -> {xreg_mask_hi, vreg_mask.xreg}
+ vinserti128(vreg_mask, vreg_mask, xreg_mask_hi, 1);
+
+ // Keep only low qword of mask in xreg_mask_q
+ if (init_mask_q) {
+ mov(reg_mask, vmask[0]);
+ vpinsrq(xreg_mask_q, Xmm(vreg_zeros.getIdx()), reg_mask, 0);
+ }
+ };
+
+ uint64_t tail_mask = (1ULL << jpp.c_tail) - 1;
+ switch (jpp.alg) {
+ case pooling_max:
+ // For "max" we need mask only in case of non-zero tail
+ if (tail_mask)
+ init(tail_mask, false);
+ break;
+ case pooling_avg_include_padding:
+ case pooling_avg_exclude_padding:
+ // For "avg" we need mask:
+ // - s32 - in case of the non-zero tail
+ // - s8/u8 - irrespective of the tail
+ switch (jpp.src_dt) {
+ case s32:
+ if (tail_mask)
+ init(tail_mask, false);
+ break;
+ case s8:
+ case u8:
+ init(tail_mask ? tail_mask : ~0ULL, tail_mask == 0);
+ break;
+ default: assert(!"unsupported src data type");
+ }
+ break;
+ default: assert(!"unsupported pooling algorithm");
+ }
+}
+
+template<>
+void jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>::init_mask() {
+
+ for (int ll = 0; ll < max_num_ll; ll++) {
+ mov(reg_mask, jpp.tail[ll]);
+ kmovq(mask(ll), reg_mask);
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_tmp_reg() {
+ using namespace data_type;
+
+ switch (jpp.alg) {
+ case pooling_avg_include_padding:
+ case pooling_avg_exclude_padding:
+ mov(reg_tmp, ptr[reg_param + offsetof(call_params_t, idivider)]);
+ movq(xmm_tmp, reg_tmp);
+ vpbroadcastd(vreg_tmp, xmm_tmp);
+ break;
+ case pooling_max:
+ switch (jpp.src_dt) {
+ case s32:
+ mov(reg_tmp, nstl::numeric_limits<int32_t>::lowest());
+ break;
+ case s8:
+ mov(reg_tmp, nstl::numeric_limits<int8_t>::lowest());
+ break;
+ case u8:
+ mov(reg_tmp, nstl::numeric_limits<uint8_t>::lowest());
+ break;
+ default: assert(!"unsupported src data_type");
+ }
+
+ movq(xmm_tmp, reg_tmp);
+ if (jpp.src_dt == s32)
+ vpbroadcastd(vreg_tmp, xmm_tmp);
+ else
+ vpbroadcastb(vreg_tmp, xmm_tmp);
+ break;
+ default: assert(!"unsupported pooling algorithm");
+ }
+
+}
+
+template <cpu_isa_t isa>
+void jit_uni_i8i8_pooling_fwd_ker_t<isa>::generate() {
+ preamble();
+
+#if !defined(_WIN32)
+ // Always use rcx as abi_param1 -
+ // see the note about maskmovdqu near reg_param.
+ mov(rcx, rdi);
+#endif
+
+# define READ_PARAM(reg, field) \
+ mov(reg, ptr[reg_param + offsetof(call_params_t, field)])
+ READ_PARAM(reg_ptr_src_i8, src_i8);
+ READ_PARAM(reg_ptr_dst_i8, dst_i8);
+ READ_PARAM(reg_kw, kw_range);
+ READ_PARAM(reg_kh, kh_range);
+
+# undef READ_PARAM
+
+ uni_vpxor(vreg_zeros, vreg_zeros, vreg_zeros);
+
+ init_mask();
+
+ init_tmp_reg();
+
+ compute_c_block();
+
+ postamble();
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jit_pool_conf_t &jpp,
+ const pooling_pd_t *ppd) {
+ if (!mayiuse(isa))
+ return status::unimplemented;
+
+ const auto &pd = *ppd->desc();
+ const memory_desc_wrapper src_d(ppd->src_md());
+ const memory_desc_wrapper dst_d(ppd->dst_md());
+
+ jpp.mb = src_d.dims()[0];
+ jpp.c = src_d.dims()[1];
+ jpp.ih = src_d.dims()[2];
+ jpp.iw = src_d.dims()[3];
+ jpp.oh = dst_d.dims()[2];
+ jpp.ow = dst_d.dims()[3];
+
+ jpp.stride_h = pd.strides[0];
+ jpp.stride_w = pd.strides[1];
+ jpp.kh = pd.kernel[0];
+ jpp.kw = pd.kernel[1];
+
+ jpp.t_pad = pd.padding[0][0];
+ jpp.l_pad = pd.padding[0][1];
+
+ jpp.alg = pd.alg_kind;
+
+ jpp.src_dt = pd.src_desc.data_type;
+ jpp.dst_dt = pd.dst_desc.data_type;
+
+ // data_type items per one vreg on the <isa>
+ // isa == avx2 : 32 bytes -> 32 for s8/u8, 8 for s32
+ // isa == avx512* : 64 bytes -> 64 for s8/u8, 16 for s32
+ int simd_w = cpu_isa_traits<isa>::vlen / data_type_size(jpp.src_dt);
+
+ jpp.c_block = simd_w;
+ jpp.c_tail = jpp.c % jpp.c_block;
+ jpp.nb_c = jpp.c / jpp.c_block;
+ jpp.ur_c = 1;
+ jpp.ur_c_tail = jpp.nb_c - (jpp.nb_c / jpp.ur_c)*jpp.ur_c +
+ (jpp.c_tail != 0);
+
+ size_t tail_mask = (1ULL << jpp.c_tail) - 1;
+
+ switch (jpp.alg) {
+ case pooling_max:
+ jpp.tail[0] = tail_mask;
+ jpp.tail[1] = 0;
+ jpp.tail[2] = 0;
+ jpp.tail[3] = 0;
+ break;
+ case pooling_avg_include_padding:
+ case pooling_avg_exclude_padding: {
+ // avg_proc_dt (s32) defines granularity (because u8/s8 processed as s32)
+ // avx2 : 8, avx512 : 16
+ const size_t msk_gran = cpu_isa_traits<isa>::vlen / data_type_size(avg_proc_dt);
+ const size_t msk_msk = (1ULL << msk_gran) - 1;
+ size_t m = tail_mask;
+ for (size_t ll = 0; ll < max_num_ll; ll++) {
+ jpp.tail[ll] = m & msk_msk;
+ m = m >> msk_gran;
+ }
+ break;
+ }
+ default: return status::unimplemented;
+ }
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_i8i8_pooling_fwd_t<isa>::pd_t::jit_conf() {
+ return jit_uni_i8i8_pooling_fwd_ker_t<isa>::init_conf(jpp_, this);
+}
+
+template <cpu_isa_t isa>
+jit_uni_i8i8_pooling_fwd_t<isa>::
+jit_uni_i8i8_pooling_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd), ker_(nullptr)
+{ ker_ = new jit_uni_i8i8_pooling_fwd_ker_t<isa>(pd()->jpp_); }
+
+template <cpu_isa_t isa>
+jit_uni_i8i8_pooling_fwd_t<isa>::
+~jit_uni_i8i8_pooling_fwd_t() { delete ker_; }
+
+template <cpu_isa_t isa>
+void jit_uni_i8i8_pooling_fwd_t<isa>::execute_forward(
+ const exec_ctx_t &ctx) const {
+ auto src_i8 = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC);
+ auto dst_i8 = CTX_OUT_MEM(char *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+
+ const auto &jpp = pd()->jpp_;
+
+ parallel_nd(jpp.mb, jpp.oh, jpp.ow,
+ [&](int n, int oh, int ow) {
+ const int ih = nstl::max(oh*jpp.stride_h - jpp.t_pad, 0);
+ const int iw = nstl::max(ow*jpp.stride_w - jpp.l_pad, 0);
+
+ const int kh_start = nstl::max(0, jpp.t_pad - oh * jpp.stride_h);
+ const int kh_end = nstl::min(jpp.kh,
+ jpp.ih + jpp.t_pad - oh * jpp.stride_h);
+ const int kw_start = nstl::max(0, jpp.l_pad - ow * jpp.stride_w);
+ const int kw_end = nstl::min(jpp.kw,
+ jpp.iw + jpp.l_pad - ow * jpp.stride_w);
+
+ auto p = typename jit_uni_i8i8_pooling_fwd_ker_t<isa>::call_params_t();
+ p.src_i8 = &src_i8[
+ src_d.blk_off(n, 0, ih, iw) * src_d.data_type_size()];
+ p.dst_i8 = &dst_i8[
+ dst_d.blk_off(n, 0, oh, ow) * dst_d.data_type_size()];
+ p.kw_range = (size_t)(kw_end - kw_start);
+ p.kh_range = (size_t)(kh_end - kh_start);
+ p.idivider = 1.0f / ((jpp.alg == pooling_avg_exclude_padding) ?
+ p.kh_range*p.kw_range : jpp.kw*jpp.kh);
+
+ ker_->ker_(&p);
+ });
+}
+
+// Explicit instantiation only for supported <isa> values.
+//
+template struct jit_uni_i8i8_pooling_fwd_ker_t<avx512_core>;
+template struct jit_uni_i8i8_pooling_fwd_t<avx512_core>;
+
+template struct jit_uni_i8i8_pooling_fwd_ker_t<avx2>;
+template struct jit_uni_i8i8_pooling_fwd_t<avx2>;
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp
new file mode 100644
index 0000000000..d757679df5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_i8i8_pooling.hpp
@@ -0,0 +1,89 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_UNI_I8I8_POOLING_HPP
+#define CPU_JIT_UNI_I8I8_POOLING_HPP
+
+#include "c_types_map.hpp"
+
+#include "cpu_pooling_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "cpu_isa_traits.hpp"
+#include "jit_primitive_conf.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <cpu_isa_t isa>
+struct jit_uni_i8i8_pooling_fwd_ker_t;
+
+template <cpu_isa_t isa>
+struct jit_uni_i8i8_pooling_fwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_pooling_fwd_pd_t {
+ using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", isa, ""),
+ jit_uni_i8i8_pooling_fwd_t<isa>);
+
+ status_t init() {
+ bool ok = true
+ && mayiuse(isa)
+ && ndims() == 4
+ && set_default_params() == status::success
+ && desc()->prop_kind == prop_kind::forward_inference
+ && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
+ alg_kind::pooling_avg_include_padding,
+ alg_kind::pooling_avg_exclude_padding)
+ && utils::one_of(src_md()->data_type, data_type::s32,
+ data_type::s8, data_type::u8)
+ && src_md()->data_type == dst_md()->data_type
+ && attr()->has_default_values()
+ && memory_desc_matches_tag(*src_md(), format_tag::nhwc)
+ && memory_desc_matches_tag(*dst_md(), format_tag::nhwc);
+ if (!ok) return status::unimplemented;
+
+ return jit_conf();
+ }
+
+ jit_pool_conf_t jpp_;
+
+ protected:
+ status_t jit_conf();
+ };
+
+ jit_uni_i8i8_pooling_fwd_t(const pd_t *apd);
+ ~jit_uni_i8i8_pooling_fwd_t();
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_uni_i8i8_pooling_fwd_ker_t<isa> *ker_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp
new file mode 100644
index 0000000000..2c5a8e8973
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.cpp
@@ -0,0 +1,305 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_uni_lrn_kernel_f32.hpp"
+#include "jit_uni_lrn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+
+template <cpu_isa_t isa>
+jit_uni_lrn_fwd_t<isa>::jit_uni_lrn_fwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd), ker_(nullptr)
+ , ker_first_(nullptr), ker_last_(nullptr)
+{
+ using namespace alg_kind;
+
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const int ls = pd()->desc()->local_size;
+ float A = pd()->desc()->lrn_alpha / ls;
+ float K = pd()->desc()->lrn_k;
+
+ auto pk = pd()->desc()->prop_kind;
+ auto ak = pd()->desc()->alg_kind;
+ auto dat_tag = pd()->dat_tag_;
+
+ if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) {
+ ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw8c_across(H, W, 0), A, K, pk);
+ ker_first_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw8c_across(H, W, -1), A, K, pk);
+ ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw8c_across(H, W, +1), A, K, pk);
+ } else if (dat_tag == nChw8c && ak == lrn_within_channel) {
+ /* within channel, local_size (x) local_size */
+ A /= ls; /* XXX: why? */
+ ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw8c_within(H, W, ls), A, K, pk);
+ } else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) {
+ ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw_across(C, H*W, 0), A, K, pk);
+ int remind = (H*W) % VECTOR_LENGTH;
+ if (remind != 0) {
+ ker_last_ = new jit_uni_lrn_fwd_kernel_f32<isa>(
+ nchw_across(C, H*W, remind), A, K, pk);
+ }
+ } else if (true /* XXX: why */) {
+ ker_ = new jit_uni_lrn_fwd_kernel_f32<isa>(nhwc_across(C), A, K, pk);
+ }
+}
+
+template <cpu_isa_t isa>
+jit_uni_lrn_fwd_t<isa>::~jit_uni_lrn_fwd_t()
+{ delete ker_; delete ker_first_; delete ker_last_; }
+
+template <cpu_isa_t isa>
+void jit_uni_lrn_fwd_t<isa>::execute_forward(const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(data_t *, MKLDNN_ARG_WORKSPACE);
+
+ const int N = pd()->MB();
+ const int C = pd()->C();
+ const int HW = pd()->H() * pd()->W();
+ const int ls = pd()->desc()->local_size;
+
+ auto ak = pd()->desc()->alg_kind;
+ auto dat_tag = pd()->dat_tag_;
+
+ if (dat_tag == nChw8c && ls == 5 && ak == lrn_across_channels) {
+ parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
+ jit_args_fwd_t args;
+ args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ if (c8 == 0)
+ (*ker_first_)(&args);
+ else if (c8 == C / VECTOR_LENGTH - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ });
+ }
+ else if (dat_tag == nChw8c && ak == lrn_within_channel) {
+ parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
+ jit_args_fwd_t args;
+ args.src = &src[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ args.dst = &dst[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ args.scratch = &ws[n*HW*C + c8 * HW * VECTOR_LENGTH];
+ (*ker_)(&args);
+ });
+ }
+ else if (dat_tag == nchw && ls == 5 && ak == lrn_across_channels) {
+ parallel_nd(N, (HW + VECTOR_LENGTH - 1) / VECTOR_LENGTH,
+ [&](int n, int hw8) {
+ jit_args_fwd_t args;
+ args.src = &src[n*HW*C + hw8 * VECTOR_LENGTH];
+ args.dst = &dst[n*HW*C + hw8 * VECTOR_LENGTH];
+ args.scratch = &ws[n*HW*C + hw8 * VECTOR_LENGTH];
+ if ((hw8 + 1)*VECTOR_LENGTH > HW)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ });
+ }
+ else { // nhwc
+ parallel_nd(N, HW, [&](int n, int hw) {
+ jit_args_fwd_t args;
+ args.src = &src[n*HW*C + hw * C];
+ args.dst = &dst[n*HW*C + hw * C];
+ args.scratch = &ws[n*HW*C + hw * C];
+ (*ker_)(&args);
+ });
+ }
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_lrn_fwd_t<isa>::pd_t::init() {
+ using namespace prop_kind;
+ using namespace alg_kind;
+
+ const memory_desc_wrapper data_d(src_md());
+ bool ok = true
+ && mayiuse(isa)
+ && is_fwd()
+ && everyone_is(data_type::f32, data_d.data_type())
+ && !has_zero_dim_memory()
+ && data_d.ndims() == 4
+ && data_d.dims()[1] % VECTOR_LENGTH == 0
+ && data_d.dims()[1] >= 2 * VECTOR_LENGTH
+ && desc()->lrn_beta == 0.75
+ && attr()->has_default_values();
+ if (!ok) return unimplemented;
+
+ if (desc_.prop_kind == forward_training) ws_md_ = *src_md();
+
+ dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c, nchw, nhwc);
+
+ bool args_ok_across = true
+ && desc()->alg_kind == lrn_across_channels
+ && desc()->local_size == 5
+ && one_of(dat_tag_, nChw8c, nchw, nhwc);
+
+ const int jit_max_local_size = 5; // bigger size triggers too big code size
+ bool args_ok_within = true
+ && desc()->alg_kind == lrn_within_channel
+ && desc()->local_size <= ( jit_max_local_size <= MAX_LOCAL_SIZE
+ ? jit_max_local_size : MAX_LOCAL_SIZE)
+ && data_d.dims()[2] >= desc()->local_size
+ && data_d.dims()[3] >= desc()->local_size
+ && one_of(dat_tag_, nChw8c);
+
+ return args_ok_across || args_ok_within ? success : unimplemented;
+}
+
+template <cpu_isa_t isa>
+jit_uni_lrn_bwd_t<isa>::jit_uni_lrn_bwd_t(const pd_t *apd)
+ : cpu_primitive_t(apd)
+ , ker_(nullptr), ker_first_(nullptr), ker_last_(nullptr)
+{
+ using namespace alg_kind;
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const int ls = pd()->desc()->local_size;
+ float A = pd()->desc()->lrn_alpha / ls;
+ float B = pd()->desc()->lrn_beta;
+
+ int use_h_parallelizm = 0;// XXX
+ if (C / VECTOR_LENGTH == 1) {
+ ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
+ nchw8c_across(H, W, 3), A, B, use_h_parallelizm);
+ }
+ else {
+ ker_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
+ nchw8c_across(H, W, 0), A, B, use_h_parallelizm);
+ ker_first_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
+ nchw8c_across(H, W, -1), A, B, use_h_parallelizm);
+ ker_last_ = new jit_uni_lrn_bwd_kernel_f32<isa>(
+ nchw8c_across(H, W, +1), A, B, use_h_parallelizm);
+ }
+}
+
+template <cpu_isa_t isa>
+jit_uni_lrn_bwd_t<isa>::~jit_uni_lrn_bwd_t()
+{
+ delete ker_; delete ker_first_; delete ker_last_;
+}
+
+template <cpu_isa_t isa>
+void jit_uni_lrn_bwd_t<isa>::execute_backward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto ws = CTX_IN_MEM(const data_t *, MKLDNN_ARG_WORKSPACE);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const int N = pd()->MB();
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+
+ int use_h_parallelizm = 0; // XXX
+ if (use_h_parallelizm) {
+ parallel_nd(N, C / VECTOR_LENGTH, H, [&](int n, int c8, int h) {
+ auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH
+ + h*W*VECTOR_LENGTH;
+ jit_args_bwd_t args;
+ args.src = &src[offset];
+ args.diff_dst = &diff_dst[offset];
+ args.scratch = &ws[offset];
+ args.diff_src = &diff_src[offset];
+ if (C / VECTOR_LENGTH == 1)
+ (*ker_)(&args);
+ else if (c8 == 0)
+ (*ker_first_)(&args);
+ else if (c8 == C / VECTOR_LENGTH - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ });
+ }
+ else {
+ parallel_nd(N, C / VECTOR_LENGTH, [&](int n, int c8) {
+ auto offset = n*C*H*W + c8*H*W*VECTOR_LENGTH;
+ jit_args_bwd_t args;
+ args.src = &src[offset];
+ args.diff_dst = &diff_dst[offset];
+ args.scratch = &ws[offset];
+ args.diff_src = &diff_src[offset];
+ if (C / VECTOR_LENGTH == 1)
+ (*ker_)(&args);
+ else if (c8 == 0)
+ (*ker_first_)(&args);
+ else if (c8 == C / VECTOR_LENGTH - 1)
+ (*ker_last_)(&args);
+ else
+ (*ker_)(&args);
+ });
+ }
+}
+
+template <cpu_isa_t isa>
+status_t jit_uni_lrn_bwd_t<isa>::pd_t::init() {
+ using namespace prop_kind;
+ using namespace alg_kind;
+
+ const memory_desc_wrapper data_d(src_md());
+ bool ok = true
+ && mayiuse(isa)
+ && !is_fwd()
+ && utils::everyone_is(data_type::f32, data_d.data_type())
+ && !has_zero_dim_memory()
+ && data_d.ndims() == 4
+ && data_d.dims()[1] % VECTOR_LENGTH == 0
+ && desc()->lrn_beta == 0.75
+ && attr()->has_default_values();
+ if (!ok) return unimplemented;
+
+ ws_md_ = *src_md();
+ if (!compare_ws(hint_fwd_pd_)) return unimplemented;
+
+ dat_tag_ = memory_desc_matches_one_of_tag(*src_md(), nChw8c);
+
+ bool args_ok_across = true
+ && desc()->alg_kind == lrn_across_channels
+ && desc()->local_size == 5
+ && utils::one_of(dat_tag_, nChw8c);
+
+ return args_ok_across ? success : unimplemented;
+}
+
+template struct jit_uni_lrn_fwd_t<sse42>;
+template struct jit_uni_lrn_fwd_t<avx2>;
+template struct jit_uni_lrn_bwd_t<avx2>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp
new file mode 100644
index 0000000000..333cd3396d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn.hpp
@@ -0,0 +1,103 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_UNI_LRN_HPP
+#define CPU_JIT_UNI_LRN_HPP
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_isa_traits.hpp"
+#include "cpu_lrn_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <cpu_isa_t isa> struct jit_uni_lrn_fwd_kernel_f32;
+template <cpu_isa_t isa> struct jit_uni_lrn_bwd_kernel_f32;
+
+template <cpu_isa_t isa>
+struct jit_uni_lrn_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_lrn_fwd_pd_t {
+ using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", isa, ""),
+ jit_uni_lrn_fwd_t<isa>);
+
+ status_t init();
+
+ format_tag_t dat_tag_;
+ };
+
+ jit_uni_lrn_fwd_t(const pd_t *apd);
+ ~jit_uni_lrn_fwd_t();
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_uni_lrn_fwd_kernel_f32<isa> *ker_, *ker_first_, *ker_last_;
+};
+
+template <cpu_isa_t isa>
+struct jit_uni_lrn_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_lrn_bwd_pd_t {
+ using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", isa, ""),
+ jit_uni_lrn_bwd_t<isa>);
+
+ status_t init();
+
+ format_tag_t dat_tag_;
+ };
+
+ jit_uni_lrn_bwd_t(const pd_t *apd);
+ ~jit_uni_lrn_bwd_t();
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ jit_uni_lrn_bwd_kernel_f32<isa> *ker_, *ker_first_, *ker_last_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp
new file mode 100644
index 0000000000..89af47272c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.cpp
@@ -0,0 +1,1487 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "jit_uni_lrn_kernel_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace Xbyak;
+
+//////////////////////////////////////////////////////////////////////////////
+// forward kernel
+template<cpu_isa_t isa>
+void jit_uni_lrn_fwd_kernel_f32<isa>::within_body(
+ int hoff, int Hoff, int woff, int Woff, int stride,
+ Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2,
+ prop_kind_t pk)
+{
+ vxorps(ysum, ysum, ysum);
+ for (int i = hoff; i <= Hoff; ++i)
+ {
+ for (int j = woff; j <= Woff; ++j)
+ {
+ if (i == 0 && j == 0)
+ {
+ vmovups(ydst, ptr[src]);
+ vfmadd231ps(ysum, ydst, ydst);
+ }
+ else
+ {
+ vmovups(ytmp, ptr[src + (i*stride + j)*VECTOR_LENGTH*4]);
+ vfmadd231ps(ysum, ytmp, ytmp);
+ }
+ }
+ }
+ vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk
+ vmovaps(ytmp, ysum);
+ if (pk != prop_kind::forward_inference)
+ vmovups(ptr[scratch], ytmp);
+ vmulps(ysum2, ysum, ysum);
+ vmulps(ysum, ysum, ysum2); // ysum = (ysum*yalpha+yk)^3;
+ vsqrtps(ysum, ysum);
+ vsqrtps(ysum, ysum); // ysum = (ysum*yalpha+yk)^0.75
+ vdivps(ydst, ydst, ysum); // ydst <- ydst / ysum
+ vmovups(ptr[dst], ydst);
+ add(src, 32);
+ add(dst, 32);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, 32);
+}
+
+template<cpu_isa_t isa>
+void jit_uni_lrn_fwd_kernel_f32<isa>::within_body_sse42(
+ int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk)
+{
+ Xbyak::Xmm xtmp_lo = xmm12;
+ Xbyak::Xmm xtmp_hi = xmm13;
+ Xbyak::Xmm xsum_lo = xmm8;
+ Xbyak::Xmm xsum_hi = xmm9;
+ Xbyak::Xmm xdst_lo = xmm10;
+ Xbyak::Xmm xdst_hi = xmm11;
+ Xbyak::Xmm xsum2_lo = xmm14;
+ Xbyak::Xmm xsum2_hi = xmm15;
+
+ xorps(xsum_lo, xsum_lo);
+ xorps(xsum_hi, xsum_hi);
+ for (int i = hoff; i <= Hoff; ++i)
+ {
+ for (int j = woff; j <= Woff; ++j)
+ {
+ if (i == 0 && j == 0)
+ {
+ movups(xdst_lo, ptr[src]);
+ movups(xdst_hi, ptr[src + 4 * sizeof(float)]);
+ mulps(xdst_lo, xdst_lo);
+ mulps(xdst_hi, xdst_hi);
+ addps(xsum_lo, xdst_lo);
+ addps(xsum_hi, xdst_hi);
+ }
+ else
+ {
+ movups(xtmp_lo, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4]);
+ movups(xtmp_hi, ptr[src + (i*stride + j)*VECTOR_LENGTH * 4 + 4 * sizeof(float)]);
+ mulps(xtmp_lo, xtmp_lo);
+ mulps(xtmp_hi, xtmp_hi);
+ addps(xsum_lo, xtmp_lo);
+ addps(xsum_hi, xtmp_hi);
+ }
+ }
+ }
+ mulps(xsum_lo, xalpha);
+ mulps(xsum_hi, xalpha);
+ addps(xsum_lo, xk);
+ addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk
+ movaps(xtmp_lo, xsum_lo);
+ movaps(xtmp_hi, xsum_hi);
+ if (pk != prop_kind::forward_inference) {
+ movups(ptr[scratch], xtmp_lo);
+ movups(ptr[scratch + 4 * sizeof(float)], xtmp_hi);
+ }
+ movaps(xsum2_lo, xsum_lo);
+ movaps(xsum2_hi, xsum_hi);
+ mulps(xsum2_lo, xsum_lo);
+ mulps(xsum2_hi, xsum_hi);
+ mulps(xsum_lo, xsum2_lo);
+ mulps(xsum_hi, xsum2_hi); // xsum = (xsum*xalpha+xk)^3;
+
+ sqrtps(xsum_lo, xsum_lo);
+ sqrtps(xsum_hi, xsum_hi);
+ sqrtps(xsum_lo, xsum_lo);
+ sqrtps(xsum_hi, xsum_hi); // xsum = (xsum*xalpha+xk)^0.75
+
+ movups(xdst_lo, ptr[src]);
+ movups(xdst_hi, ptr[src + 4 * sizeof(float)]);
+ divps(xdst_lo, xsum_lo);
+ divps(xdst_hi, xsum_hi); // xdst <- xdst / xsum
+
+ movups(ptr[dst], xdst_lo);
+ movups(ptr[dst + 4 * sizeof(float)], xdst_hi);
+ add(src, 32);
+ add(dst, 32);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, 32);
+}
+
+template <cpu_isa_t isa>
+jit_uni_lrn_fwd_kernel_f32<isa>::jit_uni_lrn_fwd_kernel_f32(
+ const struct nchw8c_within &J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void *code_ptr,
+ size_t code_size)
+ : jit_generator(code_ptr, code_size)
+ , alpha(A), k(K)
+{
+ Xbyak::Reg64 h = r9;
+ Xbyak::Reg64 w = r10;
+ Vmm ysum = Vmm(isa == avx2 ? 9 : 9);
+ Vmm ysum2 = Vmm(isa == avx2 ? 10 : 10);
+ Vmm ydst = Vmm(isa == avx2 ? 11 : 11);
+ Vmm ytmp = Vmm(isa == avx2 ? 12 : 12);
+
+ this->preamble();
+
+ mov(src, ptr[this->param1 + 0]);
+ mov(dst, ptr[this->param1 + 8]);
+ if (pk != prop_kind::forward_inference)
+ mov(scratch, ptr[this->param1 + 16]);
+
+ mov(imm_addr64, float2int(this->alpha));
+ movq(xalpha, imm_addr64);
+ if (isa == avx2) {
+ vbroadcastss(yalpha, xalpha);
+ } else {
+ shufps(xalpha, xalpha, 0);
+ }
+
+ mov(imm_addr64, float2int(this->k));
+ movq(xk, imm_addr64);
+ if (isa == avx2) {
+ vbroadcastss(yk, xk);
+ } else {
+ shufps(xk, xk, 0);
+ }
+
+ int s2 = (J.size - 1) / 2, S2 = J.size - s2 - 1;
+
+ for (int i = 0; i < s2; ++i)
+ {
+ Label label_t;
+ for (int j = 0; j < s2; ++j) {
+ if (isa == avx2) {
+ within_body(-i, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
+ }
+ else {
+ within_body_sse42(-i, S2, -j, S2, J.W, pk);
+ }
+ }
+ mov(w, J.W - J.size + 1);
+ L(label_t);
+ if (isa == avx2) {
+ within_body(-i, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
+ } else {
+ within_body_sse42(-i, S2, -s2, S2, J.W, pk);
+ }
+ dec(w);
+ cmp(w, 0);
+ jne(label_t, T_NEAR);
+ for (int j = J.W - S2; j < J.W; ++j) {
+ if (isa == avx2) {
+ within_body(-i, S2, -s2, J.W - 1 - j, J.W,
+ ysum, ydst, ytmp, ysum2, pk);
+ } else {
+ within_body_sse42(-i, S2, -s2, J.W - 1 - j, J.W, pk);
+ }
+ }
+ }
+
+ mov(h, J.H - J.size + 1);
+ Label lrn_loop_h;
+ L(lrn_loop_h);
+ for (int j = 0; j < s2; ++j) {
+ if (isa == avx2) {
+ within_body(-s2, S2, -j, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
+ } else {
+ within_body_sse42(-s2, S2, -j, S2, J.W, pk);
+ }
+ }
+ mov(w, J.W - J.size + 1);
+ Label lrn_loop_w;
+ L(lrn_loop_w);
+ if (isa == avx2) {
+ within_body(-s2, S2, -s2, S2, J.W, ysum, ydst, ytmp, ysum2, pk);
+ } else {
+ within_body_sse42(-s2, S2, -s2, S2, J.W, pk);
+ }
+ dec(w);
+ cmp(w, 0);
+ jne(lrn_loop_w, T_NEAR);
+ for (int j = J.W - S2; j < J.W; ++j) {
+ if (isa == avx2) {
+ within_body(-s2, S2, -s2, J.W - 1 - j, J.W,
+ ysum, ydst, ytmp, ysum2, pk);
+ } else {
+ within_body_sse42(-s2, S2, -s2, J.W - 1 - j, J.W, pk);
+ }
+ }
+ dec(h);
+ cmp(h, 0);
+ jne(lrn_loop_h, T_NEAR);
+
+ for (int i = J.H - S2; i < J.H; ++i)
+ {
+ for (int j = 0; j < s2; ++j) {
+ if (isa == avx2) {
+ within_body(-s2, J.H - 1 - i, -j, S2, J.W,
+ ysum, ydst, ytmp, ysum2, pk);
+ } else {
+ within_body_sse42(-s2, J.H - 1 - i, -j, S2, J.W, pk);
+ }
+ }
+
+ mov(w, J.W - J.size + 1);
+ Label label_b;
+ L(label_b);
+ if (isa == avx2) {
+ within_body(-s2, J.H - 1 - i, -s2, S2, J.W,
+ ysum, ydst, ytmp, ysum2, pk);
+ } else {
+ within_body_sse42(-s2, J.H - 1 - i, -s2, S2, J.W, pk);
+ }
+ dec(w);
+ cmp(w, 0);
+ jne(label_b, T_NEAR);
+
+ for (int j = J.W - S2; j < J.W; ++j) {
+ if (isa == avx2) {
+ within_body(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W,
+ ysum, ydst, ytmp, ysum2, pk);
+ } else {
+ within_body_sse42(-s2, J.H - 1 - i, -s2, J.W - 1 - j, J.W, pk);
+ }
+ }
+ }
+
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+}
+
+template<>
+jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
+ const struct nchw8c_across &J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void *code_ptr,
+ size_t code_size)
+ : jit_generator(code_ptr, code_size)
+ , alpha(A), k(K)
+{
+ Xbyak::Reg64 t = rsp;
+ Xbyak::Reg64 hw = r9;
+ Xbyak::Xmm xsrc_prev = xmm2;
+ Xbyak::Ymm ysrc = ymm3;
+ Xbyak::Ymm yc = ymm3;
+ Xbyak::Xmm xsrc_next = xmm4;
+ Xbyak::Ymm ya = ymm5;
+ Xbyak::Ymm yb = ymm6;
+ Xbyak::Ymm yd = ymm7;
+ Xbyak::Ymm ye = ymm8;
+ Xbyak::Ymm ysum = ymm9;
+ Xbyak::Ymm ysum2 = ymm10;
+ Xbyak::Ymm ydst = ymm11;
+ Xbyak::Ymm ybase = ymm12;
+
+ this->preamble();
+
+ mov(src, ptr[this->param1 + 0]);
+ mov(dst, ptr[this->param1 + 8]);
+ if (pk != prop_kind::forward_inference)
+ mov(scratch, ptr[this->param1 + 16]);
+ sub(t, 64);
+ mov(imm_addr64, float2int(this->alpha));
+ movq(xalpha, imm_addr64);
+ vbroadcastss(yalpha, xalpha);
+
+ mov(imm_addr64, float2int(this->k));
+ movq(xk, imm_addr64);
+ vbroadcastss(yk, xk);
+
+ if (J.version == -1)
+ {
+ vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
+ vmovups(ptr[t + 0], xsrc_prev);
+ }
+ if (J.version == +1)
+ {
+ vxorps(xsrc_next, xsrc_next, xsrc_next);
+ vmovups(ptr[t + 48], xsrc_next);
+ }
+
+ mov(hw, J.H*J.W);
+
+ Label lrn_loop;
+ L(lrn_loop);
+
+ if (J.version != -1) vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
+ vmovups(ysrc, ptr[src]);
+ if (J.version != +1) vmovups(xsrc_next, ptr[src + J.H*J.W * 32]);
+
+ if (J.version != -1) vmovups(ptr[t + 0], xsrc_prev);
+ vmovups(ptr[t + 16], ysrc);
+ if (J.version != +1) vmovups(ptr[t + 48], xsrc_next);
+
+ vmovups(ya, ptr[t + 16 - 8]);
+ vmovups(yb, ptr[t + 16 - 4]);
+ vmovups(yd, ptr[t + 16 + 4]);
+ vmovups(ye, ptr[t + 16 + 8]);
+ vmulps(ysum, yc, yc);
+ vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya*ya
+ vfmadd231ps(ysum, yb, yb);
+ vfmadd231ps(ysum, yd, yd);
+ vfmadd231ps(ysum, ye, ye);
+ vfmadd132ps(ysum, yk, yalpha); // ysum <- ysum*yalpha+yk
+
+ vmovaps(ybase, ysum);
+ if (pk != prop_kind::forward_inference)
+ vmovups(ptr[scratch], ybase);
+ vmulps(ysum2, ysum, ysum);
+ vmulps(ysum, ysum, ysum2); // ysum = ybase^3;
+ vsqrtps(ysum, ysum);
+ vsqrtps(ysum, ysum); // ysum = ybase^0.75
+ vdivps(ydst, ysrc, ysum); // ydst = ysrc / ysum
+ vmovups(ptr[dst], ydst);
+
+ add(src, 32);
+ add(dst, 32);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, 32);
+ dec(hw);
+ cmp(hw, 0);
+ jne(lrn_loop, T_NEAR);
+
+ add(t, 64);
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+}
+
+template<>
+jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
+ const struct nchw8c_across &J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void *code_ptr,
+ size_t code_size)
+ : jit_generator(code_ptr, code_size)
+ , alpha(A), k(K)
+{
+ Xbyak::Reg64 t = rsp;
+ Xbyak::Reg64 hw = r9;
+
+ Xbyak::Xmm xsrc_lo = xmm2;
+ Xbyak::Xmm xsrc_hi = xmm3;
+ Xbyak::Xmm xc_lo = xmm4;
+ Xbyak::Xmm xc_hi = xmm5;
+ Xbyak::Xmm xsum_lo = xc_lo;
+ Xbyak::Xmm xsum_hi = xc_hi;
+ Xbyak::Xmm xsrc_prev = xmm6;
+ Xbyak::Xmm xsrc_next = xmm7;
+ Xbyak::Xmm xa_lo = xmm8;
+ Xbyak::Xmm xa_hi = xmm9;
+ Xbyak::Xmm xb_lo = xmm10;
+ Xbyak::Xmm xb_hi = xmm11;
+ Xbyak::Xmm xd_lo = xmm12;
+ Xbyak::Xmm xd_hi = xmm13;
+ Xbyak::Xmm xe_lo = xmm14;
+ Xbyak::Xmm xe_hi = xmm15;
+ Xbyak::Xmm xbase_lo = xmm14;
+ Xbyak::Xmm xbase_hi = xmm15;
+
+ this->preamble();
+
+ mov(src, ptr[this->param1 + 0]);
+ mov(dst, ptr[this->param1 + 8]);
+ if (pk != prop_kind::forward_inference)
+ mov(scratch, ptr[this->param1 + 16]);
+ sub(t, 64);
+ mov(imm_addr64, float2int(this->alpha));
+ movq(xalpha, imm_addr64);
+ shufps(xalpha, xalpha, 0);
+
+ mov(imm_addr64, float2int(this->k));
+ movq(xk, imm_addr64);
+ shufps(xk, xk, 0);
+
+ if (J.version == -1)
+ {
+ xorps(xsrc_prev, xsrc_prev);
+ movups(ptr[t + 0], xsrc_prev);
+ }
+ if (J.version == +1)
+ {
+ xorps(xsrc_next, xsrc_next);
+ movups(ptr[t + 48], xsrc_next);
+ }
+
+ mov(hw, J.H*J.W);
+ Label lrn_loop;
+ L(lrn_loop);
+
+ if (J.version != -1) movups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
+ movups(xsrc_lo, ptr[src]);
+ movups(xsrc_hi, ptr[src + 4 * sizeof(float)]);
+ if (J.version != +1) movups(xsrc_next, ptr[src + J.H*J.W * 32]);
+
+ if (J.version != -1) movups(ptr[t + 0], xsrc_prev);
+ movups(ptr[t + 16], xsrc_lo);
+ movups(ptr[t + 16 + 4 * sizeof(float)], xsrc_hi);
+ if (J.version != +1) movups(ptr[t + 48], xsrc_next);
+
+ movups(xa_lo, ptr[t + 16 - 8]);
+ movups(xa_hi, ptr[t + 16 - 8 + 4 * sizeof(float)]);
+ movups(xb_lo, ptr[t + 16 - 4]);
+ movups(xb_hi, ptr[t + 16 - 4 + 4 * sizeof(float)]);
+ movups(xd_lo, ptr[t + 16 + 4]);
+ movups(xd_hi, ptr[t + 16 + 4 + 4 * sizeof(float)]);
+ movups(xe_lo, ptr[t + 16 + 8]);
+ movups(xe_hi, ptr[t + 16 + 8 + 4 * sizeof(float)]);
+ movaps(xc_lo, xsrc_lo);
+ movaps(xc_hi, xsrc_hi);
+ mulps(xsum_lo, xc_lo);
+ mulps(xsum_hi, xc_hi);
+ mulps(xa_lo, xa_lo);
+ mulps(xa_hi, xa_hi);
+ addps(xsum_lo, xa_lo);
+ addps(xsum_hi, xa_hi); // xsum <- xsum + xa*xa
+ mulps(xb_lo, xb_lo);
+ mulps(xb_hi, xb_hi);
+ addps(xsum_lo, xb_lo);
+ addps(xsum_hi, xb_hi);
+ mulps(xd_lo, xd_lo);
+ mulps(xd_hi, xd_hi);
+ addps(xsum_lo, xd_lo);
+ addps(xsum_hi, xd_hi);
+ mulps(xe_lo, xe_lo);
+ mulps(xe_hi, xe_hi);
+ addps(xsum_lo, xe_lo);
+ addps(xsum_hi, xe_hi);
+
+ mulps(xsum_lo, xalpha);
+ mulps(xsum_hi, xalpha);
+ addps(xsum_lo, xk);
+ addps(xsum_hi, xk); // xsum <- xsum*xalpha+xk
+
+ movaps(xbase_lo, xsum_lo);
+ movaps(xbase_hi, xsum_hi);
+ if (pk != prop_kind::forward_inference) {
+ movups(ptr[scratch], xbase_lo);
+ movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
+ }
+ mulps(xsum_lo, xsum_lo);
+ mulps(xsum_hi, xsum_hi);
+ mulps(xsum_lo, xbase_lo);
+ mulps(xsum_hi, xbase_hi); // xsum = xbase^3;
+ sqrtps(xsum_lo, xsum_lo);
+ sqrtps(xsum_hi, xsum_hi);
+ sqrtps(xsum_lo, xsum_lo);
+ sqrtps(xsum_hi, xsum_hi); // xsum = xbase^0.75
+ divps(xsrc_lo, xsum_lo);
+ divps(xsrc_hi, xsum_hi); // xdst = xsrc / xsum
+ movups(ptr[dst], xsrc_lo);
+ movups(ptr[dst + 4 * sizeof(float)], xsrc_hi);
+
+ add(src, 32);
+ add(dst, 32);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, 32);
+ dec(hw);
+ cmp(hw, 0);
+ jne(lrn_loop, T_NEAR);
+
+ add(t, 64);
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+}
+
+template<>
+jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
+ const struct nhwc_across &J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void *code_ptr,
+ size_t code_size)
+ : jit_generator(code_ptr, code_size)
+ , alpha(A), k(K)
+{
+ static const uint32_t mask[] = {
+ 0, 0, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
+ 0x80000000, 0x80000000, 0x80000000, 0, 0
+ };
+
+ Xbyak::Reg64 c = r9;
+ Xbyak::Ymm ya = ymm2;
+ Xbyak::Ymm yb = ymm3;
+ Xbyak::Ymm yc = ymm4;
+ Xbyak::Ymm yd = ymm5;
+ Xbyak::Ymm ye = ymm6;
+ Xbyak::Ymm ysum = ymm7;
+ Xbyak::Ymm ydst = ymm8;
+ Xbyak::Ymm ybase = ymm9;
+ Xbyak::Ymm ymask = ymm10;
+
+ this->preamble();
+
+ mov(src, ptr[this->param1 + 0]);
+ mov(dst, ptr[this->param1 + 8]);
+ if (pk != prop_kind::forward_inference)
+ mov(scratch, ptr[this->param1 + 16]);
+ mov(imm_addr64, float2int(this->alpha));
+ movq(xalpha, imm_addr64);
+ vbroadcastss(yalpha, xalpha);
+
+ mov(imm_addr64, float2int(this->k));
+ movq(xk, imm_addr64);
+ vbroadcastss(yk, xk);
+
+ vxorps(ysum, ysum, ysum);
+
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[0]));
+ vmovups(ymask, ptr[imm_addr64]);
+ vmaskmovps(ya, ymask, ptr[src - 8]);
+ vfmadd231ps(ysum, ya, ya); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
+
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[1]));
+ vmovups(ymask, ptr[imm_addr64]);
+ vmaskmovps(yb, ymask, ptr[src - 4]);
+ vfmadd231ps(ysum, yb, yb);
+
+ mov(c, J.C / 8 - 1);
+ Label lrn_loop;
+ L(lrn_loop);
+
+ vmovups(yc, ptr[src]);
+ vmovups(yd, ptr[src + 4]);
+ vmovups(ye, ptr[src + 8]);
+ vfmadd231ps(ysum, yc, yc);
+ vfmadd231ps(ysum, yd, yd);
+ vfmadd231ps(ysum, ye, ye);
+
+ vmovups(ydst, ysum);
+ vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
+
+ vmovaps(ybase, ydst);
+ if (pk != prop_kind::forward_inference)
+ vmovups(ptr[scratch], ybase);
+ vmulps(ydst, ydst, ydst);
+ vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
+ vsqrtps(ydst, ydst);
+ vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
+
+ vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
+ vmovups(ptr[dst], ydst);
+
+ vxorps(ysum, ysum, ysum);
+
+ add(src, 32);
+ add(dst, 32);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, 32);
+
+ vmovups(ya, ptr[src - 8]);
+ vfmadd231ps(ysum, ya, ya);
+ vmovups(yb, ptr[src - 4]);
+ vfmadd231ps(ysum, yb, yb);
+
+ dec(c);
+ cmp(c, 0);
+ jne(lrn_loop, T_NEAR);
+
+ vmovups(yc, ptr[src]);
+ vfmadd231ps(ysum, yc, yc);
+
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[2]));
+ vmovups(ymask, ptr[imm_addr64]);
+ vmaskmovps(yd, ymask, ptr[src + 4]);
+ vfmadd231ps(ysum, yd, yd); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
+
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[3]));
+ vmovups(ymask, ptr[imm_addr64]);
+ vmaskmovps(ye, ymask, ptr[src + 8]);
+ vfmadd231ps(ysum, ye, ye);
+
+ vmovups(ydst, ysum);
+ vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
+
+ vmovaps(ybase, ydst);
+ if (pk != prop_kind::forward_inference)
+ vmovups(ptr[scratch], ybase);
+ vmulps(ydst, ydst, ydst);
+ vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
+ vsqrtps(ydst, ydst);
+ vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
+ vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
+
+ vmovups(ptr[dst], ydst);
+
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+}
+
+template<>
+jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
+ const struct nhwc_across &J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void *code_ptr,
+ size_t code_size)
+ : jit_generator(code_ptr, code_size)
+ , alpha(A), k(K)
+{
+ static const uint32_t mask[] = {
+ 0, 0, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
+ 0xffffffff, 0xffffffff, 0xffffffff, 0, 0
+ };
+
+ static uint32_t store[] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+ };
+ Xbyak::Reg64 c = r9;
+
+ Xbyak::Xmm xdst_lo = xmm0;
+ Xbyak::Xmm xdst_hi = xmm1;
+ Xbyak::Xmm xa_lo = xmm2;
+ Xbyak::Xmm xa_hi = xmm3;
+ Xbyak::Xmm xb_lo = xmm2;
+ Xbyak::Xmm xb_hi = xmm3;
+ Xbyak::Xmm xc_lo = xmm4;
+ Xbyak::Xmm xc_hi = xmm5;
+ Xbyak::Xmm xd_lo = xmm6;
+ Xbyak::Xmm xd_hi = xmm7;
+ Xbyak::Xmm xe_lo = xmm8;
+ Xbyak::Xmm xe_hi = xmm9;
+ Xbyak::Xmm xsum_lo = xmm10;
+ Xbyak::Xmm xsum_hi = xmm11;
+ Xbyak::Xmm xmask_lo = xmm12;
+ Xbyak::Xmm xmask_hi = xmm13;
+ Xbyak::Xmm xbase_lo = xmm14;
+ Xbyak::Xmm xbase_hi = xmm15;
+
+ this->preamble();
+
+ mov(src, ptr[this->param1 + 0]);
+ mov(dst, ptr[this->param1 + 8]);
+ if (pk != prop_kind::forward_inference)
+ mov(scratch, ptr[this->param1 + 16]);
+ mov(imm_addr64, float2int(this->alpha));
+ movq(xalpha, imm_addr64);
+ shufps(xalpha, xalpha, 0);
+
+ mov(imm_addr64, float2int(this->k));
+ movq(xk, imm_addr64);
+ shufps(xk, xk, 0);
+
+ mov(store_addr, reinterpret_cast<size_t>(&store[0]));
+ and_(store_addr, -15);
+ movups(ptr[store_addr], xalpha);
+ movups(ptr[store_addr + 4 * sizeof(float)], xk);
+
+ xorps(xsum_lo, xsum_lo);
+ xorps(xsum_hi, xsum_hi);
+
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[0]));
+ movups(xmask_lo, ptr[imm_addr64]);
+ movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
+ movups(xa_lo, ptr[src - 8]);
+ movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]);
+ andps(xa_lo, xmask_lo);
+ andps(xa_hi, xmask_hi);
+ mulps(xa_lo, xa_lo);
+ mulps(xa_hi, xa_hi);
+ addps(xsum_lo, xa_lo);
+ addps(xsum_hi, xa_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
+
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[1]));
+ movups(xmask_lo, ptr[imm_addr64]);
+ movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
+ movups(xb_lo, ptr[src - 4]);
+ movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]);
+ andps(xb_lo, xmask_lo);
+ andps(xb_hi, xmask_hi);
+ mulps(xb_lo, xb_lo);
+ mulps(xb_hi, xb_hi);
+ addps(xsum_lo, xb_lo);
+ addps(xsum_hi, xb_hi);
+
+ mov(c, J.C / 8 - 1);
+ Label lrn_loop;
+ L(lrn_loop);
+
+ movups(xc_lo, ptr[src]);
+ movups(xc_hi, ptr[src + 4 * sizeof(float)]);
+ movups(xd_lo, ptr[src + 4]);
+ movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]);
+ movups(xe_lo, ptr[src + 8]);
+ movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]);
+ mulps(xc_lo, xc_lo);
+ mulps(xc_hi, xc_hi);
+ addps(xsum_lo, xc_lo);
+ addps(xsum_hi, xc_hi);
+ mulps(xd_lo, xd_lo);
+ mulps(xd_hi, xd_hi);
+ addps(xsum_lo, xd_lo);
+ addps(xsum_hi, xd_hi);
+ mulps(xe_lo, xe_lo);
+ mulps(xe_hi, xe_hi);
+ addps(xsum_lo, xe_lo);
+ addps(xsum_hi, xe_hi);
+
+ movaps(xdst_lo, xsum_lo);
+ movaps(xdst_hi, xsum_hi);
+ // xdst <- xsum*xalpha+xk
+ mulps(xdst_lo, ptr[store_addr]);
+ mulps(xdst_hi, ptr[store_addr]);
+ addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]);
+ addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]);
+
+ movaps(xbase_lo, xdst_lo);
+ movaps(xbase_hi, xdst_hi);
+ if (pk != prop_kind::forward_inference) {
+ movups(ptr[scratch], xbase_lo);
+ movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
+ }
+ mulps(xdst_lo, xdst_lo);
+ mulps(xdst_hi, xdst_hi);
+ mulps(xdst_lo, xbase_lo);
+ mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
+ sqrtps(xdst_lo, xdst_lo);
+ sqrtps(xdst_hi, xdst_hi);
+ sqrtps(xdst_lo, xdst_lo);
+ sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
+
+ movups(xc_lo, ptr[src]);
+ movups(xc_hi, ptr[src + 4 * sizeof(float)]);
+ divps(xc_lo, xdst_lo);
+ divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
+ movups(ptr[dst], xc_lo);
+ movups(ptr[dst + 4 * sizeof(float)], xc_hi);
+
+ xorps(xsum_lo, xsum_lo);
+ xorps(xsum_hi, xsum_hi);
+
+ add(src, 32);
+ add(dst, 32);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, 32);
+
+ movups(xa_lo, ptr[src - 8]);
+ movups(xa_hi, ptr[src - 8 + 4 * sizeof(float)]);
+ mulps(xa_lo, xa_lo);
+ mulps(xa_hi, xa_hi);
+ addps(xsum_lo, xa_lo);
+ addps(xsum_hi, xa_hi);
+ movups(xb_lo, ptr[src - 4]);
+ movups(xb_hi, ptr[src - 4 + 4 * sizeof(float)]);
+ mulps(xb_lo, xb_lo);
+ mulps(xb_hi, xb_hi);
+ addps(xsum_lo, xb_lo);
+ addps(xsum_hi, xb_hi);
+
+ dec(c);
+ cmp(c, 0);
+ jne(lrn_loop, T_NEAR);
+
+ movups(xc_lo, ptr[src]);
+ movups(xc_hi, ptr[src + 4 * sizeof(float)]);
+ mulps(xc_lo, xc_lo);
+ mulps(xc_hi, xc_hi);
+ addps(xsum_lo, xc_lo);
+ addps(xsum_hi, xc_hi);
+
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[2]));
+ movups(xmask_lo, ptr[imm_addr64]);
+ movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
+ movups(xd_lo, ptr[src + 4]);
+ movups(xd_hi, ptr[src + 4 + 4 * sizeof(float)]);
+ andps(xd_lo, xmask_lo);
+ andps(xd_hi, xmask_hi);
+ mulps(xd_lo, xd_lo);
+ mulps(xd_hi, xd_hi);
+ addps(xsum_lo, xd_lo);
+ addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
+
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[3]));
+ movups(xmask_lo, ptr[imm_addr64]);
+ movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
+ movups(xe_lo, ptr[src + 8]);
+ movups(xe_hi, ptr[src + 8 + 4 * sizeof(float)]);
+ andps(xe_lo, xmask_lo);
+ andps(xe_hi, xmask_hi);
+ mulps(xe_lo, xe_lo);
+ mulps(xe_hi, xe_hi);
+ addps(xsum_lo, xe_lo);
+ addps(xsum_hi, xe_hi);
+
+ movups(xdst_lo, xsum_lo);
+ movups(xdst_hi, xsum_hi);
+ // xdst <- xsum*xalpha+xk
+ mulps(xdst_lo, ptr[store_addr]);
+ mulps(xdst_hi, ptr[store_addr]);
+ addps(xdst_lo, ptr[store_addr + 4 * sizeof(float)]);
+ addps(xdst_hi, ptr[store_addr + 4 * sizeof(float)]);
+
+ movaps(xbase_lo, xdst_lo);
+ movaps(xbase_hi, xdst_hi);
+ if (pk != prop_kind::forward_inference) {
+ movups(ptr[scratch], xbase_lo);
+ movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
+ }
+ mulps(xdst_lo, xdst_lo);
+ mulps(xdst_hi, xdst_hi);
+ mulps(xdst_lo, xbase_lo);
+ mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
+ sqrtps(xdst_lo, xdst_lo);
+ sqrtps(xdst_hi, xdst_hi);
+ sqrtps(xdst_lo, xdst_lo);
+ sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
+ movups(xc_lo, ptr[src]);
+ movups(xc_hi, ptr[src + 4 * sizeof(float)]);
+ divps(xc_lo, xdst_lo);
+ divps(xc_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
+
+ movups(ptr[dst], xc_lo);
+ movups(ptr[dst + 4 * sizeof(float)], xc_hi);
+
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+}
+
+template<>
+void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body(
+ int tail, int HW, prop_kind_t pk,
+ Xbyak::Ymm ymask,
+ Xbyak::Ymm ya,
+ Xbyak::Ymm yb,
+ Xbyak::Ymm yc,
+ Xbyak::Ymm yd,
+ Xbyak::Ymm ye,
+ Xbyak::Ymm ysum) {}
+
+template<>
+void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body(
+ int tail, int HW, prop_kind_t pk,
+ Xbyak::Ymm ymask,
+ Xbyak::Ymm ya,
+ Xbyak::Ymm yb,
+ Xbyak::Ymm yc,
+ Xbyak::Ymm yd,
+ Xbyak::Ymm ye,
+ Xbyak::Ymm ysum)
+{
+ Xbyak::Ymm ydst = ymm14;
+ Xbyak::Ymm ybase = ymm15;
+
+ vfmadd231ps(ysum, ye, ye);
+
+ vmovups(ydst, ysum);
+ vfmadd132ps(ydst, yk, yalpha); // ydst <- ysum*yalpha+yk
+
+ vmovaps(ybase, ydst);
+ if (pk != prop_kind::forward_inference)
+ {
+ if (tail != 0)
+ vmaskmovps(ptr[scratch], ymask, ybase);
+ else
+ vmovups(ptr[scratch], ybase);
+ }
+ vmulps(ydst, ydst, ydst);
+ vmulps(ydst, ydst, ybase); // ydst = (ysum*yalpha+yk)^3;
+ vsqrtps(ydst, ydst);
+ vsqrtps(ydst, ydst); // ydst = (ysum*yalpha+yk)^0.75
+ vdivps(ydst, yc, ydst); // ydst = ysrc / (ysum*yalpha+yk)^0.75
+
+ if (tail != 0)
+ vmaskmovps(ptr[dst], ymask, ydst);
+ else
+ vmovups(ptr[dst], ydst);
+
+
+ vfnmadd231ps(ysum, ya, ya);
+ vmovups(ya, yb);
+ vmovups(yb, yc);
+ vmovups(yc, yd);
+ vmovups(yd, ye);
+}
+
+template<>
+void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_tail_sse42(
+ int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi)
+{}
+
+template<>
+void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_tail_sse42(
+ int tail, Xbyak::Reg64 reg_dst, Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi)
+{
+ Xbyak::Xmm xmm_tmp = xmm10;
+ movaps(xmm_tmp, xtail_lo);
+ size_t offset = 0;
+
+ if (tail > 4) {
+ movups(ptr[reg_dst], xtail_lo);
+ movaps(xmm_tmp, xtail_hi);
+ offset += 4 * sizeof(float);
+ tail -= 4;
+ }
+ movss(ptr[reg_dst + offset], xmm_tmp);
+ for (int i = 1; i < tail; i++)
+ {
+ psrldq(xmm_tmp, 4);
+ movss(ptr[reg_dst + offset + i * sizeof(float)], xmm_tmp);
+ }
+}
+
+template<>
+void jit_uni_lrn_fwd_kernel_f32<sse42>::nchw_body_sse42(
+ int tail, int HW, prop_kind_t pk,
+ Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
+ Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
+ Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi)
+{
+ Xbyak::Xmm xdst_lo = xmm0;
+ Xbyak::Xmm xdst_hi = xmm1;
+ Xbyak::Xmm xbase_lo = xmm6;
+ Xbyak::Xmm xbase_hi = xmm7;
+ Xbyak::Xmm xtmp_lo = xmm8;
+ Xbyak::Xmm xtmp_hi = xmm9;
+ Xbyak::Xmm xa_lo = xmm6;
+ Xbyak::Xmm xa_hi = xmm7;
+ Xbyak::Xmm xb_lo = xmm8;
+ Xbyak::Xmm xb_hi = xmm9;
+ Xbyak::Xmm xc_lo = xmm10;
+ Xbyak::Xmm xc_hi = xmm11;
+ Xbyak::Xmm xd_lo = xmm12;
+ Xbyak::Xmm xd_hi = xmm13;
+
+ // store xe
+ movaps(ptr[store_addr + 10 * 4 * sizeof(float)], xe_lo);
+ movaps(ptr[store_addr + 11 * 4 * sizeof(float)], xe_hi);
+
+ mulps(xe_lo, xe_lo);
+ mulps(xe_hi, xe_hi);
+ addps(xsum_lo, xe_lo);
+ addps(xsum_hi, xe_hi);
+
+ // xdst <- xsum*xalpha+xk
+ movaps(xdst_lo, xsum_lo);
+ movaps(xdst_hi, xsum_hi);
+ mulps(xdst_lo, ptr[store_addr + 0 * 4 * sizeof(float)]);
+ mulps(xdst_hi, ptr[store_addr + 0 * 4 * sizeof(float)]);
+ addps(xdst_lo, ptr[store_addr + 1 * 4 * sizeof(float)]);
+ addps(xdst_hi, ptr[store_addr + 1 * 4 * sizeof(float)]);
+
+ movaps(xbase_lo, xdst_lo);
+ movaps(xbase_hi, xdst_hi);
+ if (pk != prop_kind::forward_inference)
+ {
+ if (tail != 0) {
+ nchw_tail_sse42(tail, scratch, xbase_lo, xbase_hi);
+ }
+ else {
+ movups(ptr[scratch], xbase_lo);
+ movups(ptr[scratch + 4 * sizeof(float)], xbase_hi);
+ }
+ }
+ mulps(xdst_lo, xdst_lo);
+ mulps(xdst_hi, xdst_hi);
+ mulps(xdst_lo, xbase_lo);
+ mulps(xdst_hi, xbase_hi); // xdst = (xsum*xalpha+xk)^3;
+ sqrtps(xdst_lo, xdst_lo);
+ sqrtps(xdst_hi, xdst_hi);
+ sqrtps(xdst_lo, xdst_lo);
+ sqrtps(xdst_hi, xdst_hi); // xdst = (xsum*xalpha+xk)^0.75
+ movaps(xtmp_lo, ptr[store_addr + 6 * 4 * sizeof(float)]);
+ movaps(xtmp_hi, ptr[store_addr + 7 * 4 * sizeof(float)]);
+ divps(xtmp_lo, xdst_lo);
+ divps(xtmp_hi, xdst_hi); // xdst = xsrc / (xsum*xalpha+xk)^0.75
+ movaps(xdst_lo, xtmp_lo);
+ movaps(xdst_hi, xtmp_hi);
+
+ if (tail != 0) {
+ nchw_tail_sse42(tail, dst, xdst_lo, xdst_hi);
+ }
+ else {
+ movups(ptr[dst], xdst_lo);
+ movups(ptr[dst + 4 * sizeof(float)], xdst_hi);
+ }
+
+ movaps(xa_lo, ptr[store_addr + 2 * 4 * sizeof(float)]);
+ movaps(xa_hi, ptr[store_addr + 3 * 4 * sizeof(float)]);
+ mulps(xa_lo, xa_lo);
+ mulps(xa_hi, xa_hi);
+ subps(xsum_lo, xa_lo);
+ subps(xsum_hi, xa_hi);
+
+ // xa <- xb
+ movaps(xb_lo, ptr[store_addr + 4 * 4 * sizeof(float)]);
+ movaps(xb_hi, ptr[store_addr + 5 * 4 * sizeof(float)]);
+ movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xb_lo);
+ movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xb_hi);
+
+ // xb <- xc
+ movaps(xc_lo, ptr[store_addr + 6 * 4 * sizeof(float)]);
+ movaps(xc_hi, ptr[store_addr + 7 * 4 * sizeof(float)]);
+ movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xc_lo);
+ movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xc_hi);
+
+ // xc <- xd
+ movaps(xd_lo, ptr[store_addr + 8 * 4 * sizeof(float)]);
+ movaps(xd_hi, ptr[store_addr + 9 * 4 * sizeof(float)]);
+ movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xd_lo);
+ movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xd_hi);
+
+ // xd <- xe
+ movaps(xe_lo, ptr[store_addr + 10 * 4 * sizeof(float)]);
+ movaps(xe_hi, ptr[store_addr + 11 * 4 * sizeof(float)]);
+ movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xe_lo);
+ movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xe_hi);
+}
+
+template<>
+void jit_uni_lrn_fwd_kernel_f32<avx2>::nchw_body_sse42(
+ int tail, int HW, prop_kind_t pk,
+ Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
+ Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
+ Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi) {}
+
+template<>
+jit_uni_lrn_fwd_kernel_f32<avx2>::jit_uni_lrn_fwd_kernel_f32(
+ struct nchw_across J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void* code_ptr,
+ size_t code_size)
+ : jit_generator(code_ptr, code_size)
+ , alpha(A), k(K)
+{
+ static const uint32_t mask[] = {
+ 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000,
+ 0x80000000, 0x80000000, 0, 0, 0, 0, 0, 0, 0
+ };
+ Xbyak::Reg64 c = r10;
+ Xbyak::Ymm ymask = ymm2;
+ Xbyak::Ymm ye = ymm3;
+ Xbyak::Ymm ya = ymm4;
+ Xbyak::Ymm yb = ymm5;
+ Xbyak::Ymm yc = ymm6;
+ Xbyak::Ymm yd = ymm7;
+ Xbyak::Ymm ysum = ymm8;
+
+ this->preamble();
+
+ if (J.tail != 0)
+ {
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail]));
+ vmovups(ymask, ptr[imm_addr64]);
+ }
+ mov(imm_addr64, float2int(this->alpha));
+ movq(xalpha, imm_addr64);
+ vbroadcastss(yalpha, xalpha);
+
+ mov(imm_addr64, float2int(this->k));
+ movq(xk, imm_addr64);
+ vbroadcastss(yk, xk);
+
+ mov(src, ptr[this->param1 + 0]);
+ mov(dst, ptr[this->param1 + 8]);
+ if (pk != prop_kind::forward_inference)
+ mov(scratch, ptr[this->param1 + 16]);
+
+ vxorps(ya, ya, ya);
+ vxorps(yb, yb, yb);
+ if (J.tail != 0)
+ vmaskmovps(yc, ymask, ptr[src + J.HW * 0]);
+ else
+ vmovups(yc, ptr[src + J.HW * 0]);
+ if (J.tail != 0)
+ vmaskmovps(yd, ymask, ptr[src + J.HW * 4]);
+ else
+ vmovups(yd, ptr[src + J.HW * 4]);
+
+ vxorps(ysum, ysum, ysum);
+ vfmadd231ps(ysum, yc, yc); // ysum <- ysum + ya^2+yb^2+yc^2+yd^2+ye^2
+ vfmadd231ps(ysum, yd, yd);
+
+ mov(c, J.C - 2);
+ Label lrn_loop;
+ L(lrn_loop);
+
+ if (J.tail != 0)
+ vmaskmovps(ye, ymask, ptr[src + J.HW * 8]);
+ else
+ vmovups(ye, ptr[src + J.HW * 8]);
+
+ nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
+
+ add(src, J.HW * 4);
+ add(dst, J.HW * 4);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, J.HW * 4);
+ dec(c);
+ cmp(c, 0);
+ jne(lrn_loop, T_NEAR);
+
+ vxorps(ye, ye, ye);
+
+ nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
+ add(src, J.HW * 4);
+ add(dst, J.HW * 4);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, J.HW * 4);
+
+ nchw_body(J.tail, J.HW, pk, ymask, ya, yb, yc, yd, ye, ysum);
+
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+}
+
+template<>
+jit_uni_lrn_fwd_kernel_f32<sse42>::jit_uni_lrn_fwd_kernel_f32(
+ struct nchw_across J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void* code_ptr,
+ size_t code_size)
+ : jit_generator(code_ptr, code_size)
+ , alpha(A), k(K)
+{
+ static const uint32_t mask[] = {
+ 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff,
+ 0xffffffff, 0xffffffff, 0, 0, 0, 0, 0, 0, 0
+ };
+
+ Xbyak::Reg64 c = r10;
+
+ Xbyak::Xmm xmask_lo = xmm2;
+ Xbyak::Xmm xmask_hi = xmm3;
+ Xbyak::Xmm xsum_lo = xmm4;
+ Xbyak::Xmm xsum_hi = xmm5;
+ Xbyak::Xmm xa_lo = xmm6;
+ Xbyak::Xmm xa_hi = xmm7;
+ Xbyak::Xmm xb_lo = xmm8;
+ Xbyak::Xmm xb_hi = xmm9;
+ Xbyak::Xmm xc_lo = xmm10;
+ Xbyak::Xmm xc_hi = xmm11;
+ Xbyak::Xmm xd_lo = xmm12;
+ Xbyak::Xmm xd_hi = xmm13;
+ Xbyak::Xmm xe_lo = xmm14;
+ Xbyak::Xmm xe_hi = xmm15;
+
+ this->preamble();
+
+ mov(src, ptr[this->param1 + 0]);
+ mov(dst, ptr[this->param1 + 8]);
+ if (pk != prop_kind::forward_inference)
+ mov(scratch, ptr[this->param1 + 16]);
+
+ sub(rsp, stack_space_needed);
+ mov(store_addr, rsp);
+ and_(store_addr, -15);
+
+ mov(imm_addr64, float2int(this->alpha));
+ movq(xalpha, imm_addr64);
+ shufps(xalpha, xalpha, 0);
+
+ mov(imm_addr64, float2int(this->k));
+ movq(xk, imm_addr64);
+ shufps(xk, xk, 0);
+
+ // put alpha and k into store (free up regs)
+ movaps(ptr[store_addr + 0 * 4 * sizeof(float)], xalpha);
+ movaps(ptr[store_addr + 1 * 4 * sizeof(float)], xk);
+
+ if (J.tail != 0)
+ {
+ mov(imm_addr64, reinterpret_cast<size_t>(&mask[7 - J.tail]));
+ movups(xmask_lo, ptr[imm_addr64]);
+ movups(xmask_hi, ptr[imm_addr64 + 4 * sizeof(float)]);
+ }
+ // init xa, xb
+ xorps(xa_lo, xa_lo);
+ xorps(xa_hi, xa_hi);
+ xorps(xb_lo, xb_lo);
+ xorps(xb_hi, xb_hi);
+
+ // read xc, xd
+ if (J.tail != 0) {
+ movups(xc_lo, ptr[src + J.HW * 0]);
+ movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]);
+ andps(xc_lo, xmask_lo);
+ andps(xc_hi, xmask_hi);
+ }
+ else {
+ movups(xc_lo, ptr[src + J.HW * 0]);
+ movups(xc_hi, ptr[src + J.HW * 0 + 4 * sizeof(float)]);
+ }
+ if (J.tail != 0) {
+ movups(xd_lo, ptr[src + J.HW * 4]);
+ movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]);
+ andps(xd_lo, xmask_lo);
+ andps(xd_hi, xmask_hi);
+ }
+ else {
+ movups(xd_lo, ptr[src + J.HW * 4]);
+ movups(xd_hi, ptr[src + J.HW * 4 + 4 * sizeof(float)]);
+ }
+
+ // put xa, xb, xc, xd into store to free-up regs
+ movaps(ptr[store_addr + 2 * 4 * sizeof(float)], xa_lo);
+ movaps(ptr[store_addr + 3 * 4 * sizeof(float)], xa_hi);
+ movaps(ptr[store_addr + 4 * 4 * sizeof(float)], xb_lo);
+ movaps(ptr[store_addr + 5 * 4 * sizeof(float)], xb_hi);
+ movaps(ptr[store_addr + 6 * 4 * sizeof(float)], xc_lo);
+ movaps(ptr[store_addr + 7 * 4 * sizeof(float)], xc_hi);
+ movaps(ptr[store_addr + 8 * 4 * sizeof(float)], xd_lo);
+ movaps(ptr[store_addr + 9 * 4 * sizeof(float)], xd_hi);
+
+ xorps(xsum_lo, xsum_lo);
+ xorps(xsum_hi, xsum_hi);
+ mulps(xc_lo, xc_lo);
+ mulps(xc_hi, xc_hi);
+ addps(xsum_lo, xc_lo);
+ addps(xsum_hi, xc_hi);
+ mulps(xd_lo, xd_lo);
+ mulps(xd_hi, xd_hi);
+ addps(xsum_lo, xd_lo);
+ addps(xsum_hi, xd_hi); // xsum <- xsum + xa^2+xb^2+xc^2+xd^2+xe^2
+
+ mov(c, J.C - 2);
+ Label lrn_loop;
+ L(lrn_loop);
+
+ if (J.tail != 0) {
+ movups(xe_lo, ptr[src + J.HW * 8]);
+ movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]);
+ andps(xe_lo, xmask_lo);
+ andps(xe_hi, xmask_hi);
+ }
+ else {
+ movups(xe_lo, ptr[src + J.HW * 8]);
+ movups(xe_hi, ptr[src + J.HW * 8 + 4 * sizeof(float)]);
+ }
+
+ nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
+ xe_lo, xe_hi,
+ xsum_lo, xsum_hi);
+
+ add(src, J.HW * 4);
+ add(dst, J.HW * 4);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, J.HW * 4);
+ dec(c);
+ cmp(c, 0);
+ jne(lrn_loop, T_NEAR);
+
+ xorps(xe_lo, xe_lo);
+ xorps(xe_hi, xe_hi);
+
+ nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
+ xe_lo, xe_hi,
+ xsum_lo, xsum_hi);
+ add(src, J.HW * 4);
+ add(dst, J.HW * 4);
+ if (pk != prop_kind::forward_inference)
+ add(scratch, J.HW * 4);
+
+ nchw_body_sse42(J.tail, J.HW, pk, xmask_lo, xmask_hi,
+ xe_lo, xe_hi,
+ xsum_lo, xsum_hi);
+
+ add(rsp, stack_space_needed);
+
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+}
+
+//////////////////////////////////////////////////////////////////////////////
+// backward kernel
+template <cpu_isa_t isa>
+jit_uni_lrn_bwd_kernel_f32<isa>::jit_uni_lrn_bwd_kernel_f32(
+ const struct nchw8c_across &J,
+ float A,
+ float B,
+ int use_h_parallel,
+ void *code_ptr,
+ size_t code_size)
+ : jit_generator(code_ptr, code_size)
+ , nalphabeta(-2 * A*B)
+ , use_h_parallelizm(use_h_parallel)
+{
+ Xbyak::Reg64 t = rsp;
+ Xbyak::Reg64 hw = r10;
+
+ Xbyak::Xmm xsrc_prev = xmm1;
+ Xbyak::Xmm xws_prev = xmm2;
+ Xbyak::Xmm xdiffdst_prev = xmm3;
+ Xbyak::Ymm ysrc = ymm4;
+ Xbyak::Ymm yws = ymm5;
+ Xbyak::Ymm ydiffdst = ymm6;
+ Xbyak::Xmm xsrc_next = xmm7;
+ Xbyak::Xmm xws_next = xmm8;
+ Xbyak::Xmm xdiffdst_next = xmm9;
+ Xbyak::Ymm ya = ymm10;
+ Xbyak::Xmm xa = xmm10;
+ Xbyak::Ymm yb = ymm11;
+ Xbyak::Ymm yd = ymm12;
+ Xbyak::Ymm ye = ymm13;
+ Xbyak::Ymm ysum = ymm14;
+ Xbyak::Ymm ydiffsrc = ymm15;
+
+ this->preamble();
+
+ mov(src, ptr[this->param1 + 0]);
+ mov(diffdst, ptr[this->param1 + 8]);
+ mov(workspace, ptr[this->param1 + 16]);
+ mov(diffsrc, ptr[this->param1 + 24]);
+
+ sub(t, 64);
+ mov(imm_addr64, float2int(this->nalphabeta));
+ movq(xnalphabeta, imm_addr64);
+ vbroadcastss(ynalphabeta, xnalphabeta);
+
+ bool is_single = J.version == 3;
+ bool is_first = J.version == -1 || J.version == -2;
+ bool is_last = J.version == +1 || J.version == -2;
+
+ if (is_first || is_single) {
+ vxorps(xsrc_prev, xsrc_prev, xsrc_prev);
+ vmovups(ptr[t + 0], xsrc_prev);
+ }
+ if (is_last || is_single) {
+ vxorps(xsrc_next, xsrc_next, xsrc_next);
+ vmovups(ptr[t + 48], xsrc_next);
+ }
+ mov(hw, this->use_h_parallelizm ? J.W : J.H*J.W);
+ Label lrn_loop;
+ L(lrn_loop);
+ {
+ if (!is_first && !is_single) {
+ vmovups(xws_prev, ptr[workspace - J.H*J.W * 32 + 16]);
+ vmovups(xsrc_prev, ptr[src - J.H*J.W * 32 + 16]);
+ vmovups(xdiffdst_prev, ptr[diffdst - J.H*J.W * 32 + 16]);
+ vmulps(xa, xws_prev, xws_prev);
+ vmulps(xa, xa, xws_prev);
+ vsqrtps(xa, xa);
+ vsqrtps(xa, xa);
+ vmulps(xa, xa, xws_prev);
+ vdivps(xsrc_prev, xsrc_prev, xa);
+ vmulps(xdiffdst_prev, xdiffdst_prev, xsrc_prev);
+ }
+
+ vmovups(ysrc, ptr[src]);
+ vmovups(yws, ptr[workspace]);
+ vmovups(ydiffdst, ptr[diffdst]);
+ vmulps(ya, yws, yws);
+ vmulps(ya, ya, yws);
+ vsqrtps(ya, ya);
+ vsqrtps(ya, ya);
+ vdivps(ydiffsrc, ydiffdst, ya);
+ vdivps(ysum, ydiffsrc, yws);
+ vmulps(ysum, ysum, ysrc);
+
+ if (!is_last && !is_single) {
+ vmovups(xws_next, ptr[workspace + J.H*J.W * 32]);
+ vmovups(xsrc_next, ptr[src + J.H*J.W * 32]);
+ vmovups(xdiffdst_next, ptr[diffdst + J.H*J.W * 32]);
+ vmulps(xa, xws_next, xws_next);
+ vmulps(xa, xa, xws_next);
+ vsqrtps(xa, xa);
+ vsqrtps(xa, xa);
+ vmulps(xa, xa, xws_next);
+ vdivps(xsrc_next, xsrc_next, xa);
+ vdivps(xsrc_next, xsrc_next, xws_next);
+ vmulps(xdiffdst_next, xdiffdst_next, xsrc_next);
+ }
+
+ if (!is_first && !is_single) vmovups(ptr[t + 0], xdiffdst_prev);
+ vmovups(ptr[t + 16], ysum);
+ if (!is_last && !is_single) vmovups(ptr[t + 48], xdiffdst_next);
+
+ vmovups(ya, ptr[t + 16 - 8]);
+ vmovups(yb, ptr[t + 16 - 4]);
+ vaddps(ysum, ysum, ya);
+ vmulps(ysrc, ysrc, ynalphabeta);
+ vaddps(ysum, ysum, yb);
+
+ vmovups(yd, ptr[t + 16 + 4]);
+ vmovups(ye, ptr[t + 16 + 8]);
+ vaddps(ysum, ysum, yd);
+ vaddps(ysum, ysum, ye);
+
+ vfmadd231ps(ydiffsrc, ysum, ysrc);
+
+ vmovups(ptr[diffsrc], ydiffsrc);
+
+ add(src, 32);
+ add(diffsrc, 32);
+ add(diffdst, 32);
+ add(workspace, 32);
+
+ dec(hw);
+ cmp(hw, 0);
+ jne(lrn_loop, T_NEAR);
+ }
+
+ add(t, 64);
+ this->postamble();
+
+ ker = reinterpret_cast<decltype(ker)>(const_cast<uint8_t*>(
+ this->getCode()));
+}
+
+template struct jit_uni_lrn_fwd_kernel_f32<sse42>;
+template struct jit_uni_lrn_fwd_kernel_f32<avx2>;
+template struct jit_uni_lrn_bwd_kernel_f32<avx2>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp
new file mode 100644
index 0000000000..2b3ed43cd4
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_lrn_kernel_f32.hpp
@@ -0,0 +1,183 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_UNI_LRN_KERNEL_F32_HPP
+#define CPU_JIT_UNI_LRN_KERNEL_F32_HPP
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+
+#include "jit_generator.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace Xbyak;
+
+enum params { VECTOR_LENGTH = 8, MAX_LOCAL_SIZE = 32 };
+
+typedef struct {
+ const float *src;
+ float *dst, *scratch;
+} jit_args_fwd_t;
+
+typedef struct {
+ const float *src, *diff_dst, *scratch;
+ float *diff_src;
+} jit_args_bwd_t;
+
+struct nchw8c_across {
+ /* version:
+ * -1: channels 0..7,
+ * 1: channels C-8 .. C-1,
+ * 0: other channels
+ * 3: channels only for this kernel(without prev and next)
+ */
+ int H, W, version;
+ nchw8c_across(int h, int w, int v) : H(h), W(w), version(v) {}
+};
+
+struct nchw8c_within {
+ int H, W, size;
+ nchw8c_within(int h, int w, int s) : H(h), W(w), size(s) {}
+};
+
+struct nchw_across {
+ int C, HW, tail;
+ nchw_across(int c, int hw, int t) : C(c), HW(hw), tail(t) {}
+};
+
+struct nhwc_across {
+ int C;
+ nhwc_across(int c) : C(c) {}
+};
+
+template <cpu_isa_t isa>
+struct jit_uni_lrn_fwd_kernel_f32 : public jit_generator {
+ Xbyak::Reg64 src = rax;
+ Xbyak::Reg64 dst = r8;
+ Xbyak::Reg64 scratch = rdx;
+ Xbyak::Reg64 imm_addr64 = rbx;
+ Xbyak::Reg64 store_addr = rbp;
+
+ Xbyak::Xmm xalpha = xmm0;
+ Xbyak::Ymm yalpha = ymm0;
+ Xbyak::Xmm xk = xmm1;
+ Xbyak::Ymm yk = ymm1;
+
+ float alpha;
+ float k;
+
+ int stack_space_needed = 11 * 4 * sizeof(float) + 16;
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_fwd_kernel_f32)
+
+ /* cpu specific part */
+ using Vmm = typename utils::conditional<isa == avx2, Ymm, Zmm>::type;
+
+ jit_uni_lrn_fwd_kernel_f32(
+ const struct nchw8c_within &J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void *code_ptr = nullptr,
+ size_t code_size = 4 * Xbyak::DEFAULT_MAX_CODE_SIZE);
+ jit_uni_lrn_fwd_kernel_f32(
+ const struct nchw8c_across &J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void *code_ptr = nullptr,
+ size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
+ jit_uni_lrn_fwd_kernel_f32(
+ const struct nhwc_across &J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void *code_ptr = nullptr,
+ size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
+ jit_uni_lrn_fwd_kernel_f32(
+ struct nchw_across J,
+ float A,
+ float K,
+ prop_kind_t pk,
+ void* code_ptr = nullptr,
+ size_t code_size = 2 * Xbyak::DEFAULT_MAX_CODE_SIZE);
+
+ void within_body(
+ int hoff, int Hoff, int woff, int Woff, int stride,
+ Xbyak::Ymm ysum, Xbyak::Ymm ydst, Xbyak::Ymm ytmp, Xbyak::Ymm ysum2,
+ prop_kind_t pk);
+ void within_body_sse42(
+ int hoff, int Hoff, int woff, int Woff, int stride, prop_kind_t pk);
+
+
+ void nchw_body(int tail, int HW, prop_kind_t pk,
+ Xbyak::Ymm ymask,
+ Xbyak::Ymm ya,
+ Xbyak::Ymm yb,
+ Xbyak::Ymm yc,
+ Xbyak::Ymm yd,
+ Xbyak::Ymm ye,
+ Xbyak::Ymm ysum);
+ void nchw_body_sse42(int tail, int HW, prop_kind_t pk,
+ Xbyak::Xmm xmask_lo, Xbyak::Xmm xmask_hi,
+ Xbyak::Xmm xe_lo, Xbyak::Xmm xe_hi,
+ Xbyak::Xmm xsum_lo, Xbyak::Xmm xsum_hi);
+ void nchw_tail_sse42(int tail, Xbyak::Reg64 reg_dst,
+ Xbyak::Xmm xtail_lo, Xbyak::Xmm xtail_hi);
+
+ void operator()(jit_args_fwd_t *arg) { ker(arg); }
+ void(*ker)(jit_args_fwd_t *);
+};
+
+template <cpu_isa_t isa>
+struct jit_uni_lrn_bwd_kernel_f32 : public jit_generator {
+ Xbyak::Reg64 src = rax;
+ Xbyak::Reg64 diffsrc = r8;
+ Xbyak::Reg64 diffdst = r9;
+ Xbyak::Reg64 workspace = rdx;
+ Xbyak::Reg64 imm_addr64 = rsi;
+
+ Xbyak::Xmm xnalphabeta = xmm0;
+ Xbyak::Ymm ynalphabeta = ymm0;
+
+ float nalphabeta;
+
+ int use_h_parallelizm;
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lrn_bwd_kernel_f32)
+
+ jit_uni_lrn_bwd_kernel_f32(
+ const struct nchw8c_across &J,
+ float A,
+ float B,
+ int use_h_parallel,
+ void *code_ptr = nullptr,
+ size_t code_size = 1 * Xbyak::DEFAULT_MAX_CODE_SIZE);
+
+ void operator()(jit_args_bwd_t *arg) { ker(arg); }
+ void(*ker)(jit_args_bwd_t *);
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp
new file mode 100644
index 0000000000..bf8e609d23
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp
@@ -0,0 +1,699 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+* Copyright 2018 YANDEX LLC
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+#include "cpu_pooling_pd.hpp"
+
+#include "jit_uni_pool_kernel_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace Xbyak;
+using namespace alg_kind;
+
+#define GET_OFF(field) offsetof(jit_pool_call_s, field)
+
+template <cpu_isa_t isa>
+status_t jit_uni_pool_kernel_f32<isa>::init_conf(jit_pool_conf_t &jpp,
+ const pooling_pd_t *ppd) {
+ const auto &pd = *ppd->desc();
+ const memory_desc_wrapper src_d(
+ ppd->is_fwd() ? ppd->src_md() : ppd->diff_src_md());
+ const memory_desc_wrapper dst_d(
+ ppd->is_fwd() ? ppd->dst_md() : ppd->diff_dst_md());
+
+ bool args_ok = true
+ && mayiuse(isa)
+ && utils::one_of(pd.alg_kind, pooling_max,
+ pooling_avg_include_padding,
+ pooling_avg_exclude_padding);
+ if (!args_ok) return status::unimplemented;
+
+ const int simd_w = isa == avx512_common ? 16 : 8;
+ const int ndims = src_d.ndims();
+
+ jpp.ndims = ndims;
+ jpp.mb = src_d.dims()[0];
+
+ jpp.c = utils::rnd_up(src_d.dims()[1], simd_w);
+ if (jpp.c > src_d.padded_dims()[1])
+ return status::unimplemented;
+
+ jpp.id = (ndims == 5) ? src_d.dims()[2] : 1;
+ jpp.ih = src_d.dims()[ndims-2];
+ jpp.iw = src_d.dims()[ndims-1];
+ jpp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
+ jpp.oh = dst_d.dims()[ndims-2];
+ jpp.ow = dst_d.dims()[ndims-1];
+
+ jpp.stride_d = (ndims == 5 ) ? pd.strides[0] : 1;
+ jpp.stride_h = pd.strides[ndims-4];
+ jpp.stride_w = pd.strides[ndims-3];
+ jpp.kd = (ndims == 5) ? pd.kernel[0] : 1;
+ jpp.kh = pd.kernel[ndims-4];
+ jpp.kw = pd.kernel[ndims-3];
+
+ jpp.f_pad = (ndims == 5 ) ? pd.padding[0][0] : 0;
+ jpp.t_pad = pd.padding[0][ndims-4];
+ jpp.l_pad = pd.padding[0][ndims-3];
+
+ jpp.alg = pd.alg_kind;
+
+ jpp.is_training = pd.prop_kind == prop_kind::forward_training;
+ jpp.is_backward = pd.prop_kind == prop_kind::backward_data;
+ jpp.ind_dt = ppd->workspace_md()
+ ? ppd->workspace_md()->data_type : data_type::undef;
+
+ jpp.simple_alg = jpp.is_training
+ || IMPLICATION(jpp.is_backward, jpp.kd <= jpp.stride_d);
+
+ jpp.c_block = simd_w;
+
+ jpp.nb_c = jpp.c / jpp.c_block;
+ if (jpp.alg == pooling_max) {
+ jpp.ur_w = isa == avx512_common ? 16 : 4;
+ if (jpp.is_training)
+ jpp.ur_w = isa == avx512_common ? 9 : 3;
+ else if (jpp.is_backward)
+ jpp.ur_w = isa == avx512_common ? 6 : 3;
+ } else {
+ if (jpp.is_backward)
+ jpp.ur_w = isa == avx512_common ? 12 : 6;
+ else
+ jpp.ur_w = isa == avx512_common ? 24 : 12;
+ }
+ if (jpp.ow < jpp.ur_w) jpp.ur_w = jpp.ow;
+ if (jpp.l_pad > jpp.ur_w) return status::unimplemented;
+
+ jpp.ur_w_tail = jpp.ow % jpp.ur_w;
+
+ return status::success;
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_pool_kernel_f32<isa>::maybe_recalculate_divisor(int jj,
+ int ur_w, int pad_l, int pad_r) {
+ if (jpp.alg == pooling_avg_exclude_padding) {
+ int kw = jpp.kw;
+ int stride_w = jpp.stride_w;
+
+ int non_zero_kw = kw;
+ non_zero_kw -= nstl::max(0, pad_l - jj*stride_w);
+ non_zero_kw -= nstl::max(0, pad_r - (ur_w - 1 - jj)*stride_w);
+
+ if (non_zero_kw != prev_kw) {
+ mov(tmp_gpr, float2int((float)non_zero_kw));
+ movq(xmm_tmp, tmp_gpr);
+ uni_vbroadcastss(vmm_tmp, xmm_tmp);
+ uni_vmulps(vmm_tmp, vmm_tmp, vmm_ker_area_h);
+ prev_kw = non_zero_kw;
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_pool_kernel_f32<isa>::avg_step(int ur_w, int pad_l,
+ int pad_r) {
+
+ int iw = jpp.iw;
+ int kw = jpp.kw;
+ int stride_w = jpp.stride_w;
+ int c_block = jpp.c_block;
+ Label kd_label, kh_label;
+
+ for (int jj = 0; jj < ur_w; jj++) {
+ if (jpp.is_backward) {
+ uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]);
+ maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r);
+ uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
+ } else {
+ uni_vpxor(vreg(jj), vreg(jj), vreg(jj));
+ }
+ }
+
+ if (jpp.simple_alg && jpp.ndims == 5) {
+ push(reg_input);
+ push(reg_output);
+ mov(aux_reg_input_d, reg_input);
+ mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
+ L(kd_label);
+ mov(aux_reg_input, aux_reg_input_d);
+ } else {
+ mov(aux_reg_input, reg_input);
+ }
+
+ xor_(kj, kj);
+ L(kh_label);
+ {
+ for (int ki = 0; ki < kw; ki++) {
+ int jj_start = nstl::max(0, pad_l - ki);
+ int jj_end = ur_w
+ - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
+ if (aux_input_offset > iw * c_block)
+ continue;
+ int input_offset = sizeof(float)*aux_input_offset;
+ if (jpp.is_backward) {
+ uni_vmovups(vreg(ur_w+jj),
+ ptr[aux_reg_input + input_offset]);
+ uni_vaddps(vreg(ur_w+jj), vreg(ur_w+jj), vreg(jj));
+ uni_vmovups(vmmword[aux_reg_input + input_offset],
+ vreg(ur_w+jj));
+ } else {
+ uni_vaddps(vreg(jj), vreg(jj),
+ ptr[aux_reg_input + input_offset]);
+ }
+ }
+ }
+ add(aux_reg_input, sizeof(float) * iw * c_block);
+ inc(kj);
+ cmp(kj, reg_kh);
+ jl(kh_label, T_NEAR);
+ }
+
+ if (jpp.simple_alg && jpp.ndims == 5)
+ {
+ add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block);
+ dec(ki);
+ cmp(ki, 0);
+ jg(kd_label, T_NEAR);
+ pop(reg_output);
+ pop(reg_input);
+ }
+
+ if (!jpp.is_backward) {
+ for (int jj = 0; jj < ur_w; jj++) {
+ maybe_recalculate_divisor(jj, ur_w, pad_l, pad_r);
+ uni_vdivps(vreg(jj), vreg(jj), vmm_tmp);
+ uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block],
+ vreg(jj));
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_pool_kernel_f32<isa>::max_step_fwd(int ur_w, int pad_l,
+ int pad_r) {
+ int iw = jpp.iw;
+ int kw = jpp.kw;
+ int stride_w = jpp.stride_w;
+ int c_block = jpp.c_block;
+ Label kd_label, kh_label;
+
+ mov(tmp_gpr, float2int(nstl::numeric_limits<float>::lowest()));
+ movq(xmm_tmp, tmp_gpr);
+ uni_vbroadcastss(vmm_tmp, xmm_tmp);
+
+ for (int jj = 0; jj < ur_w; jj++) {
+ uni_vmovups(vreg(jj), vmm_tmp);
+ if (jpp.is_training)
+ uni_vpxor(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(2*ur_w+jj));
+ }
+ if (jpp.is_training)
+ {
+ movq(xmm_tmp, reg_k_shift);
+ uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
+ }
+
+ if (jpp.ndims == 5) {
+ push(reg_input);
+ push(reg_output);
+ mov(aux_reg_input_d, reg_input);
+ mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
+ L(kd_label);
+ mov(aux_reg_input, aux_reg_input_d);
+ } else {
+ mov(aux_reg_input, reg_input);
+ }
+ xor_(kj, kj);
+ L(kh_label);
+ {
+ for (int ki = 0; ki < kw; ki++) {
+ int jj_start = nstl::max(0, pad_l - ki);
+ int jj_end = ur_w
+ - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
+ if (aux_input_offset > iw * c_block)
+ continue;
+ int input_offset = sizeof(float)*aux_input_offset;
+ uni_vmovups(vreg(ur_w+jj), ptr[aux_reg_input + input_offset]);
+ if (isa == sse42) {
+ movups(vmm_mask, vreg(jj));
+ cmpps(vmm_mask, vreg(ur_w+jj), _cmp_lt_os);
+ blendvps(vreg(jj), vreg(ur_w+jj));
+ if (jpp.is_training)
+ blendvps(vreg(2*ur_w+jj), vmm_k_offset);
+ } else if (isa == avx) {
+ vcmpps(vreg(3*ur_w+jj), vreg(jj), vreg(ur_w+jj),
+ _cmp_lt_os);
+ vblendvps(vreg(jj), vreg(jj), vreg(ur_w+jj),
+ vreg(3*ur_w+jj));
+ if (jpp.is_training)
+ vblendvps(vreg(2*ur_w+jj), vreg(2*ur_w+jj),
+ vmm_k_offset, vreg(3*ur_w+jj));
+ } else {
+ vcmpps(k_store_mask, vreg(jj), vreg(ur_w+jj), _cmp_lt_os);
+ vblendmps(vreg(jj) | k_store_mask, vreg(jj), vreg(ur_w+jj));
+ if (jpp.is_training)
+ vblendmps(vreg(2*ur_w+jj) | k_store_mask,
+ vreg(2*ur_w+jj), vmm_k_offset);
+ }
+ }
+ if (jpp.is_training) {
+ if (isa == avx && !mayiuse(avx2)) {
+ avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
+ } else {
+ uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
+ }
+ }
+ }
+ add(aux_reg_input, sizeof(float) * iw * c_block);
+ inc(kj);
+ cmp(kj, reg_kh);
+ jl(kh_label, T_NEAR);
+ }
+
+ if (jpp.ndims == 5)
+ {
+ add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block);
+ if (jpp.is_training) {
+ mov(tmp_gpr, ptr[reg_param + GET_OFF(kd_padding_shift)]);
+ movq(xmm_tmp, tmp_gpr);
+ uni_vpbroadcastd(vmm_tmp, xmm_tmp);
+ if (isa == avx && !mayiuse(avx2)) {
+ Xmm t(vmm_mask.getIdx());
+ avx_vpadd1(vmm_k_offset, xmm_tmp, t);
+ } else {
+ uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
+ }
+ }
+
+ dec(ki);
+ cmp(ki, 0);
+ jg(kd_label, T_NEAR);
+ pop(reg_output);
+ pop(reg_input);
+ }
+
+ for (int jj = 0; jj < ur_w; jj++) {
+ uni_vmovups(vmmword[reg_output + sizeof(float)*jj*c_block], vreg(jj));
+ if (jpp.is_training) {
+ const size_t step_index
+ = jj * c_block * types::data_type_size(jpp.ind_dt);
+
+ auto x = xreg(2 * ur_w + jj);
+ if (jpp.ind_dt == data_type::u8) {
+ if (isa == sse42) {
+ for (int i = 0; i < 4; ++i)
+ pextrb(ptr[reg_index + step_index + i], x, 4*i);
+ } else if (isa == avx) {
+ auto y = yreg(2 * ur_w + jj);
+ if (jj == 0) {
+ movd(xmm_tmp, reg_shuf_mask);
+ uni_vpbroadcastd(vmm_tmp, xmm_tmp);
+ }
+ if (mayiuse(avx2)) {
+ vpshufb(y, y, vmm_tmp);
+ movd(ptr[reg_index + step_index], x);
+ vperm2i128(y, y, y, 0x1u);
+ movd(ptr[reg_index + step_index + 4], x);
+ } else {
+ Xmm t(vmm_mask.getIdx());
+ vextractf128(t, y, 0);
+ vpshufb(t, t, xmm_tmp);
+ movd(ptr[reg_index + step_index], t);
+ vextractf128(t, y, 1);
+ vpshufb(t, t, xmm_tmp); // ymm_tmp[:128]==ymm_tmp[127:0]
+ movd(ptr[reg_index + step_index + 4], t);
+ }
+ } else {
+ auto v = vreg(2 * ur_w + jj);
+ vpmovusdb(x, v);
+ vmovups(ptr[reg_index + step_index], v | k_index_mask);
+ }
+ } else {
+ uni_vmovups(ptr[reg_index + step_index], vreg(2*ur_w+jj));
+ }
+ }
+ }
+}
+
+template <cpu_isa_t isa>
+inline void jit_uni_pool_kernel_f32<isa>::max_step_bwd(int ur_w, int pad_l,
+ int pad_r) {
+
+ int iw = jpp.iw;
+ int kw = jpp.kw;
+ int stride_w = jpp.stride_w;
+ int c_block = jpp.c_block;
+ Label kd_label, kh_label;
+
+ for (int jj = 0; jj < ur_w; jj++) {
+ uni_vmovups(vreg(jj), ptr[reg_output + sizeof(float)*jj*c_block]);
+
+ const size_t step_index
+ = jj * c_block * types::data_type_size(jpp.ind_dt);
+ if (jpp.ind_dt == data_type::u8) {
+ if (isa == sse42) {
+ movd(xreg(ur_w+jj), ptr[reg_index + step_index]);
+ pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
+ } else if (isa == avx) {
+ movq(xreg(ur_w+jj), ptr[reg_index + step_index]);
+ if (!mayiuse(avx2)) {
+ avx_pmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj), xmm_tmp);
+ } else {
+ vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
+ }
+ } else {
+ vmovups(vreg(ur_w+jj) | k_index_mask,
+ ptr[reg_index + step_index]);
+ vpmovzxbd(vreg(ur_w+jj), xreg(ur_w+jj));
+ }
+ } else {
+ uni_vmovups(vreg(ur_w+jj), ptr[reg_index + step_index]);
+ }
+ }
+ movq(xmm_tmp, reg_k_shift);
+ uni_vpbroadcastd(vmm_k_offset, xmm_tmp);
+
+ if (jpp.simple_alg && jpp.ndims == 5) {
+ push(reg_input);
+ push(reg_output);
+ if (isa == sse42) {
+ // Save rdi since it is used in maskmovdqu
+ assert(dst_ptr == rdi);
+ push(dst_ptr);
+ }
+ mov(aux_reg_input_d, reg_input);
+ mov(ki, ptr[reg_param + GET_OFF(kd_padding)]);
+ mov(reg_kd_pad_shift, ptr[reg_param + GET_OFF(kd_padding_shift)]);
+ L(kd_label);
+ mov(aux_reg_input, aux_reg_input_d);
+ } else {
+ mov(aux_reg_input, reg_input);
+ }
+
+ xor_(kj, kj);
+ L(kh_label);
+ {
+ for (int ki = 0; ki < kw; ki++) {
+ int jj_start = nstl::max(0, pad_l - ki);
+ int jj_end = ur_w
+ - utils::div_up(nstl::max(0, ki + pad_r - (kw-1)), stride_w);
+ for (int jj = jj_start; jj < jj_end; jj++) {
+ int aux_input_offset = (ki+jj*stride_w-pad_l)* c_block;
+ if (aux_input_offset > iw * c_block)
+ continue;
+ int input_offset = sizeof(float)*aux_input_offset;
+ uni_vmovups(vreg(2*ur_w+jj), ptr[aux_reg_input + input_offset]);
+ if (isa == sse42) {
+ mov(dst_ptr, aux_reg_input);
+ add(dst_ptr, input_offset);
+
+ movups(vreg(3*ur_w+jj), vreg(ur_w+jj));
+ pcmpeqd(vreg(3*ur_w+jj), vmm_k_offset);
+ addps(vreg(2*ur_w+jj), vreg(jj));
+ maskmovdqu(vreg(2*ur_w+jj), vreg(3*ur_w+jj));
+ } else if (isa == avx) {
+ if (mayiuse(avx2)) {
+ vpcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset);
+ } else {
+ avx_pcmpeqd(vreg(3*ur_w+jj), vreg(ur_w+jj), vmm_k_offset, xmm_tmp);
+ }
+ vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vreg(jj));
+ vmaskmovps(vmmword[aux_reg_input + input_offset],
+ vreg(3*ur_w+jj), vreg(2*ur_w+jj));
+ } else {
+ vpcmpeqd(k_store_mask, vreg(ur_w+jj), vmm_k_offset);
+ vblendmps(vmm_tmp | k_store_mask | T_z, vreg(jj), vreg(jj));
+ vaddps(vreg(2*ur_w+jj), vreg(2*ur_w+jj), vmm_tmp);
+ vmovups(vmmword[aux_reg_input +
+ sizeof(float)*aux_input_offset], vreg(2*ur_w+jj));
+ }
+ }
+ if (isa == avx && !mayiuse(avx2)) {
+ avx_vpadd1(vmm_k_offset, vmm_one, xmm_tmp);
+ } else {
+ uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_one);
+ }
+ }
+ add(aux_reg_input, sizeof(float) * iw * c_block);
+ inc(kj);
+ cmp(kj, reg_kh);
+ jl(kh_label, T_NEAR);
+ }
+ if (jpp.simple_alg && jpp.ndims == 5)
+ {
+ add(aux_reg_input_d, sizeof(float) * jpp.ih * iw * c_block);
+
+ mov(tmp_gpr, reg_kd_pad_shift);
+ movq(xmm_tmp, tmp_gpr);
+ uni_vpbroadcastd(vmm_tmp, xmm_tmp);
+ if (isa == avx && !mayiuse(avx2)) {
+ Xmm t(vmm_mask.getIdx());
+ avx_vpadd1(vmm_k_offset, vmm_tmp, t);
+ } else {
+ uni_vpaddd(vmm_k_offset, vmm_k_offset, vmm_tmp);
+ }
+
+ dec(ki);
+ cmp(ki, 0);
+ jg(kd_label, T_NEAR);
+ if (isa == sse42) {
+ // Save rdi since it is used in maskmovdqu
+ assert(dst_ptr == rdi);
+ pop(dst_ptr);
+ }
+ pop(reg_output);
+ pop(reg_input);
+ }
+}
+
+template <cpu_isa_t isa>
+void jit_uni_pool_kernel_f32<isa>::maybe_zero_diff_src() {
+ assert(jpp.c_block * sizeof(float) % cpu_isa_traits<isa>::vlen == 0);
+ Label l_skip, l_zero;
+
+ auto reg_oh = tmp_gpr;
+ mov(reg_oh, ptr[reg_param + GET_OFF(oh)]);
+ cmp(reg_oh, 0);
+ jz(l_skip, T_NEAR);
+
+ if (jpp.ndims == 5) {
+ mov(zero_size, ptr[reg_param + GET_OFF(oh)]);
+ mov(tmp_gpr, jpp.ih * jpp.iw * jpp.c_block * sizeof(float));
+ imul(zero_size, tmp_gpr);
+ }
+
+ auto vzero = vmm_tmp;
+ uni_vpxor(vzero, vzero, vzero);
+
+ auto reg_off = tmp_gpr;
+ xor_(reg_off, reg_off);
+
+ L(l_zero);
+ {
+ const int dim = jpp.iw * jpp.c_block * sizeof(float);
+ for (int i = 0; i < dim; i += cpu_isa_traits<isa>::vlen)
+ uni_vmovups(ptr[reg_input + reg_off + i], vzero);
+ add(reg_off, dim);
+ if (jpp.ndims == 5) cmp(reg_off, zero_size);
+ else cmp(reg_off, jpp.ih * dim);
+ jl(l_zero, T_NEAR);
+ }
+
+ L(l_skip);
+}
+
+template <cpu_isa_t isa>
+void jit_uni_pool_kernel_f32<isa>::generate() {
+
+ this->preamble();
+
+ int ow = jpp.ow;
+ int iw = jpp.iw;
+ int kw = jpp.kw;
+ int kh = jpp.kh;
+ int ur_w = jpp.ur_w;
+ int c_block = jpp.c_block;
+ int stride_w = jpp.stride_w;
+ int l_pad = jpp.l_pad;
+ int ur_w_tail = jpp.ur_w_tail;
+
+ int n_oi = ow / ur_w;
+
+ prev_kw = 0;
+
+ int vlen = cpu_isa_traits<isa>::vlen;
+
+#if defined(_WIN32)
+ // Always mimic the Unix ABI (see the note about maskmovdqu in the header
+ // file).
+ xor_(rdi, rcx);
+ xor_(rcx, rdi);
+ xor_(rdi, rcx);
+#endif
+
+ mov(reg_input, ptr[reg_param + GET_OFF(src)]);
+ mov(reg_output, ptr[reg_param + GET_OFF(dst)]);
+ if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
+ mov(reg_index, ptr[reg_param + GET_OFF(indices)]);
+ mov(reg_kh, ptr[reg_param + GET_OFF(kh_padding)]);
+ mov(reg_k_shift, ptr[reg_param + GET_OFF(kh_padding_shift)]);
+ mov(reg_ker_area_h, ptr[reg_param + GET_OFF(ker_area_h)]);
+
+ if (jpp.is_backward)
+ maybe_zero_diff_src();
+
+ if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) {
+ mov(tmp_gpr, 1);
+ movq(xmm_one, tmp_gpr);
+ uni_vpbroadcastd(vmm_one, xmm_one);
+
+ if (isa == avx) {
+ mov(reg_shuf_mask, 0x0c080400);
+ } else if (isa >= avx512_common) {
+ mov(tmp_gpr.cvt32(), 0x000f);
+ kmovw(k_index_mask, tmp_gpr.cvt32());
+ }
+ }
+
+ int r_pad = nstl::max(0, ((ow-1)*stride_w) + kw - 1 - (iw + l_pad - 1));
+ int r_pad1 = (ur_w*n_oi - 1)*stride_w + kw - 1 - (iw + l_pad - 1);
+ if (r_pad1 > 0) n_oi--;
+
+ if (jpp.alg == pooling_avg_exclude_padding) {
+ movq(xmm_ker_area_h, reg_ker_area_h);
+ uni_vpbroadcastd(vmm_ker_area_h, xmm_ker_area_h);
+ }
+
+ if (jpp.alg == pooling_avg_include_padding) {
+ mov(tmp_gpr, float2int((float)(kw * kh * jpp.kd)));
+ movq(xmm_tmp, tmp_gpr);
+ uni_vpbroadcastd(vmm_tmp, xmm_tmp);
+ }
+ if (l_pad > 0) {
+ n_oi--;
+ if (n_oi < 0 && r_pad1 > 0) {
+ step(ur_w, l_pad, r_pad1);
+ } else {
+ step(ur_w, l_pad, 0);
+ }
+
+ if (isa == sse42) {
+ if (n_oi < 0 && r_pad1 > 0) {
+ step_high_half(ur_w, l_pad, r_pad1);
+ } else {
+ step_high_half(ur_w, l_pad, 0);
+ }
+ }
+
+ if (isa == sse42) {
+ add(reg_input, sizeof(float)*(ur_w*stride_w-l_pad)*c_block - vlen);
+ add(reg_output, sizeof(float)*ur_w*c_block - vlen);
+ if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
+ add(reg_index, (2 * ur_w - 1) * c_block / 2
+ * types::data_type_size(jpp.ind_dt));
+ } else {
+ add(reg_input, sizeof(float)*(ur_w*stride_w - l_pad)*c_block);
+ add(reg_output, sizeof(float)*ur_w*c_block);
+ if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
+ add(reg_index, ur_w * c_block
+ * types::data_type_size(jpp.ind_dt));
+ }
+ }
+
+ xor_(oi_iter, oi_iter);
+ if (n_oi > 0) {
+ Label ow_loop;
+ L(ow_loop); {
+ step(ur_w, 0, 0);
+
+ if (isa == sse42) {
+ step_high_half(ur_w, 0, 0);
+ }
+
+ if (isa == sse42) {
+ add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen);
+ add(reg_output, sizeof(float)*ur_w*c_block - vlen);
+ if (jpp.alg == pooling_max &&
+ (jpp.is_training || jpp.is_backward))
+ add(reg_index, (2 * ur_w - 1) * c_block / 2
+ * types::data_type_size(jpp.ind_dt));
+ } else {
+ add(reg_input, sizeof(float)*ur_w*stride_w*c_block);
+ add(reg_output, sizeof(float)*ur_w*c_block);
+ if (jpp.alg == pooling_max &&
+ (jpp.is_training || jpp.is_backward))
+ add(reg_index, ur_w * c_block
+ * types::data_type_size(jpp.ind_dt));
+ }
+
+ inc(oi_iter);
+ cmp(oi_iter, n_oi);
+ jl(ow_loop, T_NEAR);
+ }
+ }
+
+ if (r_pad1 > 0 && n_oi >= 0) {
+ step(ur_w, 0, r_pad1);
+
+ if (isa == sse42) {
+ step_high_half(ur_w, 0, r_pad1);
+ }
+
+ if (isa == sse42) {
+ add(reg_input, sizeof(float)*ur_w*stride_w*c_block - vlen);
+ add(reg_output, sizeof(float)*ur_w*c_block - vlen);
+ if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
+ add(reg_index, (2 * ur_w - 1) * c_block / 2
+ * types::data_type_size(jpp.ind_dt));
+ } else {
+ add(reg_input, sizeof(float)*ur_w*stride_w*c_block);
+ add(reg_output, sizeof(float)*ur_w*c_block);
+ if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward))
+ add(reg_index, ur_w * c_block
+ * types::data_type_size(jpp.ind_dt));
+ }
+ }
+
+ if (ur_w_tail != 0) {
+ step(ur_w_tail, 0, r_pad);
+
+ if (isa == sse42) {
+ step_high_half(ur_w_tail, 0, r_pad);
+ }
+ }
+
+ this->postamble();
+}
+
+template struct jit_uni_pool_kernel_f32<sse42>;
+template struct jit_uni_pool_kernel_f32<avx>; // implements both <avx> and <avx2>
+template struct jit_uni_pool_kernel_f32<avx512_common>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp
new file mode 100644
index 0000000000..992b526587
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.hpp
@@ -0,0 +1,192 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+* Copyright 2018 YANDEX LLC
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_UNI_POOL_KERNEL_F32_HPP
+#define JIT_UNI_POOL_KERNEL_F32_HPP
+
+#include <cfloat>
+
+#include "c_types_map.hpp"
+#include "pooling_pd.hpp"
+#include "type_helpers.hpp"
+
+#include "jit_generator.hpp"
+#include "jit_primitive_conf.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace Xbyak;
+
+template <cpu_isa_t isa>
+struct jit_uni_pool_kernel_f32: public jit_generator {
+ jit_uni_pool_kernel_f32(jit_pool_conf_t ajpp): jpp(ajpp)
+ {
+ this->generate();
+ jit_ker = (decltype(jit_ker))this->getCode();
+ }
+
+ jit_pool_conf_t jpp;
+
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel_f32)
+
+ void operator()(jit_pool_call_s *arg) { jit_ker(arg); }
+ static status_t init_conf(jit_pool_conf_t &jbp, const pooling_pd_t *ppd);
+
+private:
+ using Vmm = typename utils::conditional3<isa == sse42, Xmm, isa == avx,
+ Ymm, Zmm>::type;
+ Xmm xreg(int idx) { return Xmm((isa == avx512_common ? 31 : 15) - idx); }
+ Ymm yreg(int idx) { return Ymm(xreg(idx).getIdx()); }
+ Vmm vreg(int idx) { return Vmm(xreg(idx).getIdx()); }
+
+ const AddressFrame &vmmword = (isa == sse42) ? xword :
+ (isa == avx) ? yword : zword;
+
+ Xmm vmm_mask = Xmm(0);
+ Xmm xmm_ker_area_h = Xmm(2);
+ Xmm xmm_one = Xmm(2);
+ Xmm xmm_tmp = Xmm(3);
+
+ Vmm vmm_ker_area_h = Vmm(2);
+ Vmm vmm_one = Vmm(2);
+ Vmm vmm_tmp = Vmm(3);
+
+ Vmm vmm_k_offset = Vmm(1);
+
+ Opmask k_index_mask = Opmask(6);
+ Opmask k_store_mask = Opmask(7);
+
+ // Here be some (tame) dragons. This kernel does not follow the regular
+ // OS-agnostic ABI pattern because when isa is sse42 it uses maskmovdqu
+ // instruction which has its destination hardcoded in rdi. Therefore:
+ // - all registers are hardcoded
+ // - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI
+ //
+ // While this is only required by the backward pass, the quirk above
+ // is applied to the forward pass as well to keep things simpler.
+
+ using reg64_t = const Xbyak::Reg64;
+ reg64_t reg_param = rdi; // Always mimic the Unix ABI
+ reg64_t reg_input = r8;
+ reg64_t aux_reg_input = r9;
+ reg64_t reg_index = r10;
+ reg64_t reg_output = r12;
+ reg64_t reg_kd_pad_shift = r13;
+ reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu
+
+ reg64_t kj = r14;
+ reg64_t oi_iter = r15;
+ reg64_t reg_kh = rax;
+ reg64_t reg_k_shift = rbx;
+ reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above
+ reg64_t reg_ker_area_h = rdx;
+
+ reg64_t zero_size = r15;
+ reg64_t ki = r12;
+ reg64_t aux_reg_input_d = r8;
+
+ Xbyak::Reg32 reg_shuf_mask = esi;
+
+ int prev_kw;
+ void (*jit_ker)(jit_pool_call_s *);
+
+ void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r);
+ void avg_step(int ur_w, int pad_l, int pad_r);
+ void max_step_fwd(int ur_w, int pad_l, int pad_r);
+ void max_step_bwd(int ur_w, int pad_l, int pad_r);
+
+ void maybe_zero_diff_src();
+
+ void step(int ur_w, int pad_l, int pad_r) {
+ if (jpp.alg == alg_kind::pooling_max) {
+ if(jpp.is_backward)
+ max_step_bwd(ur_w, pad_l, pad_r);
+ else
+ max_step_fwd(ur_w, pad_l, pad_r);
+ }
+ else
+ avg_step(ur_w, pad_l, pad_r);
+ }
+
+ void step_high_half(int ur_w, int pad_l, int pad_r) {
+ add(reg_input, sizeof(float) * 4);
+ add(reg_output, sizeof(float) * 4);
+ if (jpp.alg == alg_kind::pooling_max &&
+ (jpp.is_training || jpp.is_backward))
+ add(reg_index, types::data_type_size(jpp.ind_dt) * 4);
+
+ step(ur_w, pad_l, pad_r);
+ }
+
+ void generate();
+
+ void avx_vpadd1(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) {
+ assert(y0.getIdx() != x1.getIdx());
+ vextractf128(xtmp, y0, 0);
+ vpaddd(xtmp, xtmp, x1);
+ vinsertf128(y0, y0, xtmp, 0);
+ vextractf128(xtmp, y0, 1);
+ vpaddd(xtmp, xtmp, x1);
+ vinsertf128(y0, y0, xtmp, 1);
+ }
+
+ void avx_vpadd1(const Xmm& x0, const Xmm& x1, const Xmm&) {
+ assert(false /*function should not be used*/);
+ paddd(x0, x1);
+ }
+
+ void avx_pmovzxbd(const Ymm& y0, const Xmm& x1, const Xmm& xtmp) {
+ Xmm x0(y0.getIdx());
+ pshufd(xmm_tmp, x1, 1);
+ pmovzxbd(x0, x1);
+ pmovzxbd(xmm_tmp, xmm_tmp);
+ vinsertf128(y0, y0, xmm_tmp, 1);
+ }
+
+ void avx_pmovzxbd(const Xmm& x0, const Xmm& x1, const Xmm&) {
+ assert(false /*function should not be used*/);
+ pmovzxbd(x0, x1);
+ }
+
+ void avx_pcmpeqd(const Ymm& y0, const Ymm& y1, const Ymm& y2, const Xmm& xtmp) {
+ assert(y0.getIdx() != y1.getIdx());
+ assert(y0.getIdx() != y2.getIdx());
+ Xmm x0(y0.getIdx());
+ Xmm x2(y2.getIdx());
+ vextractf128(x0, y1, 1);
+ vextractf128(xtmp, y2, 1);
+ pcmpeqd(xtmp, x0);
+ vextractf128(x0, y1, 0);
+ pcmpeqd(x0, x2);
+ vinsertf128(y0, y0, xtmp, 1);
+ }
+
+ void avx_pcmpeqd(const Xmm& x0, const Xmm& x1, const Xmm&, const Xmm&) {
+ assert(false /*function should not be used*/);
+ pcmpeqd(x0, x1);
+ }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp
new file mode 100644
index 0000000000..afbcf996d8
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.cpp
@@ -0,0 +1,264 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "nstl.hpp"
+
+#include "jit_uni_pooling.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <cpu_isa_t isa>
+void jit_uni_pooling_fwd_t<isa>::execute_forward(const data_t *src,
+ data_t *dst, char *indices) const {
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper indices_d(pd()->workspace_md());
+ const size_t ind_dt_size = indices
+ ? types::data_type_size(indices_d.data_type()) : 0;
+
+ const auto &jpp = pd()->jpp_;
+
+ auto ker = [&](int n, int b_c, int oh) {
+ auto arg = jit_pool_call_s();
+
+ const int ij = oh * jpp.stride_h;
+ const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
+ const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
+ const int ih = nstl::max(ij - jpp.t_pad, 0);
+
+ arg.src = &src[src_d.blk_off(n, b_c, ih)];
+ arg.dst = &dst[dst_d.blk_off(n, b_c, oh)];
+ if (indices) {
+ const size_t ind_off = indices_d.blk_off(n, b_c, oh);
+ arg.indices = &indices[ind_off * ind_dt_size];
+ }
+ arg.oh = oh == 0;
+ arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
+ arg.kh_padding_shift = i_t_overflow*jpp.kw;
+ arg.kw_padding = 0;
+ arg.ker_area_h = (float)(jpp.kh -
+ nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
+ nstl::max(0, jpp.t_pad - oh*jpp.stride_h));
+ (*kernel_)(&arg);
+ };
+
+ parallel_nd(jpp.mb, jpp.nb_c, jpp.oh,
+ [&](int n, int b_c, int oh) {
+ ker(n, b_c, oh);
+ });
+}
+
+template <cpu_isa_t isa>
+void jit_uni_pooling_fwd_t<isa>::execute_forward_3d(const data_t *src,
+ data_t *dst, char *indices) const {
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper indices_d(pd()->workspace_md());
+ const size_t ind_dt_size = indices
+ ? types::data_type_size(indices_d.data_type()) : 0;
+
+ const auto &jpp = pd()->jpp_;
+
+ auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
+ int d_b_overflow) {
+ auto arg = jit_pool_call_s();
+
+ const int ij = oh * jpp.stride_h;
+ const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
+ const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
+ const int ih = nstl::max(ij - jpp.t_pad, 0);
+
+ arg.src = &src[src_d.blk_off(n, b_c, id, ih)];
+ arg.dst = &dst[dst_d.blk_off(n, b_c, od, oh)];
+ if (indices) {
+ const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
+ arg.indices = &indices[ind_off * ind_dt_size];
+ }
+ arg.oh = (oh + od == 0);
+ arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
+ arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
+ arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh;
+ arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
+ arg.kw_padding = 0;
+ arg.ker_area_h = (float)(jpp.kh -
+ nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
+ nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
+ nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
+ nstl::max(0, jpp.f_pad - od*jpp.stride_d));
+
+
+ (*kernel_)(&arg);
+ };
+
+ parallel_nd(jpp.mb, jpp.nb_c, jpp.od,
+ [&](int n, int b_c, int od) {
+ const int ik = od * jpp.stride_d;
+ const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
+ const int d_b_overflow = nstl::max(jpp.id, ik+jpp.kd-jpp.f_pad)
+ -jpp.id;
+ const int id = nstl::max(ik - jpp.f_pad, 0);
+ for (int oh = 0; oh < jpp.oh; ++oh) {
+ ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow);
+ }
+ });
+}
+
+template <cpu_isa_t isa>
+void jit_uni_pooling_bwd_t<isa>::execute_backward(const data_t *diff_dst,
+ const char *indices, data_t *diff_src) const {
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper indices_d(pd()->workspace_md());
+ const size_t ind_dt_size = indices
+ ? types::data_type_size(indices_d.data_type()) : 0;
+
+ const auto &jpp = pd()->jpp_;
+
+ auto ker = [&](int n, int b_c, int oh) {
+ auto arg = jit_pool_call_s();
+
+ const int ij = oh * jpp.stride_h;
+ const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
+ const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
+ const int ih = nstl::max(ij - jpp.t_pad, 0);
+
+ arg.src = &diff_src[diff_src_d.blk_off(n, b_c, ih)];
+ arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, oh)];
+ if (indices) {
+ const size_t ind_off = indices_d.blk_off(n, b_c, oh);
+ arg.indices = &indices[ind_off * ind_dt_size];
+ }
+ arg.oh = (oh == 0);
+ arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
+ arg.kh_padding_shift = i_t_overflow*jpp.kw;
+ arg.kw_padding = 0;
+ arg.ker_area_h = (float)(jpp.kh -
+ nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
+ nstl::max(0, jpp.t_pad - oh*jpp.stride_h));
+
+ (*kernel_)(&arg);
+ };
+
+ parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) {
+ for (int oh = 0; oh < jpp.oh; ++oh) {
+ ker(n, b_c, oh);
+ }
+ });
+}
+
+template <cpu_isa_t isa>
+void jit_uni_pooling_bwd_t<isa>::execute_backward_3d(const data_t *diff_dst,
+ const char *indices, data_t *diff_src) const {
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper indices_d(pd()->workspace_md());
+ const size_t ind_dt_size = indices
+ ? types::data_type_size(indices_d.data_type()) : 0;
+
+ const auto &jpp = pd()->jpp_;
+
+ auto ker = [&](int n, int b_c, int od, int oh, int id, int d_t_overflow,
+ int d_b_overflow, int zero_size, int kd) {
+ auto arg = jit_pool_call_s();
+
+ const int ij = oh * jpp.stride_h;
+ const int i_t_overflow = nstl::max(0, jpp.t_pad-ij);
+ const int i_b_overflow = nstl::max(jpp.ih, ij+jpp.kh-jpp.t_pad)-jpp.ih;
+ const int ih = nstl::max(ij - jpp.t_pad, 0);
+
+ arg.src = &diff_src[diff_src_d.blk_off(n, b_c, id + kd, ih)];
+ arg.dst = &diff_dst[diff_dst_d.blk_off(n, b_c, od, oh)];
+ if (indices) {
+ const size_t ind_off = indices_d.blk_off(n, b_c, od, oh);
+ arg.indices = &indices[ind_off * ind_dt_size];
+ }
+ arg.oh = zero_size;
+ arg.kd_padding = jpp.kd - d_t_overflow - d_b_overflow;
+ arg.kh_padding = jpp.kh - i_t_overflow - i_b_overflow;
+ arg.kh_padding_shift = i_t_overflow*jpp.kw + d_t_overflow*jpp.kw*jpp.kh
+ + kd * jpp.kw * jpp.kh;
+ arg.kd_padding_shift = (i_t_overflow + i_b_overflow)*jpp.kw;
+ arg.kw_padding = 0;
+ arg.ker_area_h = (float)(jpp.kh -
+ nstl::max(0, oh*jpp.stride_h - jpp.t_pad + jpp.kh - jpp.ih) -
+ nstl::max(0, jpp.t_pad - oh*jpp.stride_h)) * (jpp.kd -
+ nstl::max(0, od*jpp.stride_d - jpp.f_pad + jpp.kd - jpp.id) -
+ nstl::max(0, jpp.f_pad - od*jpp.stride_d));
+
+ (*kernel_)(&arg);
+ };
+
+ if (jpp.simple_alg) {
+
+ parallel_nd(jpp.mb, jpp.nb_c, jpp.od,
+ [&](int n, int b_c, int od) {
+ const int ik = od * jpp.stride_d;
+ const int d_t_overflow = nstl::max(0, jpp.f_pad - ik);
+ const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
+ - jpp.f_pad) - jpp.id;
+ const int id = nstl::max(ik - jpp.f_pad, 0);
+ int zero_s = jpp.stride_d - d_t_overflow - (nstl::max(
+ jpp.id, ik + jpp.stride_d - jpp.f_pad) - jpp.id);
+ for (int oh = 0; oh < jpp.oh; ++oh) {
+ ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
+ (oh == 0) ? zero_s : 0, 0);
+ }
+ });
+ } else {
+ ptrdiff_t nelems = (ptrdiff_t)jpp.mb * (ptrdiff_t)jpp.c
+ * (ptrdiff_t)jpp.id * (ptrdiff_t)jpp.ih * (ptrdiff_t)jpp.iw;
+
+ parallel_nd(nelems, [&](ptrdiff_t i) { diff_src[i] = 0.f; });
+
+ for (int kd = 0; kd < jpp.kd; ++kd) {
+ parallel_nd(jpp.mb, jpp.nb_c, [&](int n, int b_c) {
+ for (int od = 0; od < jpp.od; ++od) {
+ const int ik = od * jpp.stride_d;
+ const int d_t_overflow = nstl::max(0, jpp.f_pad-ik);
+ const int d_b_overflow = nstl::max(jpp.id, ik + jpp.kd
+ - jpp.f_pad) - jpp.id;
+ if (kd >= jpp.kd - d_t_overflow - d_b_overflow)
+ continue;
+ const int id = nstl::max(ik - jpp.f_pad, 0);
+ for (int oh = 0; oh < jpp.oh; ++oh) {
+ ker(n, b_c, od, oh, id, d_t_overflow, d_b_overflow,
+ 0, kd);
+ }
+ }
+ });
+ }
+ }
+}
+
+
+template struct jit_uni_pooling_fwd_t<sse42>;
+template struct jit_uni_pooling_bwd_t<sse42>;
+template struct jit_uni_pooling_fwd_t<avx>;
+template struct jit_uni_pooling_bwd_t<avx>;
+template struct jit_uni_pooling_fwd_t<avx512_common>;
+template struct jit_uni_pooling_bwd_t<avx512_common>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp
new file mode 100644
index 0000000000..57bebacdee
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_pooling.hpp
@@ -0,0 +1,182 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_JIT_UNI_POOLING_HPP
+#define CPU_JIT_UNI_POOLING_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_pooling_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "jit_uni_pool_kernel_f32.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <cpu_isa_t isa>
+struct jit_uni_pooling_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_pooling_fwd_pd_t {
+ using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", isa, ""),
+ jit_uni_pooling_fwd_t<isa>);
+
+ status_t init() {
+ using namespace utils;
+
+ bool ok = true
+ && set_default_params() == status::success
+ && is_fwd()
+ && !has_zero_dim_memory()
+ && everyone_is(data_type::f32,
+ src_md()->data_type,
+ dst_md()->data_type)
+ && attr()->has_default_values()
+ && memory_desc_matches_tag(*src_md(), desired_fmt_tag())
+ && memory_desc_matches_tag(*dst_md(), desired_fmt_tag());
+ if (!ok) return status::unimplemented;
+
+ bool is_training = desc_.prop_kind == prop_kind::forward_training;
+ if (desc()->alg_kind == alg_kind::pooling_max && is_training)
+ init_default_ws();
+
+ return jit_uni_pool_kernel_f32<isa>::init_conf(jpp_, this);
+ }
+
+ format_tag_t desired_fmt_tag() {
+ using namespace format_tag;
+ return ndims() == 4
+ ? isa == avx512_common ? nChw16c : nChw8c
+ : isa == avx512_common ? nCdhw16c : nCdhw8c;
+ }
+
+ jit_pool_conf_t jpp_;
+ };
+
+ jit_uni_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
+ { kernel_ = new jit_uni_pool_kernel_f32<isa>(pd()->jpp_); }
+
+ ~jit_uni_pooling_fwd_t() { delete kernel_; }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE);
+
+ if (pd()->ndims() == 5)
+ execute_forward_3d(src, dst, ws);
+ else
+ execute_forward(src, dst, ws);
+
+ return status::success;
+ }
+
+private:
+ void execute_forward(const data_t *src, data_t *dst, char *indices) const;
+ void execute_forward_3d(const data_t *src, data_t *dst,
+ char *indices) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ jit_uni_pool_kernel_f32<isa> *kernel_;
+};
+
+template <cpu_isa_t isa>
+struct jit_uni_pooling_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_pooling_bwd_pd_t {
+ using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T(
+ JIT_IMPL_NAME_HELPER("jit:", isa, ""),
+ jit_uni_pooling_bwd_t<isa>);
+
+ status_t init() {
+ using namespace utils;
+
+ bool ok = true
+ && set_default_params() == status::success
+ && !is_fwd()
+ && !has_zero_dim_memory()
+ && everyone_is(data_type::f32,
+ diff_src_md()->data_type,
+ diff_dst_md()->data_type)
+ && attr()->has_default_values()
+ && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag())
+ && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag());
+ if (!ok) return status::unimplemented;
+
+ if (desc()->alg_kind == alg_kind::pooling_max) {
+ init_default_ws();
+ if (!compare_ws(hint_fwd_pd_))
+ return status::unimplemented;
+ }
+
+ return jit_uni_pool_kernel_f32<isa>::init_conf(jpp_, this);
+ }
+
+ format_tag_t desired_fmt_tag() {
+ using namespace format_tag;
+ return ndims()
+ ? isa == avx512_common ? nChw16c : nChw8c
+ : isa == avx512_common ? nCdhw16c : nCdhw8c;
+ }
+
+ jit_pool_conf_t jpp_;
+ };
+
+ jit_uni_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd)
+ { kernel_ = new jit_uni_pool_kernel_f32<isa>(pd()->jpp_); }
+
+ ~jit_uni_pooling_bwd_t() { delete kernel_; }
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto ws = CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ if (pd()->ndims() == 5)
+ execute_backward_3d(diff_dst, ws, diff_src);
+ else
+ execute_backward(diff_dst, ws, diff_src);
+
+ return status::success;
+ }
+
+private:
+ void execute_backward(const data_t *diff_dst, const char *indices,
+ data_t *diff_src) const;
+ void execute_backward_3d(const data_t *diff_dst, const char *indices,
+ data_t *diff_src) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ jit_uni_pool_kernel_f32<isa> *kernel_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp
new file mode 100644
index 0000000000..98796503b7
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.cpp
@@ -0,0 +1,1006 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_desc_wrapper.hpp"
+#include "mkldnn_debug.h"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+
+#include "cpu_primitive.hpp"
+#include "cpu_reorder_pd.hpp"
+#include "jit_uni_reorder.hpp"
+
+#include "jit_generator.hpp"
+
+// #define TR_DEBUG
+#if defined(TR_DEBUG)
+#define DEBUg(...) do { __VA_ARGS__ } while (0)
+#else
+#define DEBUg(...)
+#endif
+#define DEBUG(...) DEBUg(__VA_ARGS__)
+
+#ifdef _WIN32
+/* seems like s_addr is a reserved macro on Windows */
+#undef s_addr
+#endif
+
+using namespace Xbyak;
+using namespace mkldnn::impl::types;
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace tr {
+
+/** Minimal reasonable/desirable kernel size.
+ * The constant might be used to determine how a problem should be split
+ * between kernel and threading driver. */
+const size_t ker_prb_size_min = 64;
+
+/* kernel */
+struct jit_uni_reorder_kernel_f32: public kernel_t, public jit_generator {
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_reorder_kernel_f32)
+
+ enum {
+ len_unroll_max = 256,
+ ndims_jit_loop_max = 3,
+ };
+
+ struct simple_impl_desc_t {
+ int ndims_full_unroll;
+ int len_last_dim_unroll;
+ int len_unroll;
+ };
+
+ static bool simple_impl_desc_init(const prb_t &prb,
+ simple_impl_desc_t *desc) {
+ const int ndims = prb.ndims;
+
+ int ndims_full_unroll = 0;
+ int len_last_dim_unroll = 1;
+ int len_unroll = 1;
+
+ for (int d = 0; d < ndims; ++d) {
+ auto &node = prb.nodes[d];
+ if (len_unroll * node.n <= len_unroll_max) {
+ ndims_full_unroll++;
+ len_unroll *= node.n;
+ } else {
+ len_last_dim_unroll = len_unroll_max / len_unroll;
+ while (node.n % len_last_dim_unroll)
+ --len_last_dim_unroll;
+ len_unroll *= len_last_dim_unroll;
+ break;
+ }
+ }
+
+ if (prb.ndims - ndims_full_unroll > ndims_jit_loop_max)
+ return false;
+
+ if (desc) {
+ desc->ndims_full_unroll = ndims_full_unroll;
+ desc->len_last_dim_unroll = len_last_dim_unroll;
+ desc->len_unroll = len_unroll;
+ }
+
+ return true;
+ }
+
+ static bool applicable(const prb_t &p) {
+ using namespace data_type;
+
+ bool ok = true
+ && p.ndims > 0
+ && utils::one_of(p.itype, f32, s32, s8, u8)
+ && utils::one_of(p.otype, f32, s32, s8, u8)
+ && utils::everyone_is(0, p.ioff, p.ooff) /* do we need this? */
+ && utils::one_of(p.beta, 0.f, 1.f) /* anything else? */
+ && simple_impl_desc_init(p, nullptr)
+ && mayiuse(sse42)
+ && IMPLICATION(!utils::everyone_is(f32, p.itype, p.otype),
+ mayiuse(avx));
+ if (!ok) return false;
+
+ const ptrdiff_t max_stride = (1LL<<31) - 1;
+ for (int d = 0; d < p.ndims; ++d) {
+ const ptrdiff_t cms = max_stride / p.nodes[d].n;
+ bool strides_ok = true
+ && p.nodes[d].is < cms / (int)data_type_size(p.itype)
+ && p.nodes[d].os < cms / (int)data_type_size(p.otype);
+ if (!strides_ok) return false;
+ }
+
+ return true;
+ }
+
+ int n(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].n; }
+ int is(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].is; }
+ int os(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].os; }
+ int ss(int d) { assert(d < prb_.ndims); return (int)prb_.nodes[d].ss; }
+
+ Address i_addr(int i_off)
+ { return ptr[reg_ptr_in + reg_off_in + i_off * itype_sz]; }
+
+ Address o_addr(int o_off)
+ { return ptr[reg_ptr_out + reg_off_out + o_off * otype_sz]; }
+
+ Address s_addr(int s_off)
+ { return ptr[reg_ptr_scale + reg_off_scale + s_off * stype_sz]; }
+
+ void step(int off, int prev_i_off, int prev_o_off, int prev_s_off,
+ int &i_off, int &o_off, int &s_off, int step_size = 1) {
+ i_off = prev_i_off;
+ o_off = prev_o_off;
+ s_off = prev_s_off;
+
+ if (off == 0) return;
+
+ int start_dim = 0, dims_prod = 1;
+ for (; start_dim < prb_.ndims && dims_prod != step_size; ++start_dim)
+ dims_prod *= n(start_dim);
+ assert(start_dim < prb_.ndims);
+ off /= step_size;
+
+ for (int d = start_dim; d < prb_.ndims; ++d) {
+ i_off += is(d);
+ o_off += os(d);
+ s_off += ss(d);
+
+ if (off % n(d)) break;
+
+ i_off += - n(d) * is(d);
+ o_off += - n(d) * os(d);
+ s_off += - n(d) * ss(d);
+ off /= n(d);
+
+ if (off == 0) break; /* FIXME: is it really required? */
+ }
+ }
+
+ void step(int off, int prev_i_off, int prev_o_off, int &i_off, int &o_off,
+ int step_size = 1) {
+ int dummy = 0;
+ step(off, prev_i_off, prev_o_off, dummy, i_off, o_off, dummy,
+ step_size);
+ }
+
+ void tr8x8_avx2(int i_off, int o_off) {
+ for (int i = 0; i < 8; i++)
+ vmovups(Ymm(i), i_addr(i_off + i * 8));
+
+ for (int i = 0; i < 8 / 2; i++) {
+ vunpcklps(Ymm(8 + i), Ymm(2 * i), Ymm(2 * i + 1));
+ vunpckhps(Ymm(i), Ymm(2 * i), Ymm(2 * i + 1));
+ }
+
+ const unsigned int lfloat = 0x44;
+ const unsigned int ufloat = 0xee;
+ for (int i = 0; i < 8 / 2; i++) {
+ int j = i % 2 == 0 ? 8 + i : i - 1;
+ vshufps(Ymm(8 / 2 + 2 * i), Ymm(j), Ymm(j + 1), lfloat);
+ vshufps(Ymm(8 / 2 + 2 * i + 1), Ymm(j), Ymm(j + 1), ufloat);
+ }
+
+ const unsigned int lquad = 0x20;
+ for (int i = 0; i < 8 / 2; i++)
+ vperm2f128(Ymm(i), Ymm(8 / 2 + i), Ymm(8 + i), lquad);
+
+ const unsigned int uquad = 0x31;
+ for (int i = 8 / 2; i < 8; i++)
+ vperm2f128(Ymm(i), Ymm(i), Ymm(8 / 2 + i), uquad);
+
+ for (int i = 0; i < 8; i++)
+ vmovups(o_addr(o_off + i * 8), Ymm(i));
+ }
+
+ bool process_unroll_tr8x8(int len) {
+ bool can_do = true
+ && mayiuse(avx2)
+ && prb_.ndims >= 2
+ && utils::everyone_is(4, itype_sz, otype_sz)
+ && utils::everyone_is(8, n(0), n(1))
+ && utils::everyone_is(1, os(0), is(1))
+ && utils::everyone_is(8, os(1), is(0))
+ && prb_.scale_type == scale_type_t::NONE
+ && prb_.beta == 0.f;
+ if (!can_do) return false;
+
+ const int step_size = n(0) * n(1);
+ int i_off = 0, o_off = 0;
+ for (int off = 0; off < len; off += step_size) {
+ step(off, i_off, o_off, i_off, o_off, step_size);
+ tr8x8_avx2(i_off, o_off);
+ }
+
+ return true;
+ }
+
+ template <cpu_isa_t isa>
+ bool process_direct_copy(int len) {
+ using namespace data_type;
+
+ using Vmm = typename cpu_isa_traits<isa>::Vmm;
+ const int simd_w = cpu_isa_traits<isa>::vlen / itype_sz;
+
+ bool can_do = true
+ && mayiuse(isa)
+ && utils::everyone_is(1, os(0), is(0))
+ && (false
+ || prb_.itype == prb_.otype
+ || (prb_.itype == s32 && prb_.otype == f32)
+ || (prb_.itype == f32 && prb_.otype == s32)
+ )
+ && len % simd_w == 0
+ && n(0) % len == 0
+ && prb_.scale_type == scale_type_t::NONE
+ && prb_.beta == 0.f;
+ if (!can_do) return false;
+
+ for (int off = 0; off < len;) {
+ const int unroll = nstl::min(16, (len - off) / simd_w);
+
+ for (int ur = 0; ur < unroll; ++ur)
+ uni_vmovups(Vmm(ur), i_addr(off + ur * simd_w));
+
+ if (prb_.itype != prb_.otype) {
+ for (int ur = 0; ur < unroll; ++ur) {
+ if (prb_.itype == s32 && prb_.otype == f32)
+ uni_vcvtdq2ps(Vmm(ur), Vmm(ur));
+ else if (prb_.itype == f32 && prb_.otype == s32)
+ uni_vcvtps2dq(Vmm(ur), Vmm(ur));
+ else assert(!"unreachable");
+ }
+ }
+
+ for (int ur = 0; ur < unroll; ++ur)
+ uni_vmovups(o_addr(off + ur * simd_w), Vmm(ur));
+
+ off += unroll * simd_w;
+ }
+
+ return true;
+ }
+
+ void process_unroll_generic_step(int reg_unroll, const int *i_off,
+ const int *o_off, const int *s_off) {
+ using namespace data_type;
+
+ auto cvt2ps = [=](const Xmm &dst, const Operand &src, data_type_t idt) {
+ Xmm dst_pure = Xmm(dst.getIdx());
+ switch (idt) {
+ case f32:
+ if (src.isMEM() || src.getIdx() != dst.getIdx())
+ vmovups(dst, src);
+ break;
+ case s32: vcvtdq2ps(dst, src); break;
+ case s8: vpmovsxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
+ case u8: vpmovzxbd(dst, src); vcvtdq2ps(dst_pure, dst); break;
+ default: assert(!"unreachable");
+ }
+ };
+
+ auto cvt2int = [=](const Xmm &xmm, data_type_t odt, data_type_t idt) {
+ switch (odt) {
+ case s32:
+ if (idt == f32) vcvtps2dq(xmm, xmm);
+ else if (idt == s8) vpmovsxbd(xmm, xmm);
+ else if (idt == u8) vpmovzxbd(xmm, xmm);
+ break;
+ case s8:
+ if (idt == f32) vcvtps2dq(xmm, xmm);
+ if (idt == f32 || idt == s32) {
+ if (mayiuse(avx512_core)) {
+ vpmovsdb(xmm, xmm);
+ } else {
+ vpackssdw(xmm, xmm, xmm_zero);
+ vpacksswb(xmm, xmm, xmm_zero);
+ }
+ }
+ if (idt == u8) vpminub(xmm, xmm, xmm_4x127b);
+ break;
+ case u8:
+ if (idt == f32) vcvtps2dq(xmm, xmm);
+ if (idt == f32 || idt == s32) {
+ if (mayiuse(avx512_core)) {
+ vpmaxsd(xmm, xmm, xmm_zero);
+ vpmovusdb(xmm, xmm);
+ } else {
+ vpackssdw(xmm, xmm, xmm_zero);
+ vpackuswb(xmm, xmm, xmm_zero);
+ }
+ }
+ if (idt == s8) vpmaxsb(xmm, xmm, xmm_zero);
+ break;
+ default: assert(!"unreachable");
+ }
+ };
+
+ auto load = [=](const Xmm &xmm, const Address &addr, int size) {
+ switch (size) {
+ case 16: movups(xmm, addr); break;
+ case 4: movss(xmm, addr); break;
+ case 1: pinsrb(xmm, addr, 0x0); break;
+ default: assert(!"unreachable");
+ }
+ };
+
+ auto store = [=](const Address &addr, const Xmm &xmm, int size) {
+ switch (size) {
+ case 16: movups(addr, xmm); break;
+ case 4: movss(addr, xmm); break;
+ case 1: pextrb(addr, xmm, 0x0); break;
+ default: assert(!"unreachable");
+ }
+ };
+
+ /* check whether loading 4 values at once is possible */
+ bool can_load_xmm = mayiuse(avx) && reg_unroll % 4 == 0;
+ for (int ur = 1; ur < reg_unroll; ++ur)
+ if (i_off[ur] != i_off[ur - 1] + 1)
+ can_load_xmm = false;
+ const int load_step = can_load_xmm ? 4 : 1;
+
+ /* check whether storing 4 values at once is possible */
+ bool can_store_xmm = reg_unroll % 4 == 0;
+ for (int ur = 1; ur < reg_unroll; ++ur)
+ if (o_off[ur] != o_off[ur - 1] + 1)
+ can_store_xmm = false;
+ const int ur_step = can_store_xmm ? 4 : 1;
+
+ const bool interim_f32 = false
+ || utils::one_of(f32, prb_.itype, prb_.otype)
+ || prb_.scale_type != scale_type_t::NONE
+ || prb_.beta != 0.f;
+
+ if (!can_load_xmm && can_store_xmm) {
+ assert(ur_step == 4);
+ /* load with stride */
+ for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+ for (int r = 0; r < ur_step; ++r) {
+ if (itype_sz == 4)
+ pinsrd(Xmm(ur), i_addr(i_off[ur + r]), r);
+ else
+ pinsrb(Xmm(ur), i_addr(i_off[ur + r]), r);
+ }
+ }
+ } else {
+ for (int ur = 0; ur < reg_unroll; ur += load_step)
+ load(Xmm(ur), i_addr(i_off[ur]), load_step * itype_sz);
+ }
+
+ /* xmm[:] <-- (f32)xmm[:] */
+ if (interim_f32) {
+ const int cvt_step = nstl::max(load_step, ur_step);
+ for (int ur = 0; ur < reg_unroll; ur += cvt_step)
+ cvt2ps(Xmm(ur), Xmm(ur), prb_.itype);
+ }
+
+ if (can_load_xmm && !can_store_xmm) {
+ const bool fast_return = true // transposition on the fly
+ && prb_.scale_type != scale_type_t::MANY
+ && prb_.beta == 0.f;
+ if (fast_return) {
+ for (int ur = 0; ur < reg_unroll; ur += load_step) {
+ if (prb_.scale_type == scale_type_t::COMMON)
+ mulps(Xmm(ur), xmm_scale);
+ if (prb_.otype != f32)
+ cvt2int(Xmm(ur), prb_.otype,
+ interim_f32 ? f32 : prb_.itype);
+ for (int r = 0; r < load_step; ++r) {
+ if (otype_sz == 4)
+ pextrd(o_addr(o_off[ur + r]), Xmm(ur), r);
+ else
+ pextrb(o_addr(o_off[ur + r]), Xmm(ur), r);
+ }
+ }
+ return;
+ }
+
+ /* scatter elements of xmm into 4 xmms */
+ if (itype_sz == 4 || interim_f32) {
+ for (int ur = 0; ur < reg_unroll; ur += load_step)
+ for (int r = 1; r < load_step; ++r)
+ vshufps(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
+ } else {
+ for (int ur = 0; ur < reg_unroll; ur += load_step)
+ for (int r = 1; r < load_step; ++r)
+ vpalignr(Xmm(ur + r), Xmm(ur), Xmm(ur), r);
+ }
+ }
+
+ /* scale and beta processing */
+ if (can_store_xmm) {
+ /* xmm <-- scale * xmm[:] */
+ if (prb_.scale_type == scale_type_t::COMMON) {
+ for (int ur = 0; ur < reg_unroll; ur += ur_step)
+ mulps(Xmm(ur), xmm_scale);
+ } else if (prb_.scale_type == scale_type_t::MANY) {
+ enum class scale_load_type_t { bcast, load, gather };
+
+ for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+ scale_load_type_t scale_load_type =
+ scale_load_type_t::bcast; // the best case
+
+ for (int r = ur + 1; r < ur + ur_step; ++r)
+ if (s_off[r] != s_off[r - 1] + 0)
+ scale_load_type = scale_load_type_t::load;
+
+ if (scale_load_type == scale_load_type_t::bcast) {
+ movss(xmm_scale, s_addr(s_off[ur]));
+ shufps(xmm_scale, xmm_scale, 0x0);
+ mulps(Xmm(ur), xmm_scale);
+ continue;
+ }
+
+ // bcast doesn't work, the next try -- load
+ for (int r = ur + 1; r < ur + ur_step; ++r)
+ if (s_off[r] != s_off[r - 1] + 1)
+ scale_load_type = scale_load_type_t::gather;
+
+ if (scale_load_type == scale_load_type_t::load) {
+ movups(xmm_scale, s_addr(s_off[ur]));
+ mulps(Xmm(ur), xmm_scale);
+ continue;
+ }
+
+ // load doesn't work as well
+ // so gather the scale factors one by one
+ for (int r = ur; r < ur + ur_step; ++r)
+ pinsrd(xmm_scale, s_addr(s_off[r]), r - ur);
+ mulps(Xmm(ur), xmm_scale);
+ }
+ }
+
+ /* dst <-- beta * dst + xmm[:] */
+ assert(prb_.beta == 0.f || prb_.beta == 1.f);
+ if (prb_.beta == 1.f) {
+ for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+ if (prb_.otype == f32) {
+ /* non VEX instructions do not support unaligned
+ * memory for instructions other than movups. */
+ if (mayiuse(avx)) {
+ vaddps(Xmm(ur), o_addr(o_off[ur]));
+ } else {
+ /* register xmm(1) is unused */
+ movups(Xmm(1), o_addr(o_off[ur]));
+ addps(Xmm(ur), Xmm(1));
+ }
+ } else {
+ cvt2ps(Xmm(1), o_addr(o_off[ur]), prb_.otype);
+ vaddps(Xmm(ur), Xmm(1));
+ }
+ }
+ }
+ } else {
+ /* xmm[0] <-- scale * xmm[0] */
+ if (prb_.scale_type == scale_type_t::COMMON) {
+ for (int ur = 0; ur < reg_unroll; ur += ur_step)
+ mulss(Xmm(ur), xmm_scale);
+ } else if (prb_.scale_type == scale_type_t::MANY) {
+ for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+ mulss(Xmm(ur), s_addr(s_off[ur]));
+ }
+ }
+
+ /* dst <-- beta * dst + xmm[0] */
+ assert(prb_.beta == 0.f || prb_.beta == 1.f);
+ if (prb_.beta == 1.f) {
+ for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+ if (prb_.otype == f32) {
+ addss(Xmm(ur), o_addr(o_off[ur]));
+ } else {
+ if (prb_.otype == s32) {
+ vmovss(xmm_tmp, o_addr(o_off[ur]));
+ } else if (utils::one_of(prb_.otype, s8, u8)) {
+ pinsrb(xmm_tmp, o_addr(o_off[ur]), 0x0);
+ } else {
+ assert(!"unsupported o_type");
+ }
+ cvt2ps(xmm_tmp, xmm_tmp, prb_.otype);
+ addps(Xmm(ur), xmm_tmp);
+ }
+ }
+ }
+ }
+
+ for (int ur = 0; ur < reg_unroll; ur += ur_step) {
+ if (prb_.otype != f32)
+ cvt2int(Xmm(ur), prb_.otype, interim_f32 ? f32 : prb_.itype);
+ store(o_addr(o_off[ur]), Xmm(ur), ur_step * otype_sz);
+ }
+ }
+
+ void process_unroll_generic(int len) {
+ const int blk = 8;
+
+ int i_off[2 * blk] = {0};
+ int o_off[2 * blk] = {0};
+ int s_off[2 * blk] = {0};
+
+ int curr = 0; // will switch between 0 and 1
+
+ for (int off = 0; off < len; off += blk) {
+ const int reg_unroll = nstl::min(off + blk, len) - off;
+
+ /* compute offsets */
+ for (int ur = off != 0 ? 0 : 1; ur < reg_unroll; ++ur) {
+ const int ur_c = curr * blk + ur;
+ const int ur_p = (ur_c - 1 + 2 * blk) % (2 * blk); // prev ur
+ step(off + ur,
+ i_off[ur_p], o_off[ur_p], s_off[ur_p],
+ i_off[ur_c], o_off[ur_c], s_off[ur_c]);
+ }
+
+ process_unroll_generic_step(reg_unroll, i_off + curr * blk,
+ o_off + curr * blk, s_off + curr * blk);
+
+ curr = 1 - curr;
+ }
+ }
+
+ void loop_begin(Label &l, Reg64 reg_cnt, int len) {
+ mov(reg_cnt, len);
+ L(l);
+ }
+
+ void loop_end(Label &l, Reg64 reg_cnt, int len,
+ int i_step, int o_step, int s_step) {
+ add(reg_off_in, i_step * itype_sz);
+ add(reg_off_out, o_step * otype_sz);
+ if (prb_.scale_type == scale_type_t::MANY)
+ add(reg_off_scale, s_step * stype_sz);
+ dec(reg_cnt);
+ jnz(l);
+
+ sub(reg_off_in, len * i_step * itype_sz);
+ sub(reg_off_out, len * o_step * otype_sz);
+ if (prb_.scale_type == scale_type_t::MANY)
+ sub(reg_off_scale, len * s_step * stype_sz);
+ }
+
+ bool simple_impl() {
+ simple_impl_desc_t d;
+ if (!simple_impl_desc_init(prb_, &d)) return false;
+
+ const int nfu = d.ndims_full_unroll;
+ const int ldu = d.len_last_dim_unroll;
+ const int n_jit_loops = prb_.ndims - d.ndims_full_unroll;
+ assert(n_jit_loops <= ndims_jit_loop_max);
+
+ xor_(reg_off_in, reg_off_in);
+ xor_(reg_off_out, reg_off_out);
+ if (prb_.scale_type == scale_type_t::MANY)
+ xor_(reg_off_scale, reg_off_scale);
+
+ Label l_loop[3];
+ Reg64 reg_cnt[3] = {r15, r14, r13};
+
+ if (n_jit_loops > 2)
+ loop_begin(l_loop[2], reg_cnt[2], n(nfu + 2));
+
+ if (n_jit_loops > 1)
+ loop_begin(l_loop[1], reg_cnt[1], n(nfu + 1));
+
+ if (n_jit_loops > 0)
+ loop_begin(l_loop[0], reg_cnt[0], n(nfu + 0) / ldu);
+
+ const bool optimized = false
+ || process_direct_copy<avx>(d.len_unroll)
+ || process_direct_copy<sse42>(d.len_unroll)
+ || process_unroll_tr8x8(d.len_unroll);
+ if (!optimized)
+ process_unroll_generic(d.len_unroll);
+
+ if (n_jit_loops > 0)
+ loop_end(l_loop[0], reg_cnt[0],
+ n(nfu + 0) / ldu, is(nfu + 0) * ldu, os(nfu + 0) * ldu,
+ ss(nfu + 0) * ldu);
+
+ if (n_jit_loops > 1)
+ loop_end(l_loop[1], reg_cnt[1],
+ n(nfu + 1), is(nfu + 1), os(nfu + 1), ss(nfu + 1));
+
+ if (n_jit_loops > 2)
+ loop_end(l_loop[2], reg_cnt[2],
+ n(nfu + 2), is(nfu + 2), os(nfu + 2), ss(nfu + 2));
+
+ return true;
+ }
+
+ void impl() {
+ if (simple_impl()) return;
+ assert(!"no implementation available");
+ }
+
+ jit_uni_reorder_kernel_f32(const desc_t &desc)
+ : kernel_t(desc), jit_generator() {
+ itype_sz = data_type_size(prb_.itype);
+ otype_sz = data_type_size(prb_.otype);
+ stype_sz = sizeof(float);
+
+ preamble();
+# define PARAM(x) ptr[abi_param1 + offsetof(call_param_t, x)]
+ if (prb_.scale_type == scale_type_t::COMMON) {
+ auto reg_ptr_scale_tmp = reg_ptr_in;
+ mov(reg_ptr_scale_tmp, PARAM(scale));
+ movups(xmm_scale, ptr[reg_ptr_scale_tmp]);
+ } else if (prb_.scale_type == scale_type_t::MANY) {
+ mov(reg_ptr_scale, PARAM(scale));
+ }
+ mov(reg_ptr_in, PARAM(in));
+ mov(reg_ptr_out, PARAM(out));
+# undef PARAM
+
+ if (mayiuse(avx)) {
+ vxorps(xmm_zero, xmm_zero, xmm_zero);
+
+ if (prb_.itype == data_type::u8 && prb_.otype == data_type::s8) {
+ mov(reg_tmp.cvt32(), 0x7f7f7f7f);
+ movd(xmm_4x127b, reg_tmp.cvt32());
+ }
+ }
+
+ impl();
+ postamble();
+ ker_ = (void (*)(const call_param_t *))getCode();
+ }
+
+private:
+ int itype_sz;
+ int otype_sz;
+ int stype_sz;
+
+ Reg64 reg_ptr_in = rsi;
+ Reg64 reg_ptr_out = rdx;
+ Reg64 reg_ptr_scale = abi_not_param1;
+
+ Reg64 reg_off_in = r8;
+ Reg64 reg_off_out = r9;
+ Reg64 reg_off_scale = r10;
+
+ Reg64 reg_tmp = rax;
+
+ Xmm xmm_scale = xmm15;
+ Xmm xmm_zero = xmm14;
+ Xmm xmm_4x127b = xmm13; // TODO: unite with xmm_zero
+ Xmm xmm_tmp = xmm12;
+};
+
+status_t kernel_t::desc_init(kernel_t::desc_t &desc, const prb_t &prb,
+ int ndims_ker_max) {
+ desc.prb = prb;
+ desc.prb.ioff = desc.prb.ooff = 0;
+
+ if (ndims_ker_max > prb.ndims)
+ return status::invalid_arguments;
+
+ auto ndims_ker_max_f = [&]() {
+ size_t cur_size = 1;
+ for (int d = 0; d < prb.ndims; cur_size *= prb.nodes[d++].n)
+ if (cur_size >= ker_prb_size_min) return d;
+ return prb.ndims;
+ };
+
+ if (ndims_ker_max <= 0)
+ ndims_ker_max = ndims_ker_max_f();
+
+ /* traverse through kernel implementations */
+ /* TODO: find a better way to do that... */
+ desc.id = 0;
+ for (int ndims_ker = ndims_ker_max; ndims_ker > 0; --ndims_ker) {
+ desc.prb.ndims = ndims_ker;
+ if (jit_uni_reorder_kernel_f32::applicable(desc.prb))
+ return status::success;
+ }
+
+ return status::unimplemented;
+}
+
+kernel_t *kernel_t::create(const kernel_t::desc_t &desc) {
+ switch (desc.id) {
+ case 0: return new jit_uni_reorder_kernel_f32(desc);
+ default: assert(!"unknown kernel id"); return nullptr;
+ }
+
+ return nullptr;
+}
+
+}
+
+static void prb_block_for_cache(tr::prb_t &prb) {
+ if (prb.nodes[0].is % 64 == 0 && prb.nodes[0].n > 16) {
+ /** an attempt to use caches more efficient and
+ * address the 4K-aliasing issue */
+ /* TODO: improve the logic around here */
+ int j = 1;
+ for (; j < prb.ndims && prb.nodes[j].is != 1; ++j);
+ if (j == prb.ndims) return;
+
+ /* it makes sense to re-prioritize sequential read over
+ * sequential write if the former would not trash the
+ * cache, i.e. is == 1 and os % 2^smth != 0. Smth is
+ * set to 2 at the moment */
+ const int move_to = prb.nodes[j].os % 4 != 0 ? 0 : 1;
+ if (j == move_to) return;
+
+ if (prb.nodes[j].n > 16 && prb.nodes[j].n % 16 == 0)
+ prb_node_split(prb, j, 16);
+
+ prb_node_move(prb, j, move_to);
+ DEBUG({ printf("cache: "); prb_dump(prb); });
+ }
+}
+
+/** finds the maximum number of dimension the kernel should process and
+ * optionally splits one of the dimension to achieve better balance between
+ * parallel driver and the kernel. */
+static void prb_thread_kernel_balance(tr::prb_t &prb, int &ndims_ker_max) {
+ size_t sz_total = 1;
+ for (int d = 0; d < prb.ndims; ++d)
+ sz_total *= prb.nodes[d].n;
+
+ /* sz_drv_min is the minimal size for the parallel
+ * driver required for good parallelization */
+ const size_t sz_drv_min = nstl::min<size_t>(
+ 16 * mkldnn_get_max_threads(),
+ utils::div_up(sz_total, 1024));
+
+ /* kdims -- # of dimensions processed by a kernel
+ * sz_ker_cur -- product of the dimension processed by a kernel
+ * sz_drv_cur -- product of the dimension processed by a driver */
+
+ int kdims = prb.ndims;
+ size_t sz_drv_cur = 1;
+ for (; kdims > 1 && sz_drv_cur < sz_drv_min; --kdims)
+ sz_drv_cur *= prb.nodes[kdims - 1].n;
+
+ size_t sz_ker_cur = 1;
+ for (int d = 0; d < kdims; ++d)
+ sz_ker_cur *= prb.nodes[d].n;
+
+ /* Initially kdims is chosen so that sz_drv_cur >= sz_drv_min.
+ *
+ * It might happen that for chosen kdims the sz_ker_cur is too small
+ * (less than tr::ker_prb_size_min). In that case try to split the
+ * innermost driver dimension into two, to increase sz_ker_cur. */
+ bool want_borrow_ker_from_drv = true
+ && kdims < prb.ndims
+ && sz_ker_cur < tr::ker_prb_size_min
+ && sz_drv_cur > sz_drv_min;
+ if (want_borrow_ker_from_drv) {
+ /* sz_want_borrow is the minimal sz, so that:
+ * o) sz_ker_cur * sz_want_borrow >= tr::ker_prb_size_min
+ * o) current innermost driver dimension is divisible by
+ * sz_want_borrow (so that we can evenly split that
+ * dimension into two)
+ *
+ * In the worst case the minimal sz_want_borrow is equal
+ * to the innermost driver dimension itself. In that case
+ * we will sacrifice it in favor of kernel (is it fine?). */
+ size_t sz_want_borrow
+ = utils::div_up(tr::ker_prb_size_min, sz_ker_cur);
+ for (; prb.nodes[kdims].n % sz_want_borrow; ++sz_want_borrow);
+ if (sz_want_borrow != prb.nodes[kdims].n)
+ prb_node_split(prb, kdims, sz_want_borrow);
+ kdims += 1;
+ }
+
+ /* On the other hand it might happen that for chosen kdims
+ * the sz_drv_cur is too small (less than sz_drv_min). In that case
+ * try to split the outermost kernel dimension into two, to increase
+ * sz_drv_cur. */
+ bool want_borrow_drv_from_ker = true
+ && sz_ker_cur > tr::ker_prb_size_min
+ && sz_drv_cur < sz_drv_min;
+ if (want_borrow_drv_from_ker) {
+ size_t sz_want_borrow = utils::div_up(sz_drv_min, sz_drv_cur);
+ for (; prb.nodes[kdims - 1].n % sz_want_borrow; ++sz_want_borrow);
+ if (sz_want_borrow != prb.nodes[kdims - 1].n)
+ prb_node_split(prb, kdims - 1,
+ prb.nodes[kdims - 1].n / sz_want_borrow);
+ }
+
+ ndims_ker_max = kdims;
+
+ if (want_borrow_ker_from_drv || want_borrow_drv_from_ker) {
+ DEBUG({ printf("split: "); prb_dump(prb);
+ printf("ndims_ker_max = %d\n", ndims_ker_max); });
+ }
+}
+
+struct jit_uni_reorder_t : public cpu_primitive_t {
+ struct pd_t : public cpu_reorder_pd_t {
+ using cpu_reorder_pd_t::cpu_reorder_pd_t;
+
+ DECLARE_COMMON_PD_T("jit:uni", jit_uni_reorder_t);
+
+ static status_t create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+ auto prb = tr::prb_t();
+
+ status_t prb_init_status = prb_init(prb, *src_md, *dst_md, attr);
+ if (prb_init_status != status::success) return prb_init_status;
+
+ DEBUG({ printf("init : "); prb_dump(prb); });
+ prb_normalize(prb);
+ DEBUG({ printf("norm : "); prb_dump(prb); });
+ prb_simplify(prb);
+ DEBUG({ printf("smpl : "); prb_dump(prb); });
+
+ prb_block_for_cache(prb);
+
+ int ndims_ker_max;
+ prb_thread_kernel_balance(prb, ndims_ker_max);
+
+ tr::kernel_t::desc_t ker_desc;
+ status_t ker_init_status
+ = tr::kernel_t::desc_init(ker_desc, prb, ndims_ker_max);
+ if (ker_init_status != status::success) return ker_init_status;
+
+ const int ndims_driver = prb.ndims - ker_desc.prb.ndims;
+ if (ndims_driver > jit_uni_reorder_t::ndims_driver_max)
+ return status::unimplemented;
+
+ DEBUG({ printf("ker : "); prb_dump(ker_desc.prb); });
+
+ auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
+ dst_md);
+ if (_pd == nullptr) return status::out_of_memory;
+ if (_pd->init() != status::success) {
+ delete _pd;
+ return status::unimplemented;
+ }
+ _pd->prb_ = prb;
+ _pd->ker_desc_ = ker_desc;
+ return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
+ }
+
+ tr::prb_t prb_;
+ tr::kernel_t::desc_t ker_desc_;
+ };
+
+ jit_uni_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {
+ kernel_ = tr::kernel_t::create(pd()->ker_desc_);
+ assert(kernel_);
+ }
+ ~jit_uni_reorder_t() { delete kernel_; }
+
+ void omp_driver_0d(int off, const char *in, char *out,
+ const float *scale) const {
+ tr::call_param_t c{in, out, scale};
+ (*kernel_)(&c);
+ }
+
+ void omp_driver_1d(int ithr, int nthr, int off, const char *in, char *out,
+ const float *scale) const {
+ const tr::node_t *ns = pd()->prb_.nodes + off;
+ for_nd(ithr, nthr, (ptrdiff_t)ns[0].n, [&](ptrdiff_t d0) {
+ auto c = tr::call_param_t();
+ c.in = in + d0 * ns[0].is * data_type_size(pd()->prb_.itype);
+ c.out = out + d0 * ns[0].os * data_type_size(pd()->prb_.otype);
+ c.scale = scale + d0 * ns[0].ss;
+ (*kernel_)(&c);
+ });
+ }
+
+ void omp_driver_2d(int ithr, int nthr, int off, const char *in, char *out,
+ const float *scale) const {
+ const tr::node_t *ns = pd()->prb_.nodes + off;
+ for_nd(ithr, nthr, (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
+ [&](ptrdiff_t d1, ptrdiff_t d0) {
+ auto c = tr::call_param_t();
+ c.in = in + (d0 * ns[0].is + d1 * ns[1].is)
+ * data_type_size(pd()->prb_.itype);
+ c.out = out + (d0 * ns[0].os + d1 * ns[1].os)
+ * data_type_size(pd()->prb_.otype);
+ c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss;
+ (*kernel_)(&c);
+ });
+ }
+
+ void omp_driver_3d(int ithr, int nthr, int off, const char *in, char *out,
+ const float *scale) const {
+ const tr::node_t *ns = pd()->prb_.nodes + off;
+ for_nd(ithr, nthr, (ptrdiff_t)ns[2].n, (ptrdiff_t)ns[1].n,
+ (ptrdiff_t)ns[0].n,
+ [&](ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
+ auto c = tr::call_param_t();
+ c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is)
+ * data_type_size(pd()->prb_.itype);
+ c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os)
+ * data_type_size(pd()->prb_.otype);
+ c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss;
+ (*kernel_)(&c);
+ });
+ }
+
+ void omp_driver_4d(int ithr, int nthr, int off, const char *in, char *out,
+ const float *scale) const {
+ const tr::node_t *ns = pd()->prb_.nodes + off;
+ for_nd(ithr, nthr, (ptrdiff_t)ns[3].n, (ptrdiff_t)ns[2].n,
+ (ptrdiff_t)ns[1].n, (ptrdiff_t)ns[0].n,
+ [&](ptrdiff_t d3, ptrdiff_t d2, ptrdiff_t d1, ptrdiff_t d0) {
+ auto c = tr::call_param_t();
+ c.in = in + (d0 * ns[0].is + d1 * ns[1].is + d2 * ns[2].is
+ + d3 * ns[3].is) * data_type_size(pd()->prb_.itype);
+ c.out = out + (d0 * ns[0].os + d1 * ns[1].os + d2 * ns[2].os
+ + d3 * ns[3].os) * data_type_size(pd()->prb_.otype);
+ c.scale = scale + d0 * ns[0].ss + d1 * ns[1].ss + d2 * ns[2].ss
+ + d3 * ns[3].ss;
+ (*kernel_)(&c);
+ });
+ }
+
+ void omp_driver(const char *in, char *out, const float *scale) const {
+ in += pd()->prb_.ioff * data_type_size(pd()->prb_.itype);
+ out += pd()->prb_.ooff * data_type_size(pd()->prb_.otype);
+
+ DEBUG({ printf("prb : "); tr::prb_dump(pd()->prb_); });
+ DEBUG({ printf("ker : "); tr::prb_dump(pd()->ker_desc_.prb); });
+
+ int ndims = pd()->prb_.ndims;
+ int ndims_ker = pd()->ker_desc_.prb.ndims;
+ assert(ndims - ndims_ker <= ndims_driver_max);
+
+ if (ndims - ndims_ker == 0) {
+ omp_driver_0d(ndims_ker, in, out, scale);
+ } else {
+ parallel(0, [&](const int ithr, const int nthr) {
+ switch (ndims - ndims_ker) {
+ case 1: omp_driver_1d(ithr, nthr, ndims_ker, in, out, scale); break;
+ case 2: omp_driver_2d(ithr, nthr, ndims_ker, in, out, scale); break;
+ case 3: omp_driver_3d(ithr, nthr, ndims_ker, in, out, scale); break;
+ case 4: omp_driver_4d(ithr, nthr, ndims_ker, in, out, scale); break;
+ default: assert(!"unimplemented");
+ }
+ });
+ }
+ }
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto in = CTX_IN_MEM(const char *, MKLDNN_ARG_FROM);
+ auto out = CTX_OUT_MEM(char *, MKLDNN_ARG_TO);
+
+ omp_driver(in, out, pd()->attr()->output_scales_.scales_);
+
+ return status::success;
+ }
+
+ enum { ndims_driver_max = 4 };
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ tr::kernel_t *kernel_;
+};
+
+status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+ return jit_uni_reorder_t::pd_t::create(reorder_pd, engine, attr,
+ src_engine, src_md, dst_engine, dst_md);
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp
new file mode 100644
index 0000000000..0746ea61d3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder.hpp
@@ -0,0 +1,127 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef _JIT_UNI_REORDER_HPP
+#define _JIT_UNI_REORDER_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+
+#include "cpu_primitive.hpp"
+#include "cpu_reorder_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace tr {
+
+constexpr int max_ndims = MKLDNN_MAX_NDIMS;
+
+struct node_t {
+ size_t n;
+ ptrdiff_t is; // input stride
+ ptrdiff_t os; // output stride
+ ptrdiff_t ss; // scale stride
+};
+
+enum class scale_type_t { NONE, COMMON, MANY };
+
+struct prb_t {
+ data_type_t itype;
+ data_type_t otype;
+ int ndims;
+ node_t nodes[max_ndims];
+ ptrdiff_t ioff;
+ ptrdiff_t ooff;
+ scale_type_t scale_type;
+ float beta;
+};
+
+status_t prb_init(prb_t &prb, const memory_desc_t &imd,
+ const memory_desc_t &omd, const primitive_attr_t *attr);
+
+/** sorts the problem nodes so that output strides come in ascending order */
+void prb_normalize(prb_t &p);
+
+/** folds nodes together if possible */
+void prb_simplify(prb_t &p);
+
+/** splits the node dim into two of sizes n1 and n / n1
+ * @warning n must be multiple of n1 */
+void prb_node_split(prb_t &p, int dim, size_t n1);
+
+/** swaps d0 and d1 nodes */
+void prb_node_swap(prb_t &p, int d0, int d1);
+
+/** moves node d0 to the d1 position.
+ * nodes (d0, d1] are shifted to the left if d0 < d1 or
+ * to the right if d0 > d1 */
+void prb_node_move(prb_t &p, int d0, int d1);
+
+/** dumps the problem to stdout */
+void prb_dump(const prb_t &p);
+
+struct call_param_t {
+ const void *in;
+ void *out;
+ const float *scale;
+};
+
+struct kernel_t {
+ struct desc_t {
+ int id;
+ prb_t prb;
+ };
+
+ kernel_t(const desc_t &desc): desc_(desc), ker_(nullptr) {}
+ void operator()(const call_param_t *c) const { assert(ker_); ker_(c); }
+ virtual ~kernel_t() {}
+
+ /** inits kernel descriptor:
+ * desc -- kernel descriptor (output)
+ * prb -- transposition problem (input)
+ * ndims_ker_max -- limit the maximum number of dimensions kernel
+ * will process (optional, 0 -- no limitation) */
+ static status_t desc_init(desc_t &desc, const prb_t &prb,
+ int ndims_ker_max = 0);
+
+ /** creates kernel for the problem described in desc */
+ static kernel_t *create(const desc_t &desc);
+
+protected:
+ const desc_t desc_;
+ const prb_t &prb_ = desc_.prb;
+ void (*ker_)(const call_param_t *);
+};
+
+/* TODO: add trans_t class */
+
+}
+
+/* for cpu reorder list */
+status_t jit_uni_reorder_create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md);
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp
new file mode 100644
index 0000000000..69b7a33604
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp
@@ -0,0 +1,313 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_desc_wrapper.hpp"
+#include "mkldnn_debug.h"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "jit_uni_reorder.hpp"
+
+using namespace mkldnn::impl::types;
+using namespace mkldnn::impl::status;
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace tr {
+
+/** ad-hoc structure to describe blocked memory layout */
+struct layout_desc_t {
+ data_type_t dt;
+ int ndims;
+ dims_t id;
+ dims_t dims;
+ strides_t strides;
+};
+
+status_t cvt_mem_desc_to_layout_desc(const memory_desc_t &md_,
+ layout_desc_t &ld) {
+ const auto md = memory_desc_wrapper(md_);
+
+ bool ok = true
+ && md.is_blocking_desc()
+ && md.extra().flags == 0;
+ if (!ok) return invalid_arguments;
+
+ const auto &bd = md.blocking_desc();
+
+ ld.ndims = 0;
+ ld.dt = md.data_type();
+
+ auto P = [&ld](int id, int dim, ptrdiff_t stride) {
+ assert((size_t)ld.ndims < sizeof(ld.dims) / sizeof(ld.dims[0]));
+ ld.id[ld.ndims] = id;
+ ld.dims[ld.ndims] = dim;
+ ld.strides[ld.ndims] = stride;
+ ++ld.ndims;
+ };
+
+ dims_t blocks;
+ md.compute_blocks(blocks);
+
+ for (int d = 0; d < md.ndims(); ++d) {
+ const int ld_ndims_start = ld.ndims;
+ if (blocks[d] != 1) {
+ stride_t stride = 1;
+ for (int iblk = bd.inner_nblks - 1; iblk >= 0; --iblk) {
+ if (bd.inner_idxs[iblk] == d)
+ P(d, bd.inner_blks[iblk], stride);
+ stride *= bd.inner_blks[iblk];
+ }
+ }
+ P(d, md.padded_dims()[d] / blocks[d], bd.strides[d]);
+
+ // TODO: NOW: revisit, do we need a reverse?
+ // TODO: NOW: consider using strides instead of block sizes in md
+ // reverse the order of dims
+ for (int ld_d = 0; ld_d < (ld.ndims - ld_ndims_start) / 2; ++ld_d) {
+ const int idx0 = ld_ndims_start + ld_d;
+ const int idx1 = ld.ndims - 1 - ld_d;
+ nstl::swap(ld.dims[idx0], ld.dims[idx1]);
+ nstl::swap(ld.strides[idx0], ld.strides[idx1]);
+ }
+ }
+
+ return success;
+}
+
+status_t prb_init(prb_t &p, const memory_desc_t &imd, const memory_desc_t &omd,
+ const primitive_attr_t *attr) {
+ auto im_d = memory_desc_wrapper(imd);
+ auto om_d = memory_desc_wrapper(omd);
+
+ bool ok = true
+ && im_d.is_blocking_desc()
+ && om_d.is_blocking_desc()
+ && !im_d.has_zero_dim()
+ && !om_d.has_zero_dim();
+ if (!ok)
+ return unimplemented;
+
+ dims_t iblocks, oblocks;
+ im_d.compute_blocks(iblocks);
+ om_d.compute_blocks(oblocks);
+
+ /* padding_dim consistency check */
+ for (int d = 0; d < im_d.ndims(); ++d) {
+ const auto pdim = im_d.padded_dims()[d];
+ bool ok = true
+ && pdim == om_d.padded_dims()[d]
+ && pdim % iblocks[d] == 0
+ && pdim % oblocks[d] == 0;
+ if (!ok) return unimplemented;
+ }
+
+ layout_desc_t ild, old;
+ status_t status = cvt_mem_desc_to_layout_desc(imd, ild);
+ if (status != success) return status;
+ status = cvt_mem_desc_to_layout_desc(omd, old);
+ if (status != success) return status;
+
+ p.itype = ild.dt;
+ p.otype = old.dt;
+
+ p.scale_type = attr->output_scales_.has_default_values()
+ ? scale_type_t::NONE
+ : (attr->output_scales_.mask_ == 0
+ ? scale_type_t::COMMON
+ : scale_type_t::MANY);
+
+ ptrdiff_t ss[max_ndims] = {0};
+ if (p.scale_type == scale_type_t::MANY) {
+ ptrdiff_t last_ss = 1;
+ for (int d = old.ndims - 1; d >=0; --d) {
+ assert((d == 0 || old.id[d - 1] <= old.id[d])
+ && "logical dimensions should be in ascending order");
+ if (attr->output_scales_.mask_ & (1 << old.id[d])) {
+ ss[d] = last_ss;
+ last_ss *= old.dims[d];
+ }
+ }
+ }
+
+ int ndims = 0;
+
+ int i_pos = 0; /* state for input -- current dimension */
+ int o_pos = 0; /* state for output -- current dimension */
+
+ while (i_pos < ild.ndims && o_pos < old.ndims) {
+ assert(ild.id[i_pos] == old.id[o_pos]);
+ if (ild.id[i_pos] != old.id[o_pos])
+ return runtime_error;
+
+ assert(ndims < max_ndims);
+ if (ndims == max_ndims)
+ return runtime_error;
+
+ if (ild.dims[i_pos] == old.dims[o_pos]) {
+ p.nodes[ndims].n = ild.dims[i_pos];
+ p.nodes[ndims].is = ild.strides[i_pos];
+ p.nodes[ndims].os = old.strides[o_pos];
+ p.nodes[ndims].ss = ss[o_pos];
+ ++ndims;
+ ++i_pos;
+ ++o_pos;
+ } else if (ild.dims[i_pos] < old.dims[o_pos]) {
+ assert(old.dims[o_pos] % ild.dims[i_pos] == 0);
+ int factor = old.dims[o_pos] / ild.dims[i_pos];
+ p.nodes[ndims].n = ild.dims[i_pos];
+ p.nodes[ndims].is = ild.strides[i_pos];
+ p.nodes[ndims].os = old.strides[o_pos] * factor;
+ p.nodes[ndims].ss = ss[o_pos] * factor;
+ ++ndims;
+ ++i_pos;
+ old.dims[o_pos] = factor;
+ } else if (ild.dims[i_pos] > old.dims[o_pos]) {
+ assert(ild.dims[i_pos] % old.dims[o_pos] == 0);
+ int factor = ild.dims[i_pos] / old.dims[o_pos];
+ p.nodes[ndims].n = old.dims[o_pos];
+ p.nodes[ndims].is = ild.strides[i_pos] * factor;
+ p.nodes[ndims].os = old.strides[o_pos];
+ p.nodes[ndims].ss = ss[o_pos];
+ ++ndims;
+ ++o_pos;
+ ild.dims[i_pos] = factor;
+ }
+ }
+ p.ndims = ndims;
+
+ dims_t zero_pos = {0};
+ p.ioff = memory_desc_wrapper(imd).off_v(zero_pos);
+ p.ooff = memory_desc_wrapper(omd).off_v(zero_pos);
+
+ const int sum_idx = attr->post_ops_.find(primitive_kind::sum);
+ p.beta = sum_idx == -1 ? 0.f : attr->post_ops_.entry_[sum_idx].sum.scale;
+
+ return success;
+}
+
+void prb_normalize(prb_t &p) {
+ for (int d = 0; d < p.ndims; ++d) {
+ int min_pos = d;
+ for (int j = d + 1; j < p.ndims; ++j) {
+ bool new_min = false
+ || p.nodes[j].os < p.nodes[min_pos].os
+ || (true
+ && p.nodes[j].os == p.nodes[min_pos].os
+ && p.nodes[j].n < p.nodes[min_pos].n);
+ if (new_min) min_pos = j;
+ }
+ if (min_pos != d)
+ nstl::swap(p.nodes[d], p.nodes[min_pos]);
+ }
+}
+
+void prb_simplify(prb_t &p) {
+#if defined(__GNUC__) && __GNUC__ >= 4
+/* GCC produces bogus array subscript is above array bounds warning for
+ * the `p.nodes[j - 1] = p.nodes[j]` line below, so disable it for now. */
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Warray-bounds"
+#endif
+ for (int d = 0; d < p.ndims - 1; ++d) {
+ auto &this_node = p.nodes[d + 0];
+ auto &next_node = p.nodes[d + 1];
+ const bool fold = false
+ || next_node.n == (size_t)1 // trivial case, just drop next node
+ || (true // or real folding if possible
+ && next_node.is == (ptrdiff_t)this_node.n * this_node.is
+ && next_node.os == (ptrdiff_t)this_node.n * this_node.os
+ && next_node.ss == (ptrdiff_t)this_node.n * this_node.ss);
+ if (fold) {
+ this_node.n *= next_node.n;
+ for (int j = d + 2; j < p.ndims; ++j)
+ p.nodes[j - 1] = p.nodes[j];
+ --p.ndims;
+ --d; // make another try
+ }
+ }
+#if defined(__GNUC__) && __GNUC__ >= 4
+#pragma GCC diagnostic pop
+#endif
+}
+
+void prb_node_split(prb_t &p, int dim, size_t n1) {
+ assert(dim < p.ndims);
+ assert(p.ndims < max_ndims);
+ assert(p.nodes[dim].n % n1 == 0);
+
+ p.ndims += 1;
+
+ for (int d = p.ndims; d > dim + 1; --d)
+ p.nodes[d] = p.nodes[d - 1];
+
+ p.nodes[dim + 1].n = p.nodes[dim].n / n1;
+ p.nodes[dim + 1].is = p.nodes[dim].is * n1;
+ p.nodes[dim + 1].os = p.nodes[dim].os * n1;
+ p.nodes[dim + 1].ss = p.nodes[dim].ss * n1;
+
+ p.nodes[dim].n = n1;
+}
+
+void prb_node_swap(prb_t &p, int d0, int d1) {
+ assert(d0 < p.ndims);
+ assert(d1 < p.ndims);
+ assert(p.ndims < max_ndims);
+
+ if (d0 == d1) return;
+
+ nstl::swap(p.nodes[d0], p.nodes[d1]);
+}
+
+void prb_node_move(prb_t &p, int d0, int d1) {
+ assert(d0 < p.ndims);
+ assert(d1 < p.ndims);
+ assert(p.ndims < max_ndims);
+
+ if (d0 == d1) return;
+
+ node_t node = p.nodes[d0];
+
+ if (d0 < d1)
+ for (int d = d0; d < d1; ++d)
+ p.nodes[d] = p.nodes[d + 1];
+ else
+ for (int d = d0; d > d1; --d)
+ p.nodes[d] = p.nodes[d - 1];
+
+ p.nodes[d1] = node;
+}
+
+void prb_dump(const prb_t &p) {
+ printf("@@@ type:%s:%s ndims:%d ", mkldnn_dt2str(p.itype),
+ mkldnn_dt2str(p.otype), p.ndims);
+ for (int d = 0; d < p.ndims; ++d)
+ printf("[%zu:%td:%td:%td]",
+ p.nodes[d].n, p.nodes[d].is, p.nodes[d].os, p.nodes[d].ss);
+ printf(" off:%zu:%zu\n", p.ioff, p.ooff);
+}
+
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp
new file mode 100644
index 0000000000..08747aa89c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.cpp
@@ -0,0 +1,115 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <mutex>
+
+#include "utils.hpp"
+
+#ifndef MKLDNN_ENABLE_JIT_PROFILING
+#define MKLDNN_ENABLE_JIT_PROFILING 1
+#endif
+
+#ifndef MKLDNN_ENABLE_JIT_DUMP
+#define MKLDNN_ENABLE_JIT_DUMP 1
+#endif
+
+#if MKLDNN_ENABLE_JIT_PROFILING
+#include "jitprofiling/jitprofiling.h"
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+namespace jit_utils {
+
+// WARNING: These functions are not thread safe and must be protected by a
+// mutex
+
+void dump_jit_code(const void *code, size_t code_size, const char *code_name)
+{
+#if MKLDNN_ENABLE_JIT_DUMP
+ if (code && jit_dump_enabled()) {
+ static int counter = 0;
+#define MAX_FNAME_LEN 256
+ char fname[MAX_FNAME_LEN + 1];
+ // TODO (Roma): support prefix for code / linux perf dumps
+ snprintf(fname, MAX_FNAME_LEN, "mkldnn_dump_%s.%d.bin", code_name,
+ counter);
+ counter++;
+
+ FILE *fp = fopen(fname, "w+");
+ // Failure to dump code is not fatal
+ if (fp) {
+ size_t unused = fwrite(code, code_size, 1, fp);
+ UNUSED(unused);
+ fclose(fp);
+ }
+ }
+#undef MAX_FNAME_LEN
+#else
+ UNUSED(code);
+ UNUSED(code_size);
+ UNUSED(code_name);
+#endif
+}
+
+void register_jit_code_vtune(const void *code, size_t code_size,
+ const char *code_name, const char *source_file_name)
+{
+#if MKLDNN_ENABLE_JIT_PROFILING
+ if (iJIT_IsProfilingActive() == iJIT_SAMPLING_ON) {
+ auto jmethod = iJIT_Method_Load();
+ jmethod.method_id = iJIT_GetNewMethodID(); // XXX: not thread-safe
+ jmethod.method_name = (char *)code_name; // XXX: dropping const
+ jmethod.class_file_name = NULL;
+ jmethod.source_file_name = (char *)source_file_name; // XXX: dropping const
+ jmethod.method_load_address = (void *)code;
+ jmethod.method_size = (unsigned int)code_size;
+
+ iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED,
+ (void*)&jmethod);
+ }
+#else
+ UNUSED(code);
+ UNUSED(code_size);
+ UNUSED(code_name);
+ UNUSED(source_file_name);
+#endif
+}
+
+void register_jit_code(const void *code, size_t code_size,
+ const char *code_name, const char *source_file_name)
+{
+ // The #ifdef guards are required to avoid generating a function that only
+ // consists of lock and unlock code
+#if MKLDNN_ENABLE_JIT_PROFILING || MKLDNN_ENABLE_JIT_DUMP
+ static std::mutex m;
+ std::lock_guard<std::mutex> guard(m);
+
+ dump_jit_code(code, code_size, code_name);
+ register_jit_code_vtune(code, code_size, code_name, source_file_name);
+#else
+ UNUSED(code);
+ UNUSED(code_size);
+ UNUSED(code_name);
+ UNUSED(source_file_name);
+#endif
+}
+
+}
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp
new file mode 100644
index 0000000000..2f52dba4ac
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jit_utils.hpp
@@ -0,0 +1,32 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef JIT_SUPPORT_HPP
+#define JIT_SUPPORT_HPP
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+namespace jit_utils {
+
+void register_jit_code(const void *code, size_t code_size,
+ const char *code_name, const char *source_file_name);
+
+}
+}
+}
+}
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD
new file mode 100644
index 0000000000..4fd21cea57
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/LICENSE.BSD
@@ -0,0 +1,27 @@
+Copyright (c) 2011, Intel Corporation
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md
new file mode 100644
index 0000000000..fc67c4f134
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/README.md
@@ -0,0 +1 @@
+This code is from [Intel SEAPI library](https://github.com/intel/IntelSEAPI)
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h
new file mode 100644
index 0000000000..edbf4a15f0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_config.h
@@ -0,0 +1,595 @@
+/* <copyright>
+
+ Contact Information:
+ http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/
+
+ BSD LICENSE
+
+ Copyright (c) 2005-2014 Intel Corporation. All rights reserved.
+ All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+ * Neither the name of Intel Corporation nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+</copyright> */
+#ifndef _ITTNOTIFY_CONFIG_H_
+#define _ITTNOTIFY_CONFIG_H_
+
+/** @cond exclude_from_documentation */
+#ifndef ITT_OS_WIN
+# define ITT_OS_WIN 1
+#endif /* ITT_OS_WIN */
+
+#ifndef ITT_OS_LINUX
+# define ITT_OS_LINUX 2
+#endif /* ITT_OS_LINUX */
+
+#ifndef ITT_OS_MAC
+# define ITT_OS_MAC 3
+#endif /* ITT_OS_MAC */
+
+#ifndef ITT_OS_FREEBSD
+# define ITT_OS_FREEBSD 4
+#endif /* ITT_OS_FREEBSD */
+
+#ifndef ITT_OS
+# if defined WIN32 || defined _WIN32
+# define ITT_OS ITT_OS_WIN
+# elif defined( __APPLE__ ) && defined( __MACH__ )
+# define ITT_OS ITT_OS_MAC
+# elif defined( __FreeBSD__ )
+# define ITT_OS ITT_OS_FREEBSD
+# else
+# define ITT_OS ITT_OS_LINUX
+# endif
+#endif /* ITT_OS */
+
+#ifndef ITT_PLATFORM_WIN
+# define ITT_PLATFORM_WIN 1
+#endif /* ITT_PLATFORM_WIN */
+
+#ifndef ITT_PLATFORM_POSIX
+# define ITT_PLATFORM_POSIX 2
+#endif /* ITT_PLATFORM_POSIX */
+
+#ifndef ITT_PLATFORM_MAC
+# define ITT_PLATFORM_MAC 3
+#endif /* ITT_PLATFORM_MAC */
+
+#ifndef ITT_PLATFORM_FREEBSD
+# define ITT_PLATFORM_FREEBSD 4
+#endif /* ITT_PLATFORM_FREEBSD */
+
+#ifndef ITT_PLATFORM
+# if ITT_OS==ITT_OS_WIN
+# define ITT_PLATFORM ITT_PLATFORM_WIN
+# elif ITT_OS==ITT_OS_MAC
+# define ITT_PLATFORM ITT_PLATFORM_MAC
+# elif ITT_OS==ITT_OS_FREEBSD
+# define ITT_PLATFORM ITT_PLATFORM_FREEBSD
+# else
+# define ITT_PLATFORM ITT_PLATFORM_POSIX
+# endif
+#endif /* ITT_PLATFORM */
+
+#if defined(_UNICODE) && !defined(UNICODE)
+#define UNICODE
+#endif
+
+#include <stddef.h>
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+#include <tchar.h>
+#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+#include <stdint.h>
+#if defined(UNICODE) || defined(_UNICODE)
+#include <wchar.h>
+#endif /* UNICODE || _UNICODE */
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+
+#ifndef ITTAPI_CDECL
+# if ITT_PLATFORM==ITT_PLATFORM_WIN
+# define ITTAPI_CDECL __cdecl
+# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+# if defined _M_IX86 || defined __i386__
+# define ITTAPI_CDECL __attribute__ ((cdecl))
+# else /* _M_IX86 || __i386__ */
+# define ITTAPI_CDECL /* actual only on x86 platform */
+# endif /* _M_IX86 || __i386__ */
+# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+#endif /* ITTAPI_CDECL */
+
+#ifndef STDCALL
+# if ITT_PLATFORM==ITT_PLATFORM_WIN
+# define STDCALL __stdcall
+# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+# if defined _M_IX86 || defined __i386__
+# define STDCALL __attribute__ ((stdcall))
+# else /* _M_IX86 || __i386__ */
+# define STDCALL /* supported only on x86 platform */
+# endif /* _M_IX86 || __i386__ */
+# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+#endif /* STDCALL */
+
+#define ITTAPI ITTAPI_CDECL
+#define LIBITTAPI ITTAPI_CDECL
+
+/* TODO: Temporary for compatibility! */
+#define ITTAPI_CALL ITTAPI_CDECL
+#define LIBITTAPI_CALL ITTAPI_CDECL
+
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+/* use __forceinline (VC++ specific) */
+#define ITT_INLINE __forceinline
+#define ITT_INLINE_ATTRIBUTE /* nothing */
+#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+/*
+ * Generally, functions are not inlined unless optimization is specified.
+ * For functions declared inline, this attribute inlines the function even
+ * if no optimization level was specified.
+ */
+#ifdef __STRICT_ANSI__
+#define ITT_INLINE static
+#define ITT_INLINE_ATTRIBUTE __attribute__((unused))
+#else /* __STRICT_ANSI__ */
+#define ITT_INLINE static inline
+#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused))
+#endif /* __STRICT_ANSI__ */
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+/** @endcond */
+
+#ifndef ITT_ARCH_IA32
+# define ITT_ARCH_IA32 1
+#endif /* ITT_ARCH_IA32 */
+
+#ifndef ITT_ARCH_IA32E
+# define ITT_ARCH_IA32E 2
+#endif /* ITT_ARCH_IA32E */
+
+#ifndef ITT_ARCH_ARM
+# define ITT_ARCH_ARM 4
+#endif /* ITT_ARCH_ARM */
+
+#ifndef ITT_ARCH_PPC64
+# define ITT_ARCH_PPC64 5
+#endif /* ITT_ARCH_PPC64 */
+
+#ifndef ITT_ARCH
+# if defined _M_IX86 || defined __i386__
+# define ITT_ARCH ITT_ARCH_IA32
+# elif defined _M_X64 || defined _M_AMD64 || defined __x86_64__
+# define ITT_ARCH ITT_ARCH_IA32E
+# elif defined _M_IA64 || defined __ia64__
+# define ITT_ARCH ITT_ARCH_IA64
+# elif defined _M_ARM || defined __arm__
+# define ITT_ARCH ITT_ARCH_ARM
+# elif defined __powerpc64__
+# define ITT_ARCH ITT_ARCH_PPC64
+# endif
+#endif
+
+#ifdef __cplusplus
+# define ITT_EXTERN_C extern "C"
+# define ITT_EXTERN_C_BEGIN extern "C" {
+# define ITT_EXTERN_C_END }
+#else
+# define ITT_EXTERN_C /* nothing */
+# define ITT_EXTERN_C_BEGIN /* nothing */
+# define ITT_EXTERN_C_END /* nothing */
+#endif /* __cplusplus */
+
+#define ITT_TO_STR_AUX(x) #x
+#define ITT_TO_STR(x) ITT_TO_STR_AUX(x)
+
+#define __ITT_BUILD_ASSERT(expr, suffix) do { \
+ static char __itt_build_check_##suffix[(expr) ? 1 : -1]; \
+ __itt_build_check_##suffix[0] = 0; \
+} while(0)
+#define _ITT_BUILD_ASSERT(expr, suffix) __ITT_BUILD_ASSERT((expr), suffix)
+#define ITT_BUILD_ASSERT(expr) _ITT_BUILD_ASSERT((expr), __LINE__)
+
+#define ITT_MAGIC { 0xED, 0xAB, 0xAB, 0xEC, 0x0D, 0xEE, 0xDA, 0x30 }
+
+/* Replace with snapshot date YYYYMMDD for promotion build. */
+#define API_VERSION_BUILD 20151119
+
+#ifndef API_VERSION_NUM
+#define API_VERSION_NUM 0.0.0
+#endif /* API_VERSION_NUM */
+
+#define API_VERSION "ITT-API-Version " ITT_TO_STR(API_VERSION_NUM) \
+ " (" ITT_TO_STR(API_VERSION_BUILD) ")"
+
+/* OS communication functions */
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+#include <windows.h>
+typedef HMODULE lib_t;
+typedef DWORD TIDT;
+typedef CRITICAL_SECTION mutex_t;
+#define MUTEX_INITIALIZER { 0 }
+#define strong_alias(name, aliasname) /* empty for Windows */
+#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+#include <dlfcn.h>
+#if defined(UNICODE) || defined(_UNICODE)
+#include <wchar.h>
+#endif /* UNICODE */
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE 1 /* need for PTHREAD_MUTEX_RECURSIVE */
+#endif /* _GNU_SOURCE */
+#ifndef __USE_UNIX98
+#define __USE_UNIX98 1 /* need for PTHREAD_MUTEX_RECURSIVE, on SLES11.1 with gcc 4.3.4 wherein pthread.h missing dependency on __USE_XOPEN2K8 */
+#endif /*__USE_UNIX98*/
+#include <pthread.h>
+typedef void* lib_t;
+typedef pthread_t TIDT;
+typedef pthread_mutex_t mutex_t;
+#define MUTEX_INITIALIZER PTHREAD_MUTEX_INITIALIZER
+#define _strong_alias(name, aliasname) \
+ extern __typeof (name) aliasname __attribute__ ((alias (#name)));
+#define strong_alias(name, aliasname) _strong_alias(name, aliasname)
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+#define __itt_get_proc(lib, name) GetProcAddress(lib, name)
+#define __itt_mutex_init(mutex) InitializeCriticalSection(mutex)
+#define __itt_mutex_lock(mutex) EnterCriticalSection(mutex)
+#define __itt_mutex_unlock(mutex) LeaveCriticalSection(mutex)
+#define __itt_load_lib(name) LoadLibraryA(name)
+#define __itt_unload_lib(handle) FreeLibrary(handle)
+#define __itt_system_error() (int)GetLastError()
+#define __itt_fstrcmp(s1, s2) lstrcmpA(s1, s2)
+#define __itt_fstrnlen(s, l) strnlen_s(s, l)
+#define __itt_fstrcpyn(s1, b, s2, l) strncpy_s(s1, b, s2, l)
+#define __itt_fstrdup(s) _strdup(s)
+#define __itt_thread_id() GetCurrentThreadId()
+#define __itt_thread_yield() SwitchToThread()
+#ifndef ITT_SIMPLE_INIT
+ITT_INLINE long
+__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE;
+ITT_INLINE long __itt_interlocked_increment(volatile long* ptr)
+{
+ return InterlockedIncrement(ptr);
+}
+#endif /* ITT_SIMPLE_INIT */
+
+#define DL_SYMBOLS (1)
+#define PTHREAD_SYMBOLS (1)
+
+#else /* ITT_PLATFORM!=ITT_PLATFORM_WIN */
+#define __itt_get_proc(lib, name) dlsym(lib, name)
+#define __itt_mutex_init(mutex) {\
+ pthread_mutexattr_t mutex_attr; \
+ int error_code = pthread_mutexattr_init(&mutex_attr); \
+ if (error_code) \
+ __itt_report_error(__itt_error_system, "pthread_mutexattr_init", \
+ error_code); \
+ error_code = pthread_mutexattr_settype(&mutex_attr, \
+ PTHREAD_MUTEX_RECURSIVE); \
+ if (error_code) \
+ __itt_report_error(__itt_error_system, "pthread_mutexattr_settype", \
+ error_code); \
+ error_code = pthread_mutex_init(mutex, &mutex_attr); \
+ if (error_code) \
+ __itt_report_error(__itt_error_system, "pthread_mutex_init", \
+ error_code); \
+ error_code = pthread_mutexattr_destroy(&mutex_attr); \
+ if (error_code) \
+ __itt_report_error(__itt_error_system, "pthread_mutexattr_destroy", \
+ error_code); \
+}
+#define __itt_mutex_lock(mutex) pthread_mutex_lock(mutex)
+#define __itt_mutex_unlock(mutex) pthread_mutex_unlock(mutex)
+#define __itt_load_lib(name) dlopen(name, RTLD_LAZY)
+#define __itt_unload_lib(handle) dlclose(handle)
+#define __itt_system_error() errno
+#define __itt_fstrcmp(s1, s2) strcmp(s1, s2)
+
+/* makes customer code define safe APIs for SDL_STRNLEN_S and SDL_STRNCPY_S */
+#ifdef SDL_STRNLEN_S
+#define __itt_fstrnlen(s, l) SDL_STRNLEN_S(s, l)
+#else
+#define __itt_fstrnlen(s, l) strlen(s)
+#endif /* SDL_STRNLEN_S */
+#ifdef SDL_STRNCPY_S
+#define __itt_fstrcpyn(s1, b, s2, l) SDL_STRNCPY_S(s1, b, s2, l)
+#else
+#define __itt_fstrcpyn(s1, b, s2, l) strncpy(s1, s2, l)
+#endif /* SDL_STRNCPY_S */
+
+#define __itt_fstrdup(s) strdup(s)
+#define __itt_thread_id() pthread_self()
+#define __itt_thread_yield() sched_yield()
+#if ITT_ARCH==ITT_ARCH_IA64
+#ifdef __INTEL_COMPILER
+#define __TBB_machine_fetchadd4(addr, val) __fetchadd4_acq((void *)addr, val)
+#else /* __INTEL_COMPILER */
+/* TODO: Add Support for not Intel compilers for IA-64 architecture */
+#endif /* __INTEL_COMPILER */
+#elif ITT_ARCH==ITT_ARCH_IA32 || ITT_ARCH==ITT_ARCH_IA32E /* ITT_ARCH!=ITT_ARCH_IA64 */
+ITT_INLINE long
+__TBB_machine_fetchadd4(volatile void* ptr, long addend) ITT_INLINE_ATTRIBUTE;
+ITT_INLINE long __TBB_machine_fetchadd4(volatile void* ptr, long addend)
+{
+ long result;
+ __asm__ __volatile__("lock\nxadd %0,%1"
+ : "=r"(result),"=m"(*(int*)ptr)
+ : "0"(addend), "m"(*(int*)ptr)
+ : "memory");
+ return result;
+}
+#elif ITT_ARCH==ITT_ARCH_ARM || ITT_ARCH==ITT_ARCH_PPC64
+#define __TBB_machine_fetchadd4(addr, val) __sync_fetch_and_add(addr, val)
+#endif /* ITT_ARCH==ITT_ARCH_IA64 */
+#ifndef ITT_SIMPLE_INIT
+ITT_INLINE long
+__itt_interlocked_increment(volatile long* ptr) ITT_INLINE_ATTRIBUTE;
+ITT_INLINE long __itt_interlocked_increment(volatile long* ptr)
+{
+ return __TBB_machine_fetchadd4(ptr, 1) + 1L;
+}
+#endif /* ITT_SIMPLE_INIT */
+
+void* dlopen(const char*, int) __attribute__((weak));
+void* dlsym(void*, const char*) __attribute__((weak));
+int dlclose(void*) __attribute__((weak));
+#define DL_SYMBOLS (dlopen && dlsym && dlclose)
+
+int pthread_mutex_init(pthread_mutex_t*, const pthread_mutexattr_t*) __attribute__((weak));
+int pthread_mutex_lock(pthread_mutex_t*) __attribute__((weak));
+int pthread_mutex_unlock(pthread_mutex_t*) __attribute__((weak));
+int pthread_mutex_destroy(pthread_mutex_t*) __attribute__((weak));
+int pthread_mutexattr_init(pthread_mutexattr_t*) __attribute__((weak));
+int pthread_mutexattr_settype(pthread_mutexattr_t*, int) __attribute__((weak));
+int pthread_mutexattr_destroy(pthread_mutexattr_t*) __attribute__((weak));
+pthread_t pthread_self(void) __attribute__((weak));
+#define PTHREAD_SYMBOLS (pthread_mutex_init && pthread_mutex_lock && pthread_mutex_unlock && pthread_mutex_destroy && pthread_mutexattr_init && pthread_mutexattr_settype && pthread_mutexattr_destroy && pthread_self)
+
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+
+typedef enum {
+ __itt_collection_normal = 0,
+ __itt_collection_paused = 1
+} __itt_collection_state;
+
+typedef enum {
+ __itt_thread_normal = 0,
+ __itt_thread_ignored = 1
+} __itt_thread_state;
+
+#pragma pack(push, 8)
+
+typedef struct ___itt_thread_info
+{
+ const char* nameA; /*!< Copy of original name in ASCII. */
+#if defined(UNICODE) || defined(_UNICODE)
+ const wchar_t* nameW; /*!< Copy of original name in UNICODE. */
+#else /* UNICODE || _UNICODE */
+ void* nameW;
+#endif /* UNICODE || _UNICODE */
+ TIDT tid;
+ __itt_thread_state state; /*!< Thread state (paused or normal) */
+ int extra1; /*!< Reserved to the runtime */
+ void* extra2; /*!< Reserved to the runtime */
+ struct ___itt_thread_info* next;
+} __itt_thread_info;
+
+#include "ittnotify_types.h" /* For __itt_group_id definition */
+
+typedef struct ___itt_api_info_20101001
+{
+ const char* name;
+ void** func_ptr;
+ void* init_func;
+ __itt_group_id group;
+} __itt_api_info_20101001;
+
+typedef struct ___itt_api_info
+{
+ const char* name;
+ void** func_ptr;
+ void* init_func;
+ void* null_func;
+ __itt_group_id group;
+} __itt_api_info;
+
+typedef struct __itt_counter_info
+{
+ const char* nameA; /*!< Copy of original name in ASCII. */
+#if defined(UNICODE) || defined(_UNICODE)
+ const wchar_t* nameW; /*!< Copy of original name in UNICODE. */
+#else /* UNICODE || _UNICODE */
+ void* nameW;
+#endif /* UNICODE || _UNICODE */
+ const char* domainA; /*!< Copy of original name in ASCII. */
+#if defined(UNICODE) || defined(_UNICODE)
+ const wchar_t* domainW; /*!< Copy of original name in UNICODE. */
+#else /* UNICODE || _UNICODE */
+ void* domainW;
+#endif /* UNICODE || _UNICODE */
+ int type;
+ long index;
+ int extra1; /*!< Reserved to the runtime */
+ void* extra2; /*!< Reserved to the runtime */
+ struct __itt_counter_info* next;
+} __itt_counter_info_t;
+
+struct ___itt_domain;
+struct ___itt_string_handle;
+
+typedef struct ___itt_global
+{
+ unsigned char magic[8];
+ unsigned long version_major;
+ unsigned long version_minor;
+ unsigned long version_build;
+ volatile long api_initialized;
+ volatile long mutex_initialized;
+ volatile long atomic_counter;
+ mutex_t mutex;
+ lib_t lib;
+ void* error_handler;
+ const char** dll_path_ptr;
+ __itt_api_info* api_list_ptr;
+ struct ___itt_global* next;
+ /* Joinable structures below */
+ __itt_thread_info* thread_list;
+ struct ___itt_domain* domain_list;
+ struct ___itt_string_handle* string_list;
+ __itt_collection_state state;
+ __itt_counter_info_t* counter_list;
+} __itt_global;
+
+#pragma pack(pop)
+
+#define NEW_THREAD_INFO_W(gptr,h,h_tail,t,s,n) { \
+ h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \
+ if (h != NULL) { \
+ h->tid = t; \
+ h->nameA = NULL; \
+ h->nameW = n ? _wcsdup(n) : NULL; \
+ h->state = s; \
+ h->extra1 = 0; /* reserved */ \
+ h->extra2 = NULL; /* reserved */ \
+ h->next = NULL; \
+ if (h_tail == NULL) \
+ (gptr)->thread_list = h; \
+ else \
+ h_tail->next = h; \
+ } \
+}
+
+#define NEW_THREAD_INFO_A(gptr,h,h_tail,t,s,n) { \
+ h = (__itt_thread_info*)malloc(sizeof(__itt_thread_info)); \
+ if (h != NULL) { \
+ h->tid = t; \
+ h->nameA = n ? __itt_fstrdup(n) : NULL; \
+ h->nameW = NULL; \
+ h->state = s; \
+ h->extra1 = 0; /* reserved */ \
+ h->extra2 = NULL; /* reserved */ \
+ h->next = NULL; \
+ if (h_tail == NULL) \
+ (gptr)->thread_list = h; \
+ else \
+ h_tail->next = h; \
+ } \
+}
+
+#define NEW_DOMAIN_W(gptr,h,h_tail,name) { \
+ h = (__itt_domain*)malloc(sizeof(__itt_domain)); \
+ if (h != NULL) { \
+ h->flags = 1; /* domain is enabled by default */ \
+ h->nameA = NULL; \
+ h->nameW = name ? _wcsdup(name) : NULL; \
+ h->extra1 = 0; /* reserved */ \
+ h->extra2 = NULL; /* reserved */ \
+ h->next = NULL; \
+ if (h_tail == NULL) \
+ (gptr)->domain_list = h; \
+ else \
+ h_tail->next = h; \
+ } \
+}
+
+#define NEW_DOMAIN_A(gptr,h,h_tail,name) { \
+ h = (__itt_domain*)malloc(sizeof(__itt_domain)); \
+ if (h != NULL) { \
+ h->flags = 1; /* domain is enabled by default */ \
+ h->nameA = name ? __itt_fstrdup(name) : NULL; \
+ h->nameW = NULL; \
+ h->extra1 = 0; /* reserved */ \
+ h->extra2 = NULL; /* reserved */ \
+ h->next = NULL; \
+ if (h_tail == NULL) \
+ (gptr)->domain_list = h; \
+ else \
+ h_tail->next = h; \
+ } \
+}
+
+#define NEW_STRING_HANDLE_W(gptr,h,h_tail,name) { \
+ h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \
+ if (h != NULL) { \
+ h->strA = NULL; \
+ h->strW = name ? _wcsdup(name) : NULL; \
+ h->extra1 = 0; /* reserved */ \
+ h->extra2 = NULL; /* reserved */ \
+ h->next = NULL; \
+ if (h_tail == NULL) \
+ (gptr)->string_list = h; \
+ else \
+ h_tail->next = h; \
+ } \
+}
+
+#define NEW_STRING_HANDLE_A(gptr,h,h_tail,name) { \
+ h = (__itt_string_handle*)malloc(sizeof(__itt_string_handle)); \
+ if (h != NULL) { \
+ h->strA = name ? __itt_fstrdup(name) : NULL; \
+ h->strW = NULL; \
+ h->extra1 = 0; /* reserved */ \
+ h->extra2 = NULL; /* reserved */ \
+ h->next = NULL; \
+ if (h_tail == NULL) \
+ (gptr)->string_list = h; \
+ else \
+ h_tail->next = h; \
+ } \
+}
+
+#define NEW_COUNTER_W(gptr,h,h_tail,name,domain,type) { \
+ h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \
+ if (h != NULL) { \
+ h->nameA = NULL; \
+ h->nameW = name ? _wcsdup(name) : NULL; \
+ h->domainA = NULL; \
+ h->domainW = name ? _wcsdup(domain) : NULL; \
+ h->type = type; \
+ h->index = 0; \
+ h->next = NULL; \
+ if (h_tail == NULL) \
+ (gptr)->counter_list = h; \
+ else \
+ h_tail->next = h; \
+ } \
+}
+
+#define NEW_COUNTER_A(gptr,h,h_tail,name,domain,type) { \
+ h = (__itt_counter_info_t*)malloc(sizeof(__itt_counter_info_t)); \
+ if (h != NULL) { \
+ h->nameA = name ? __itt_fstrdup(name) : NULL; \
+ h->nameW = NULL; \
+ h->domainA = domain ? __itt_fstrdup(domain) : NULL; \
+ h->domainW = NULL; \
+ h->type = type; \
+ h->index = 0; \
+ h->next = NULL; \
+ if (h_tail == NULL) \
+ (gptr)->counter_list = h; \
+ else \
+ h_tail->next = h; \
+ } \
+}
+
+#endif /* _ITTNOTIFY_CONFIG_H_ */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h
new file mode 100644
index 0000000000..99fbc24054
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/ittnotify_types.h
@@ -0,0 +1,94 @@
+/* <copyright>
+
+ Contact Information:
+ http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/
+
+ BSD LICENSE
+
+ Copyright (c) 2005-2014 Intel Corporation. All rights reserved.
+ All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+ * Neither the name of Intel Corporation nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+</copyright> */
+
+#ifndef _ITTNOTIFY_TYPES_H_
+#define _ITTNOTIFY_TYPES_H_
+
+typedef enum ___itt_group_id
+{
+ __itt_group_none = 0,
+ __itt_group_legacy = 1<<0,
+ __itt_group_control = 1<<1,
+ __itt_group_thread = 1<<2,
+ __itt_group_mark = 1<<3,
+ __itt_group_sync = 1<<4,
+ __itt_group_fsync = 1<<5,
+ __itt_group_jit = 1<<6,
+ __itt_group_model = 1<<7,
+ __itt_group_splitter_min = 1<<7,
+ __itt_group_counter = 1<<8,
+ __itt_group_frame = 1<<9,
+ __itt_group_stitch = 1<<10,
+ __itt_group_heap = 1<<11,
+ __itt_group_splitter_max = 1<<12,
+ __itt_group_structure = 1<<12,
+ __itt_group_suppress = 1<<13,
+ __itt_group_arrays = 1<<14,
+ __itt_group_all = -1
+} __itt_group_id;
+
+#pragma pack(push, 8)
+
+typedef struct ___itt_group_list
+{
+ __itt_group_id id;
+ const char* name;
+} __itt_group_list;
+
+#pragma pack(pop)
+
+#define ITT_GROUP_LIST(varname) \
+ static __itt_group_list varname[] = { \
+ { __itt_group_all, "all" }, \
+ { __itt_group_control, "control" }, \
+ { __itt_group_thread, "thread" }, \
+ { __itt_group_mark, "mark" }, \
+ { __itt_group_sync, "sync" }, \
+ { __itt_group_fsync, "fsync" }, \
+ { __itt_group_jit, "jit" }, \
+ { __itt_group_model, "model" }, \
+ { __itt_group_counter, "counter" }, \
+ { __itt_group_frame, "frame" }, \
+ { __itt_group_stitch, "stitch" }, \
+ { __itt_group_heap, "heap" }, \
+ { __itt_group_structure, "structure" }, \
+ { __itt_group_suppress, "suppress" }, \
+ { __itt_group_arrays, "arrays" }, \
+ { __itt_group_none, NULL } \
+ }
+
+#endif /* _ITTNOTIFY_TYPES_H_ */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c
new file mode 100644
index 0000000000..15f4b9929b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c
@@ -0,0 +1,293 @@
+/* <copyright>
+
+ Contact Information:
+ http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/
+
+ BSD LICENSE
+
+ Copyright (c) 2005-2014 Intel Corporation. All rights reserved.
+ All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+ * Neither the name of Intel Corporation nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+</copyright> */
+
+#include "ittnotify_config.h"
+
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+#include <windows.h>
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+#if ITT_PLATFORM != ITT_PLATFORM_MAC && ITT_PLATFORM != ITT_PLATFORM_FREEBSD
+#include <malloc.h>
+#endif
+#include <stdlib.h>
+
+#include "jitprofiling.h"
+
+static const char rcsid[] = "\n@(#) $Revision: 471937 $\n";
+
+#define DLL_ENVIRONMENT_VAR "VS_PROFILER"
+
+#ifndef NEW_DLL_ENVIRONMENT_VAR
+#if ITT_ARCH==ITT_ARCH_IA32
+#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER32"
+#else
+#define NEW_DLL_ENVIRONMENT_VAR "INTEL_JIT_PROFILER64"
+#endif
+#endif /* NEW_DLL_ENVIRONMENT_VAR */
+
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+#define DEFAULT_DLLNAME "JitPI.dll"
+HINSTANCE m_libHandle = NULL;
+#elif ITT_PLATFORM==ITT_PLATFORM_MAC
+#define DEFAULT_DLLNAME "libJitPI.dylib"
+void* m_libHandle = NULL;
+#else
+#define DEFAULT_DLLNAME "libJitPI.so"
+void* m_libHandle = NULL;
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+
+/* default location of JIT profiling agent on Android */
+#define ANDROID_JIT_AGENT_PATH "/data/intel/libittnotify.so"
+
+/* the function pointers */
+typedef unsigned int(JITAPI *TPInitialize)(void);
+static TPInitialize FUNC_Initialize=NULL;
+
+typedef unsigned int(JITAPI *TPNotify)(unsigned int, void*);
+static TPNotify FUNC_NotifyEvent=NULL;
+
+static iJIT_IsProfilingActiveFlags executionMode = iJIT_NOTHING_RUNNING;
+
+/* end collector dll part. */
+
+/* loadiJIT_Funcs() : this function is called just in the beginning
+ * and is responsible to load the functions from BistroJavaCollector.dll
+ * result:
+ * on success: the functions loads, iJIT_DLL_is_missing=0, return value = 1
+ * on failure: the functions are NULL, iJIT_DLL_is_missing=1, return value = 0
+ */
+static int loadiJIT_Funcs(void);
+
+/* global representing whether the collector can't be loaded */
+static int iJIT_DLL_is_missing = 0;
+
+ITT_EXTERN_C int JITAPI
+iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData)
+{
+ int ReturnValue = 0;
+
+ /* initialization part - the collector has not been loaded yet. */
+ if (!FUNC_NotifyEvent)
+ {
+ if (iJIT_DLL_is_missing)
+ return 0;
+
+ if (!loadiJIT_Funcs())
+ return 0;
+ }
+
+ if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED ||
+ event_type == iJVM_EVENT_TYPE_METHOD_UPDATE)
+ {
+ if (((piJIT_Method_Load)EventSpecificData)->method_id == 0)
+ return 0;
+ }
+ else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2)
+ {
+ if (((piJIT_Method_Load_V2)EventSpecificData)->method_id == 0)
+ return 0;
+ }
+ else if (event_type == iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3)
+ {
+ if (((piJIT_Method_Load_V3)EventSpecificData)->method_id == 0)
+ return 0;
+ }
+ else if (event_type == iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED)
+ {
+ if (((piJIT_Method_Inline_Load)EventSpecificData)->method_id == 0 ||
+ ((piJIT_Method_Inline_Load)EventSpecificData)->parent_method_id == 0)
+ return 0;
+ }
+
+ ReturnValue = (int)FUNC_NotifyEvent(event_type, EventSpecificData);
+
+ return ReturnValue;
+}
+
+ITT_EXTERN_C iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive()
+{
+ if (!iJIT_DLL_is_missing)
+ {
+ loadiJIT_Funcs();
+ }
+
+ return executionMode;
+}
+
+/* This function loads the collector dll and the relevant functions.
+ * on success: all functions load, iJIT_DLL_is_missing = 0, return value = 1
+ * on failure: all functions are NULL, iJIT_DLL_is_missing = 1, return value = 0
+ */
+static int loadiJIT_Funcs()
+{
+ static int bDllWasLoaded = 0;
+ char *dllName = (char*)rcsid; /* !! Just to avoid unused code elimination */
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+ DWORD dNameLength = 0;
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+
+ if(bDllWasLoaded)
+ {
+ /* dll was already loaded, no need to do it for the second time */
+ return 1;
+ }
+
+ /* Assumes that the DLL will not be found */
+ iJIT_DLL_is_missing = 1;
+ FUNC_NotifyEvent = NULL;
+
+ if (m_libHandle)
+ {
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+ FreeLibrary(m_libHandle);
+#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+ dlclose(m_libHandle);
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+ m_libHandle = NULL;
+ }
+
+ /* Try to get the dll name from the environment */
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+ dNameLength = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR, NULL, 0);
+ if (dNameLength)
+ {
+ DWORD envret = 0;
+ dllName = (char*)malloc(sizeof(char) * (dNameLength + 1));
+ if(dllName != NULL)
+ {
+ envret = GetEnvironmentVariableA(NEW_DLL_ENVIRONMENT_VAR,
+ dllName, dNameLength);
+ if (envret)
+ {
+ /* Try to load the dll from the PATH... */
+ m_libHandle = LoadLibraryExA(dllName,
+ NULL, LOAD_WITH_ALTERED_SEARCH_PATH);
+ }
+ free(dllName);
+ }
+ } else {
+ /* Try to use old VS_PROFILER variable */
+ dNameLength = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR, NULL, 0);
+ if (dNameLength)
+ {
+ DWORD envret = 0;
+ dllName = (char*)malloc(sizeof(char) * (dNameLength + 1));
+ if(dllName != NULL)
+ {
+ envret = GetEnvironmentVariableA(DLL_ENVIRONMENT_VAR,
+ dllName, dNameLength);
+ if (envret)
+ {
+ /* Try to load the dll from the PATH... */
+ m_libHandle = LoadLibraryA(dllName);
+ }
+ free(dllName);
+ }
+ }
+ }
+#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+ dllName = getenv(NEW_DLL_ENVIRONMENT_VAR);
+ if (!dllName)
+ dllName = getenv(DLL_ENVIRONMENT_VAR);
+#if defined(__ANDROID__) || defined(ANDROID)
+ if (!dllName)
+ dllName = ANDROID_JIT_AGENT_PATH;
+#endif
+ if (dllName)
+ {
+ /* Try to load the dll from the PATH... */
+ m_libHandle = dlopen(dllName, RTLD_LAZY);
+ }
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+
+ if (!m_libHandle)
+ {
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+ m_libHandle = LoadLibraryA(DEFAULT_DLLNAME);
+#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+ m_libHandle = dlopen(DEFAULT_DLLNAME, RTLD_LAZY);
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+ }
+
+ /* if the dll wasn't loaded - exit. */
+ if (!m_libHandle)
+ {
+ iJIT_DLL_is_missing = 1; /* don't try to initialize
+ * JIT agent the second time
+ */
+ return 0;
+ }
+
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+ FUNC_NotifyEvent = (TPNotify)GetProcAddress(m_libHandle, "NotifyEvent");
+#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+ FUNC_NotifyEvent = (TPNotify)dlsym(m_libHandle, "NotifyEvent");
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+ if (!FUNC_NotifyEvent)
+ {
+ FUNC_Initialize = NULL;
+ return 0;
+ }
+
+#if ITT_PLATFORM==ITT_PLATFORM_WIN
+ FUNC_Initialize = (TPInitialize)GetProcAddress(m_libHandle, "Initialize");
+#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+ FUNC_Initialize = (TPInitialize)dlsym(m_libHandle, "Initialize");
+#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */
+ if (!FUNC_Initialize)
+ {
+ FUNC_NotifyEvent = NULL;
+ return 0;
+ }
+
+ executionMode = (iJIT_IsProfilingActiveFlags)FUNC_Initialize();
+
+ bDllWasLoaded = 1;
+ iJIT_DLL_is_missing = 0; /* DLL is ok. */
+
+ return 1;
+}
+
+ITT_EXTERN_C unsigned int JITAPI iJIT_GetNewMethodID()
+{
+ static unsigned int methodID = 1;
+
+ if (methodID == 0)
+ return 0; /* ERROR : this is not a valid value */
+
+ return methodID++;
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h
new file mode 100644
index 0000000000..bf0489b1a1
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.h
@@ -0,0 +1,673 @@
+/* <copyright>
+
+ Contact Information:
+ http://software.intel.com/en-us/articles/intel-vtune-amplifier-xe/
+
+ BSD LICENSE
+
+ Copyright (c) 2005-2014 Intel Corporation. All rights reserved.
+ All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+ * Neither the name of Intel Corporation nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+ OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+</copyright> */
+
+#ifndef __JITPROFILING_H__
+#define __JITPROFILING_H__
+
+/**
+ * @brief JIT Profiling APIs
+ *
+ * The JIT Profiling API is used to report information about just-in-time
+ * generated code that can be used by performance tools. The user inserts
+ * calls in the code generator to report information before JIT-compiled
+ * code goes to execution. This information is collected at runtime and used
+ * by tools like Intel(R) VTune(TM) Amplifier to display performance metrics
+ * associated with JIT-compiled code.
+ *
+ * These APIs can be used to\n
+ * - **Profile trace-based and method-based JIT-compiled
+ * code**. Some examples of environments that you can profile with these APIs:
+ * dynamic JIT compilation of JavaScript code traces, JIT execution in OpenCL(TM)
+ * software technology, Java/.NET managed execution environments, and custom
+ * ISV JIT engines.
+ * @code
+ * #include <jitprofiling.h>
+ *
+ * if (iJIT_IsProfilingActive != iJIT_SAMPLING_ON) {
+ * return;
+ * }
+ *
+ * iJIT_Method_Load jmethod = {0};
+ * jmethod.method_id = iJIT_GetNewMethodID();
+ * jmethod.method_name = "method_name";
+ * jmethod.class_file_name = "class_name";
+ * jmethod.source_file_name = "source_file_name";
+ * jmethod.method_load_address = code_addr;
+ * jmethod.method_size = code_size;
+ *
+ * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&jmethod);
+ * iJIT_NotifyEvent(iJVM_EVENT_TYPE_SHUTDOWN, NULL);
+ * @endcode
+ *
+ * * Expected behavior:
+ * * If any iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an
+ * already reported method, then such a method becomes invalid and its
+ * memory region is treated as unloaded. VTune Amplifier displays the metrics
+ * collected by the method until it is overwritten.
+ * * If supplied line number information contains multiple source lines for
+ * the same assembly instruction (code location), then VTune Amplifier picks up
+ * the first line number.
+ * * Dynamically generated code can be associated with a module name.
+ * Use the iJIT_Method_Load_V2 structure.\n
+ * Clarification of some cases:
+ * * If you register a function with the same method ID multiple times,
+ * specifying different module names, then the VTune Amplifier picks up
+ * the module name registered first. If you want to distinguish the same
+ * function between different JIT engines, supply different method IDs for
+ * each function. Other symbolic information (for example, source file)
+ * can be identical.
+ *
+ * - **Analyze split functions** (multiple joint or disjoint code regions
+ * belonging to the same function) **including re-JIT**
+ * with potential overlapping of code regions in time, which is common in
+ * resource-limited environments.
+ * @code
+ * #include <jitprofiling.h>
+ *
+ * unsigned int method_id = iJIT_GetNewMethodID();
+ *
+ * iJIT_Method_Load a = {0};
+ * a.method_id = method_id;
+ * a.method_load_address = 0x100;
+ * a.method_size = 0x20;
+ *
+ * iJIT_Method_Load b = {0};
+ * b.method_id = method_id;
+ * b.method_load_address = 0x200;
+ * b.method_size = 0x30;
+ *
+ * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a);
+ * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&b);
+ * @endcode
+ *
+ * * Expected behaviour:
+ * * If a iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an
+ * already reported method, then such a method becomes invalid and
+ * its memory region is treated as unloaded.
+ * * All code regions reported with the same method ID are considered as
+ * belonging to the same method. Symbolic information (method name,
+ * source file name) will be taken from the first notification, and all
+ * subsequent notifications with the same method ID will be processed
+ * only for line number table information. So, the VTune Amplifier will map
+ * samples to a source line using the line number table from the current
+ * notification while taking the source file name from the very first one.\n
+ * Clarification of some cases:\n
+ * * If you register a second code region with a different source file
+ * name and the same method ID, then this information will be saved and
+ * will not be considered as an extension of the first code region, but
+ * VTune Amplifier will use the source file of the first code region and map
+ * performance metrics incorrectly.
+ * * If you register a second code region with the same source file as
+ * for the first region and the same method ID, then the source file will be
+ * discarded but VTune Amplifier will map metrics to the source file correctly.
+ * * If you register a second code region with a null source file and
+ * the same method ID, then provided line number info will be associated
+ * with the source file of the first code region.
+ *
+ * - **Explore inline functions** including multi-level hierarchy of
+ * nested inline methods which shows how performance metrics are distributed through them.
+ * @code
+ * #include <jitprofiling.h>
+ *
+ * // method_id parent_id
+ * // [-- c --] 3000 2000
+ * // [---- d -----] 2001 1000
+ * // [---- b ----] 2000 1000
+ * // [------------ a ----------------] 1000 n/a
+ *
+ * iJIT_Method_Load a = {0};
+ * a.method_id = 1000;
+ *
+ * iJIT_Method_Inline_Load b = {0};
+ * b.method_id = 2000;
+ * b.parent_method_id = 1000;
+ *
+ * iJIT_Method_Inline_Load c = {0};
+ * c.method_id = 3000;
+ * c.parent_method_id = 2000;
+ *
+ * iJIT_Method_Inline_Load d = {0};
+ * d.method_id = 2001;
+ * d.parent_method_id = 1000;
+ *
+ * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a);
+ * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&b);
+ * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&c);
+ * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&d);
+ * @endcode
+ *
+ * * Requirements:
+ * * Each inline (iJIT_Method_Inline_Load) method should be associated
+ * with two method IDs: one for itself; one for its immediate parent.
+ * * Address regions of inline methods of the same parent method cannot
+ * overlap each other.
+ * * Execution of the parent method must not be started until it and all
+ * its inline methods are reported.
+ * * Expected behaviour:
+ * * In case of nested inline methods an order of
+ * iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED events is not important.
+ * * If any event overwrites either inline method or top parent method,
+ * then the parent, including inline methods, becomes invalid and its memory
+ * region is treated as unloaded.
+ *
+ * **Life time of allocated data**\n
+ * The client sends an event notification to the agent with event-specific
+ * data, which is a structure. The pointers in the structure refer to memory
+ * allocated by the client, which responsible for releasing it. The pointers are
+ * used by the iJIT_NotifyEvent method to copy client's data in a trace file,
+ * and they are not used after the iJIT_NotifyEvent method returns.
+ */
+
+/**
+ * @defgroup jitapi JIT Profiling
+ * @ingroup internal
+ * @{
+ */
+
+/**
+ * @brief Enumerator for the types of notifications
+ */
+typedef enum iJIT_jvm_event
+{
+ iJVM_EVENT_TYPE_SHUTDOWN = 2, /**<\brief Send this to shutdown the agent.
+ * Use NULL for event data. */
+
+ iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED = 13, /**<\brief Send when dynamic code is
+ * JIT compiled and loaded into
+ * memory by the JIT engine, but
+ * before the code is executed.
+ * Use iJIT_Method_Load as event
+ * data. */
+/** @cond exclude_from_documentation */
+ iJVM_EVENT_TYPE_METHOD_UNLOAD_START, /**<\brief Send when compiled dynamic
+ * code is being unloaded from memory.
+ * Use iJIT_Method_Load as event data.*/
+/** @endcond */
+
+ iJVM_EVENT_TYPE_METHOD_UPDATE, /**<\brief Send to provide new content for
+ * a previously reported dynamic code.
+ * The previous content will be invalidated
+ * starting from the time of the notification.
+ * Use iJIT_Method_Load as event data but
+ * required fields are following:
+ * - method_id identify the code to update.
+ * - method_load_address specify start address
+ * within identified code range
+ * where update should be started.
+ * - method_size specify length of updated code
+ * range. */
+
+
+ iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, /**<\brief Send when an inline dynamic
+ * code is JIT compiled and loaded
+ * into memory by the JIT engine,
+ * but before the parent code region
+ * starts executing.
+ * Use iJIT_Method_Inline_Load as event data.*/
+
+/** @cond exclude_from_documentation */
+ iJVM_EVENT_TYPE_METHOD_UPDATE_V2,
+/** @endcond */
+
+ iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 = 21, /**<\brief Send when a dynamic code is
+ * JIT compiled and loaded into
+ * memory by the JIT engine, but
+ * before the code is executed.
+ * Use iJIT_Method_Load_V2 as event data. */
+
+ iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 /**<\brief Send when a dynamic code is
+ * JIT compiled and loaded into
+ * memory by the JIT engine, but
+ * before the code is executed.
+ * Use iJIT_Method_Load_V3 as event data. */
+} iJIT_JVM_EVENT;
+
+/**
+ * @brief Enumerator for the agent's mode
+ */
+typedef enum _iJIT_IsProfilingActiveFlags
+{
+ iJIT_NOTHING_RUNNING = 0x0000, /**<\brief The agent is not running;
+ * iJIT_NotifyEvent calls will
+ * not be processed. */
+ iJIT_SAMPLING_ON = 0x0001, /**<\brief The agent is running and
+ * ready to process notifications. */
+} iJIT_IsProfilingActiveFlags;
+
+/**
+ * @brief Description of a single entry in the line number information of a code region.
+ * @details A table of line number entries gives information about how the reported code region
+ * is mapped to source file.
+ * Intel(R) VTune(TM) Amplifier uses line number information to attribute
+ * the samples (virtual address) to a line number. \n
+ * It is acceptable to report different code addresses for the same source line:
+ * @code
+ * Offset LineNumber
+ * 1 2
+ * 12 4
+ * 15 2
+ * 18 1
+ * 21 30
+ *
+ * VTune Amplifier constructs the following table using the client data
+ *
+ * Code subrange Line number
+ * 0-1 2
+ * 1-12 4
+ * 12-15 2
+ * 15-18 1
+ * 18-21 30
+ * @endcode
+ */
+typedef struct _LineNumberInfo
+{
+ unsigned int Offset; /**<\brief Offset from the begining of the code region. */
+ unsigned int LineNumber; /**<\brief Matching source line number offset (from beginning of source file). */
+
+} *pLineNumberInfo, LineNumberInfo;
+
+/**
+ * @brief Enumerator for the code architecture.
+ */
+typedef enum _iJIT_CodeArchitecture
+{
+ iJIT_CA_NATIVE = 0, /**<\brief Native to the process architecture that is calling it. */
+
+ iJIT_CA_32, /**<\brief 32-bit machine code. */
+
+ iJIT_CA_64 /**<\brief 64-bit machine code. */
+
+} iJIT_CodeArchitecture;
+
+#pragma pack(push, 8)
+
+/**
+ * @brief Description of a JIT-compiled method
+ * @details When you use the iJIT_Method_Load structure to describe
+ * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED
+ * as an event type to report it.
+ */
+typedef struct _iJIT_Method_Load
+{
+ unsigned int method_id; /**<\brief Unique method ID. Cannot be 0.
+ * You must either use the API function
+ * iJIT_GetNewMethodID to get a valid and unique
+ * method ID, or else manage ID uniqueness
+ * and correct range by yourself.\n
+ * You must use the same method ID for all code
+ * regions of the same method, otherwise different
+ * method IDs specify different methods. */
+
+ char* method_name; /**<\brief The name of the method. It can be optionally
+ * prefixed with its class name and appended with
+ * its complete signature. Can't be NULL. */
+
+ void* method_load_address; /**<\brief The start virtual address of the method code
+ * region. If NULL, data provided with
+ * event are not accepted. */
+
+ unsigned int method_size; /**<\brief The code size of the method in memory.
+ * If 0, then data provided with the event are not
+ * accepted. */
+
+ unsigned int line_number_size; /**<\brief The number of entries in the line number
+ * table.0 if none. */
+
+ pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info
+ * array. Can be NULL if
+ * line_number_size is 0. See
+ * LineNumberInfo Structure for a
+ * description of a single entry in
+ * the line number info array */
+
+ unsigned int class_id; /**<\brief This field is obsolete. */
+
+ char* class_file_name; /**<\brief Class name. Can be NULL.*/
+
+ char* source_file_name; /**<\brief Source file name. Can be NULL.*/
+
+} *piJIT_Method_Load, iJIT_Method_Load;
+
+/**
+ * @brief Description of a JIT-compiled method
+ * @details When you use the iJIT_Method_Load_V2 structure to describe
+ * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2
+ * as an event type to report it.
+ */
+typedef struct _iJIT_Method_Load_V2
+{
+ unsigned int method_id; /**<\brief Unique method ID. Cannot be 0.
+ * You must either use the API function
+ * iJIT_GetNewMethodID to get a valid and unique
+ * method ID, or else manage ID uniqueness
+ * and correct range by yourself.\n
+ * You must use the same method ID for all code
+ * regions of the same method, otherwise different
+ * method IDs specify different methods. */
+
+ char* method_name; /**<\brief The name of the method. It can be optionally
+ * prefixed with its class name and appended with
+ * its complete signature. Can't be NULL. */
+
+ void* method_load_address; /**<\brief The start virtual address of the method code
+ * region. If NULL, then data provided with the
+ * event are not accepted. */
+
+ unsigned int method_size; /**<\brief The code size of the method in memory.
+ * If 0, then data provided with the event are not
+ * accepted. */
+
+ unsigned int line_number_size; /**<\brief The number of entries in the line number
+ * table. 0 if none. */
+
+ pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info
+ * array. Can be NULL if
+ * line_number_size is 0. See
+ * LineNumberInfo Structure for a
+ * description of a single entry in
+ * the line number info array. */
+
+ char* class_file_name; /**<\brief Class name. Can be NULL. */
+
+ char* source_file_name; /**<\brief Source file name. Can be NULL. */
+
+ char* module_name; /**<\brief Module name. Can be NULL.
+ The module name can be useful for distinguishing among
+ different JIT engines. VTune Amplifier will display
+ reported methods grouped by specific module. */
+
+} *piJIT_Method_Load_V2, iJIT_Method_Load_V2;
+
+/**
+ * @brief Description of a JIT-compiled method
+ * @details The iJIT_Method_Load_V3 structure is the same as iJIT_Method_Load_V2
+ * with a newly introduced 'arch' field that specifies architecture of the code region.
+ * When you use the iJIT_Method_Load_V3 structure to describe
+ * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3
+ * as an event type to report it.
+ */
+typedef struct _iJIT_Method_Load_V3
+{
+ unsigned int method_id; /**<\brief Unique method ID. Cannot be 0.
+ * You must either use the API function
+ * iJIT_GetNewMethodID to get a valid and unique
+ * method ID, or manage ID uniqueness
+ * and correct range by yourself.\n
+ * You must use the same method ID for all code
+ * regions of the same method, otherwise they are
+ * treated as regions of different methods. */
+
+ char* method_name; /**<\brief The name of the method. It can be optionally
+ * prefixed with its class name and appended with
+ * its complete signature. Cannot be NULL. */
+
+ void* method_load_address; /**<\brief The start virtual address of the method code
+ * region. If NULL, then data provided with the
+ * event are not accepted. */
+
+ unsigned int method_size; /**<\brief The code size of the method in memory.
+ * If 0, then data provided with the event are not
+ * accepted. */
+
+ unsigned int line_number_size; /**<\brief The number of entries in the line number
+ * table. 0 if none. */
+
+ pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info
+ * array. Can be NULL if
+ * line_number_size is 0. See
+ * LineNumberInfo Structure for a
+ * description of a single entry in
+ * the line number info array. */
+
+ char* class_file_name; /**<\brief Class name. Can be NULL. */
+
+ char* source_file_name; /**<\brief Source file name. Can be NULL. */
+
+ char* module_name; /**<\brief Module name. Can be NULL.
+ * The module name can be useful for distinguishing among
+ * different JIT engines. VTune Amplifier will display
+ * reported methods grouped by specific module. */
+
+ iJIT_CodeArchitecture module_arch; /**<\brief Architecture of the method's code region.
+ * By default, it is the same as the process
+ * architecture that is calling it.
+ * For example, you can use it if your 32-bit JIT
+ * engine generates 64-bit code.
+ *
+ * If JIT engine reports both 32-bit and 64-bit types
+ * of methods then VTune Amplifier splits the methods
+ * with the same module name but with different
+ * architectures in two different modules. VTune Amplifier
+ * modifies the original name provided with a 64-bit method
+ * version by ending it with '(64)' */
+
+} *piJIT_Method_Load_V3, iJIT_Method_Load_V3;
+
+/**
+ * @brief Description of an inline JIT-compiled method
+ * @details When you use the_iJIT_Method_Inline_Load structure to describe
+ * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED
+ * as an event type to report it.
+ */
+typedef struct _iJIT_Method_Inline_Load
+{
+ unsigned int method_id; /**<\brief Unique method ID. Cannot be 0.
+ * You must either use the API function
+ * iJIT_GetNewMethodID to get a valid and unique
+ * method ID, or else manage ID uniqueness
+ * and correct range by yourself. */
+
+ unsigned int parent_method_id; /**<\brief Unique immediate parent's method ID.
+ * Cannot be 0.
+ * You must either use the API function
+ * iJIT_GetNewMethodID to get a valid and unique
+ * method ID, or else manage ID uniqueness
+ * and correct range by yourself. */
+
+ char* method_name; /**<\brief The name of the method. It can be optionally
+ * prefixed with its class name and appended with
+ * its complete signature. Can't be NULL. */
+
+ void* method_load_address; /** <\brief The virtual address on which the method
+ * is inlined. If NULL, then data provided with
+ * the event are not accepted. */
+
+ unsigned int method_size; /**<\brief The code size of the method in memory.
+ * If 0, then data provided with the event are not
+ * accepted. */
+
+ unsigned int line_number_size; /**<\brief The number of entries in the line number
+ * table. 0 if none. */
+
+ pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info
+ * array. Can be NULL if
+ * line_number_size is 0. See
+ * LineNumberInfo Structure for a
+ * description of a single entry in
+ * the line number info array */
+
+ char* class_file_name; /**<\brief Class name. Can be NULL.*/
+
+ char* source_file_name; /**<\brief Source file name. Can be NULL.*/
+
+} *piJIT_Method_Inline_Load, iJIT_Method_Inline_Load;
+
+/** @cond exclude_from_documentation */
+/**
+ * @brief Description of a segment type
+ * @details Use the segment type to specify a type of data supplied
+ * with the iJVM_EVENT_TYPE_METHOD_UPDATE_V2 event to be applied to
+ * a certain code trace.
+ */
+typedef enum _iJIT_SegmentType
+{
+ iJIT_CT_UNKNOWN = 0,
+
+ iJIT_CT_CODE, /**<\brief Executable code. */
+
+ iJIT_CT_DATA, /**<\brief Data (not executable code).
+ * VTune Amplifier uses the format string
+ * (see iJIT_Method_Update) to represent
+ * this data in the VTune Amplifier GUI */
+
+ iJIT_CT_KEEP, /**<\brief Use the previous markup for the trace.
+ * Can be used for the following
+ * iJVM_EVENT_TYPE_METHOD_UPDATE_V2 events,
+ * if the type of the previously reported segment
+ * type is the same. */
+ iJIT_CT_EOF
+} iJIT_SegmentType;
+
+/**
+ * @brief Description of a dynamic update of the content within JIT-compiled method
+ * @details The JIT engine may generate the methods that are updated at runtime
+ * partially by mixed (data + executable code) content. When you use the iJIT_Method_Update
+ * structure to describe the update of the content within a JIT-compiled method,
+ * use iJVM_EVENT_TYPE_METHOD_UPDATE_V2 as an event type to report it.
+ *
+ * On the first Update event, VTune Amplifier copies the original code range reported by
+ * the iJVM_EVENT_TYPE_METHOD_LOAD event, then modifies it with the supplied bytes and
+ * adds the modified range to the original method. For next update events, VTune Amplifier
+ * does the same but it uses the latest modified version of a code region for update.
+ * Eventually, VTune Amplifier GUI displays multiple code ranges for the method reported by
+ * the iJVM_EVENT_TYPE_METHOD_LOAD event.
+ * Notes:
+ * - Multiple update events with different types for the same trace are allowed
+ * but they must be reported for the same code ranges.
+ * Example,
+ * @code
+ * [-- data---] Allowed
+ * [-- code --] Allowed
+ * [code] Ignored
+ * [-- data---] Allowed
+ * [-- code --] Allowed
+ * [------------ trace ---------]
+ * @endcode
+ * - The types of previously reported events can be changed but they must be reported
+ * for the same code ranges.
+ * Example,
+ * @code
+ * [-- data---] Allowed
+ * [-- code --] Allowed
+ * [-- data---] Allowed
+ * [-- code --] Allowed
+ * [------------ trace ---------]
+ * @endcode
+ */
+
+typedef struct _iJIT_Method_Update
+{
+ void* load_address; /**<\brief Start address of the update within a method */
+
+ unsigned int size; /**<\brief The update size */
+
+ iJIT_SegmentType type; /**<\brief Type of the update */
+
+ const char* data_format; /**<\brief C string that contains a format string
+ * that follows the same specifications as format in printf.
+ * The format string is used for iJIT_CT_CODE only
+ * and cannot be NULL.
+ * Format can be changed on the fly. */
+} *piJIT_Method_Update, iJIT_Method_Update;
+
+/** @endcond */
+
+#pragma pack(pop)
+
+/** @cond exclude_from_documentation */
+#ifdef __cplusplus
+extern "C" {
+#endif /* __cplusplus */
+
+#ifndef JITAPI_CDECL
+# if defined WIN32 || defined _WIN32
+# define JITAPI_CDECL __cdecl
+# else /* defined WIN32 || defined _WIN32 */
+# if defined _M_IX86 || defined __i386__
+# define JITAPI_CDECL __attribute__ ((cdecl))
+# else /* _M_IX86 || __i386__ */
+# define JITAPI_CDECL /* actual only on x86_64 platform */
+# endif /* _M_IX86 || __i386__ */
+# endif /* defined WIN32 || defined _WIN32 */
+#endif /* JITAPI_CDECL */
+
+#define JITAPI JITAPI_CDECL
+/** @endcond */
+
+/**
+ * @brief Generates a new unique method ID.
+ *
+ * You must use this API to obtain unique and valid method IDs for methods or
+ * traces reported to the agent if you don't have your own mechanism to generate
+ * unique method IDs.
+ *
+ * @return a new unique method ID. When out of unique method IDs, this API
+ * returns 0, which is not an accepted value.
+ */
+unsigned int JITAPI iJIT_GetNewMethodID(void);
+
+/**
+ * @brief Returns the current mode of the agent.
+ *
+ * @return iJIT_SAMPLING_ON, indicating that agent is running, or
+ * iJIT_NOTHING_RUNNING if no agent is running.
+ */
+iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive(void);
+
+/**
+ * @brief Reports infomation about JIT-compiled code to the agent.
+ *
+ * The reported information is used to attribute samples obtained from any
+ * Intel(R) VTune(TM) Amplifier collector. This API needs to be called
+ * after JIT compilation and before the first entry into the JIT-compiled
+ * code.
+ *
+ * @param[in] event_type - type of the data sent to the agent
+ * @param[in] EventSpecificData - pointer to event-specific data
+ *
+ * @returns 1 on success, otherwise 0.
+ */
+int JITAPI iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData);
+
+#ifdef __cplusplus
+}
+#endif /* __cplusplus */
+/** @endcond */
+
+/** @} jitapi group */
+
+#endif /* __JITPROFILING_H__ */
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp
new file mode 100644
index 0000000000..ef4c42bacf
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.cpp
@@ -0,0 +1,317 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+
+#include "nchw_pooling.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type>
+void nchw_pooling_fwd_t<data_type>::execute_forward(
+ const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE);
+
+ const memory_desc_wrapper ws_d(pd()->workspace_md());
+ const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
+
+ const int MB = pd()->MB();
+ const int C = pd()->C();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ const int SD = pd()->KSD();
+ const int SH = pd()->KSH();
+ const int SW = pd()->KSW();
+ const int padF = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+
+ auto alg = pd()->desc()->alg_kind;
+
+ auto apply_offset = [=](int index, int offset) {
+ return (index > offset) ? index - offset : 0;
+ };
+
+ auto set_ws = [=](int mb, int c, int od, int oh, int ow, int value) {
+ if (ws) {
+ assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
+ size_t ws_offset
+ = (size_t)OW * OH * OD * C * mb
+ + (size_t)OW * OH * OD * c
+ + (size_t)OW * OH * od
+ + (size_t)OW * oh
+ + (size_t)ow;
+ if (ws_dt == data_type::u8) {
+ assert(0 <= value && value <= 255);
+ ws[ws_offset] = value;
+ } else
+ reinterpret_cast<int *>(ws)[ws_offset] = value;
+ }
+ };
+
+ auto ker_max = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
+ for (int kd = 0; kd < KD; ++kd) {
+ for (int kh = 0; kh < KH; ++kh) {
+ for (int kw = 0; kw < KW; ++kw) {
+ const int id = od * SD - padF + kd;
+ const int ih = oh * SH - padT + kh;
+ const int iw = ow * SW - padL + kw;
+
+ if (id < 0 || id >= ID) continue;
+ if (ih < 0 || ih >= IH) continue;
+ if (iw < 0 || iw >= IW) continue;
+
+ auto src_offset
+ = (size_t)IW * IH * ID * C * mb
+ + (size_t)IW * IH * ID * c
+ + (size_t)IW * IH * id
+ + (size_t)IW * ih
+ + (size_t)iw;
+ auto s = src[src_offset];
+ if (s > d[0]) {
+ d[0] = s;
+ set_ws(mb, c, od, oh, ow, kd*KH*KW + kh*KW + kw);
+ }
+ }
+ }
+ }
+ };
+
+ auto ker_avg = [=](data_t *d, int mb, int c, int od, int oh, int ow) {
+ auto id_start = apply_offset(od*SD, padF);
+ auto ih_start = apply_offset(oh*SH, padT);
+ auto iw_start = apply_offset(ow*SW, padL);
+ auto id_end = nstl::min(od*SD - padF + KD, ID);
+ auto ih_end = nstl::min(oh*SH - padT + KH, IH);
+ auto iw_end = nstl::min(ow*SW - padL + KW, IW);
+
+ auto num_summands = (alg == pooling_avg_include_padding) ? KD*KW*KH
+ : (id_end - id_start)*(ih_end - ih_start)*(iw_end - iw_start);
+
+ for (int id = id_start; id < id_end; ++id) {
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+ for (int iw = iw_start; iw < iw_end; ++iw) {
+ auto src_offset
+ = (size_t)IW * IH * ID * C * mb
+ + (size_t)IW * IH * ID * c
+ + (size_t)IW * IH * id
+ + (size_t)IW * ih
+ + (size_t)iw;
+ d[0] += src[src_offset];
+ }
+ }
+ }
+
+ d[0] = math::out_round<data_t>((float)d[0] / num_summands);
+ };
+
+
+ if (pd()->desc()->alg_kind == pooling_max) {
+ parallel_nd(MB, C, OD, OH, OW,
+ [&](int mb, int c, int od, int oh, int ow) {
+ size_t dst_offset
+ = (size_t)OW * OH * OD * C * mb
+ + (size_t)OW * OH * OD * c
+ + (size_t)OW * OH * od
+ + (size_t)OW * oh
+ + (size_t)ow;
+ data_t *d = &dst[dst_offset];
+ d[0] = nstl::numeric_limits<data_t>::lowest();
+ set_ws(mb, c, od, oh, ow, 0);
+ ker_max(d, mb, c, od, oh, ow);
+ });
+ } else {
+ parallel_nd(MB, C, OD, OH, OW,
+ [&](int mb, int c, int od, int oh, int ow) {
+ size_t dst_offset
+ = (size_t)OW * OH * OD * C * mb
+ + (size_t)OW * OH * OD * c
+ + (size_t)OW * OH * od
+ + (size_t)OW * oh
+ + (size_t)ow;
+ data_t *d = &dst[dst_offset];
+ d[0] = 0;
+ ker_avg(d, mb, c, od, oh, ow);
+ });
+ }
+}
+
+template <impl::data_type_t data_type>
+void nchw_pooling_bwd_t<data_type>::execute_backward(
+ const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper ws_d(pd()->workspace_md());
+
+ const int MB = pd()->MB();
+ const int C = pd()->C();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ const int SD = pd()->KSD();
+ const int SH = pd()->KSH();
+ const int SW = pd()->KSW();
+ const int padF = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+
+ const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
+
+ auto alg = pd()->desc()->alg_kind;
+
+ auto apply_offset = [=](int index, int offset) {
+ return (index > offset) ? index - offset : 0;
+ };
+
+ auto ker_zero = [=](int mb, int c) {
+ size_t diff_src_offset = (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW;
+ for (int id = 0; id < ID; ++id) {
+ for (int ih = 0; ih < IH; ++ih) {
+ for (int iw = 0; iw < IW; ++iw) {
+ diff_src[diff_src_offset++] = 0;
+ }
+ }
+ }
+ };
+
+ auto ker_max = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
+ auto b_c = ws_d.blocking_desc().inner_nblks == 0
+ ? 1 : ws_d.blocking_desc().inner_blks[0];
+ auto ws_offset = is_3d
+ ? ws_d.blk_off(mb, c / b_c, od, oh, ow) + c % b_c
+ : ws_d.blk_off(mb, c / b_c, oh, ow) + c % b_c;
+
+ const int index = ws_d.data_type() == data_type::u8
+ ? (int)ws[ws_offset] : ((const int *)ws)[ws_offset];
+ const int kw = index % KW;
+ const int kh = (index / KW) % KH;
+ const int kd = (index / KW) / KH;
+
+ const int id = od * SD - padF + kd;
+ const int ih = oh * SH - padT + kh;
+ const int iw = ow * SW - padL + kw;
+
+ // If padding area could fit the kernel,
+ // then input displacement would be out of bounds.
+ // No need to back propagate there as padding is
+ // virtual in pooling_max case.
+ if (id < 0 || id >= ID)
+ return;
+ if (ih < 0 || ih >= IH)
+ return;
+ if (iw < 0 || iw >= IW)
+ return;
+
+ size_t diff_src_offset =
+ (size_t)mb*C*ID*IH*IW + (size_t)c*ID*IH*IW + (size_t)id*IH*IW
+ + (size_t)ih*IW + (size_t)iw;
+ diff_src[diff_src_offset] += d[0];
+ };
+
+ auto ker_avg = [=](const data_t *d, int mb, int c, int od, int oh, int ow) {
+ auto id_start = apply_offset(od*SD, padF);
+ auto ih_start = apply_offset(oh*SH, padT);
+ auto iw_start = apply_offset(ow*SW, padL);
+ auto id_end = nstl::min(od*SD - padF + KD, ID);
+ auto ih_end = nstl::min(oh*SH - padT + KH, IH);
+ auto iw_end = nstl::min(ow*SW - padL + KW, IW);
+
+ size_t num_summands = (alg == pooling_avg_include_padding)
+ ? (size_t)KW*KH*KD
+ : (size_t)(id_end - id_start)*(ih_end - ih_start)
+ *(iw_end - iw_start);
+
+ for (int id = id_start; id < id_end; ++id) {
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+ for (int iw = iw_start; iw < iw_end; ++iw) {
+ size_t diff_src_offset = (size_t)mb*C*ID*IH*IW
+ + (size_t)c*ID*IH*IW + (size_t)id*IH*IW
+ + (size_t)ih*IW + (size_t)iw;
+ diff_src[diff_src_offset] += d[0] / num_summands;
+ }
+ }
+ }
+ };
+
+ if (pd()->desc()->alg_kind == pooling_max) {
+ parallel_nd(MB, C, [&](int mb, int c) {
+ size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW
+ + (size_t)c*OD*OH*OW;
+ ker_zero(mb, c);
+ for (int od = 0; od < OD; ++od) {
+ for (int oh = 0; oh < OH; ++oh) {
+ for (int ow = 0; ow < OW; ++ow) {
+ const data_t *d = &diff_dst[diff_dst_offset++];
+ ker_max(d, mb, c, od, oh, ow);
+ }
+ }
+ }
+ });
+ } else {
+ parallel_nd(MB, C, [&](int mb, int c) {
+ size_t diff_dst_offset = (size_t)mb*C*OD*OH*OW
+ + (size_t)c*OD*OH*OW;
+ ker_zero(mb, c);
+ for (int od = 0; od < OD; ++od) {
+ for (int oh = 0; oh < OH; ++oh) {
+ for (int ow = 0; ow < OW; ++ow) {
+ const data_t *d = &diff_dst[diff_dst_offset++];
+ ker_avg(d, mb, c, od, oh, ow);
+ }
+ }
+ }
+ });
+ }
+}
+
+template struct nchw_pooling_fwd_t<data_type::f32>;
+template struct nchw_pooling_bwd_t<data_type::f32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp
new file mode 100644
index 0000000000..bbdd04f6b9
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/nchw_pooling.hpp
@@ -0,0 +1,147 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_NCHW_POOLING_HPP
+#define CPU_NCHW_POOLING_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_pooling_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type>
+struct nchw_pooling_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_pooling_fwd_pd_t {
+ using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T("nchw_pooling:any", nchw_pooling_fwd_t);
+
+ status_t init() {
+ const format_tag_t desired_fmt_tag =
+ ndims() == 4 ? format_tag::nchw : format_tag::ncdhw;
+
+ bool ok = true
+ && set_default_params() == status::success
+ && is_fwd()
+ && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
+ alg_kind::pooling_avg_include_padding,
+ alg_kind::pooling_avg_exclude_padding)
+ && !has_zero_dim_memory()
+ && utils::everyone_is(data_type, src_md()->data_type,
+ dst_md()->data_type)
+ && attr()->has_default_values()
+ && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
+ && memory_desc_matches_tag(*dst_md(), desired_fmt_tag);
+ if (!ok) return status::unimplemented;
+
+ bool is_training = desc_.prop_kind == prop_kind::forward_training;
+ if (desc()->alg_kind == alg_kind::pooling_max && is_training)
+ init_default_ws();
+
+ return status::success;
+ }
+ };
+
+ nchw_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t data_type>
+struct nchw_pooling_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_pooling_bwd_pd_t {
+ using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T("nchw:any", nchw_pooling_bwd_t);
+
+ status_t init() {
+ const format_tag_t desired_fmt_tag =
+ ndims() == 4 ? format_tag::nchw : format_tag::ncdhw;
+
+ bool ok = true
+ && set_default_params() == status::success
+ && !is_fwd()
+ && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
+ alg_kind::pooling_avg_include_padding,
+ alg_kind::pooling_avg_exclude_padding)
+ && !has_zero_dim_memory()
+ && utils::everyone_is(data_type,
+ diff_dst_md()->data_type,
+ diff_src_md()->data_type)
+ && attr()->has_default_values()
+ && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag)
+ && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag);
+ if (!ok) return status::unimplemented;
+
+ if (desc()->alg_kind == alg_kind::pooling_max) {
+ bool ws_ok = true
+ && hint_fwd_pd_
+ && hint_fwd_pd_->workspace_md();
+ if (!ws_ok)
+ return status::unimplemented;
+
+ const auto &ws_blk =
+ hint_fwd_pd_->workspace_md()->format_desc.blocking;
+ ws_ok = ws_ok
+ && ws_blk.inner_nblks < 1
+ && IMPLICATION(ws_blk.inner_nblks == 1,
+ ws_blk.inner_idxs[0] == 1);
+ if (!ws_ok)
+ return status::unimplemented;
+
+ ws_md_ = *hint_fwd_pd_->workspace_md();
+ }
+
+ return status::success;
+ }
+ };
+
+ nchw_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp
new file mode 100644
index 0000000000..c0e93fefe4
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.cpp
@@ -0,0 +1,382 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+
+#include "cpu_batch_normalization_utils.hpp"
+#include "jit_generator.hpp"
+
+#include "ncsp_batch_normalization.hpp"
+
+// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
+#if (defined __clang_major__) && (__clang_major__ >= 6)
+#define SAFE_TO_USE_OMP_SIMD 0
+#else
+#define SAFE_TO_USE_OMP_SIMD 1
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace memory_tracking::names;
+
+void ncsp_batch_normalization_fwd_t::execute_forward(
+ const exec_ctx_t &ctx) const {
+ const bool calculate_stats = !pd()->stats_is_src();
+ const bool save_stats = pd()->is_training();
+ const bool is_training = pd()->is_training();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+
+ auto scratchpad = this->scratchpad(ctx);
+ auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
+
+ data_t *mean, *variance;
+ if (!calculate_stats) {
+ mean = const_cast<data_t *>(
+ CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN));
+ variance = const_cast<data_t *>(
+ CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE));
+ } else {
+ if (save_stats) {
+ mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN);
+ variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE);
+ } else {
+ mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
+ variance = scratchpad.get<data_t>(key_bnorm_tmp_var);
+ }
+ }
+
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();
+ const bool with_relu = pd()->with_relu_post_op();
+ auto maybe_post_op
+ = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
+ const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
+ dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
+ dim_t N = pd()->MB();
+ dim_t C = pd()->C();
+
+ int nthr = mkldnn_get_max_threads();
+ size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
+ size_t data_size = N * C * SP * sizeof(data_t);
+ bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ int C_ithr = 0, C_nthr = 0;
+ int N_ithr = 0, N_nthr = 0;
+ int S_ithr = 0, S_nthr = 0;
+
+ dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
+ dim_t N_s = 0, N_e = 0;
+ dim_t S_s = 0, S_e = 0;
+
+ dim_t C_blks_per_iter = 1;
+ int64_t iters = 1;
+
+ if (do_blocking) {
+ size_t working_set_size = N * SP * sizeof(data_t);
+ bnorm_utils::cache_balance(
+ working_set_size, C, C_blks_per_iter, iters);
+ } else
+ C_blks_per_iter = C;
+ int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter;
+ bool spatial_thr_allowed
+ = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N,
+ C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
+ N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
+ balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
+ int SP_N_ithr = N_ithr * S_nthr + S_ithr;
+ int SP_N_nthr = N_nthr * S_nthr;
+ for (int64_t it = 0; it < iters; ++it) {
+ if (it == iters - 1 && iters > 1) {
+ // On the last iteration the access pattern to ws_reduce
+ // might change (due to re-balance on C). So sync the
+ // threads if they are not synced by the algorithm.
+ if (SP_N_nthr == 1 && mkldnn_thr_syncable())
+ mkldnn_thr_barrier();
+
+ S_s = S_e = C_blk_s = C_blk_e = N_s = N_e = 0;
+ spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
+ spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
+ C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
+ N_e, S_ithr, S_nthr, S_s, S_e);
+ balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
+ SP_N_ithr = N_ithr * S_nthr + S_ithr;
+ SP_N_nthr = N_nthr * S_nthr;
+ }
+ size_t C_off = it * C_blks_per_iter;
+ // On the last iteration the access pattern to ws_reduce
+ // might change (due to re-balance on C). Since sync is not always
+ // possible (in case of TBB) use different parts of ws for each
+ // iteration if threads are not synced by the algorithm.
+ size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * C_off;
+
+ if (calculate_stats) {
+ data_t *mean_blk = mean + C_off;
+ data_t *variance_blk = variance + C_off;
+ for (dim_t c = C_blk_s; c < C_blk_e; c++) {
+ size_t off = (c + C_off) * SP;
+ data_t sum = 0;
+ for (dim_t n = N_s; n < N_e; ++n)
+ PRAGMA_OMP_SIMD(reduction(+ : sum))
+ for (dim_t sp = S_s; sp < S_e; ++sp) {
+ sum += src[off + n * C * SP + sp];
+ }
+ ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
+ = sum;
+ }
+
+ if (SP_N_nthr > 1) mkldnn_thr_barrier();
+
+ for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) {
+ mean_blk[c] = 0.;
+ for (dim_t n = 0; n < SP_N_nthr; n++)
+ mean_blk[c] += ws_reduce[ws_iter_off
+ + n * C_blks_per_iter + c];
+ mean_blk[c] /= (N * SP);
+ }
+
+ if (SP_N_nthr > 1) mkldnn_thr_barrier();
+
+ for (dim_t c = C_blk_s; c < C_blk_e; c++) {
+ size_t off = c + C_off;
+ data_t sum = 0.;
+ for (dim_t n = N_s; n < N_e; ++n)
+ PRAGMA_OMP_SIMD(reduction(+ : sum))
+ for (dim_t sp = S_s; sp < S_e; ++sp) {
+ data_t m = src[off * SP + n * C * SP + sp]
+ - mean[off];
+ sum += m * m;
+ }
+ ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
+ = sum;
+ }
+
+ if (SP_N_nthr > 1) mkldnn_thr_barrier();
+
+ for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) {
+ variance_blk[c] = 0.;
+ for (dim_t n = 0; n < SP_N_nthr; n++)
+ variance_blk[c] += ws_reduce[ws_iter_off
+ + n * C_blks_per_iter + c];
+ variance_blk[c] /= (N * SP);
+ }
+
+ if (SP_N_nthr > 1) mkldnn_thr_barrier();
+ }
+
+ for (dim_t c = C_blk_s; c < C_blk_e; c++) {
+ size_t off = c + C_off;
+ data_t sqrt_variance
+ = static_cast<data_t>(sqrtf(variance[off] + eps));
+ data_t sm = (use_scaleshift ? scaleshift[off] : 1.0f) / sqrt_variance;
+ data_t sv = use_scaleshift ? scaleshift[C + off] : 0;
+ for (dim_t n = N_s; n < N_e; ++n)
+#if SAFE_TO_USE_OMP_SIMD
+ PRAGMA_OMP_SIMD()
+#endif
+ for (dim_t sp = S_s; sp < S_e; ++sp) {
+ size_t d_off = off * SP + n * C * SP + sp;
+ data_t bn_res
+ = sm * (src[d_off] - mean[off]) + sv;
+ if (fuse_bn_relu) {
+ if (bn_res <= 0) {
+ bn_res = 0;
+ if (is_training)
+ ws[d_off] = 0;
+ } else {
+ if (is_training)
+ ws[d_off] = 1;
+ }
+ }
+ dst[d_off] = maybe_post_op(bn_res);
+ }
+ }
+ }
+ });
+}
+
+void ncsp_batch_normalization_bwd_t::execute_backward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
+ auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+ auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+ auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
+
+ auto scratchpad = this->scratchpad(ctx);
+ auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
+
+ if (diff_scaleshift == nullptr)
+ diff_scaleshift = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
+
+ const bool has_spatial = utils::one_of(pd()->ndims(), 4, 5);
+ dim_t SP = (has_spatial) ? pd()->H() * pd()->W() * pd()->D() : 1;
+ dim_t C = pd()->C(), N = pd()->MB();
+ const bool use_scaleshift = pd()->use_scaleshift();
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool calculate_diff_stats = !pd()->use_global_stats();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+
+ int nthr = mkldnn_get_max_threads();
+ size_t l3_size_ = get_cache_size(3, true) * nthr / 2;
+ size_t data_size = N * C * SP * sizeof(data_t);
+ bool do_blocking = (data_size >= l3_size_ / 2 && l3_size_ > 0);
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ int C_ithr = 0, C_nthr = 0;
+ int N_ithr = 0, N_nthr = 0;
+ int S_ithr = 0, S_nthr = 0;
+
+ dim_t C_blk_gl_s = 0, C_blk_gl_e = 0, C_blk_s = 0, C_blk_e = 0;
+ dim_t N_s = 0, N_e = 0;
+ dim_t S_s = 0, S_e = 0;
+
+ dim_t C_blks_per_iter = 1;
+ int64_t iters = 1;
+
+ if (do_blocking) {
+ size_t working_set_size = 2 * N * SP * sizeof(data_t);
+ bnorm_utils::cache_balance(
+ working_set_size, C, C_blks_per_iter, iters);
+ } else
+ C_blks_per_iter = C;
+ int64_t last_iter_blks = C - (iters - 1) * C_blks_per_iter;
+ bool spatial_thr_allowed
+ = bnorm_utils::thread_balance(do_blocking, true, ithr, nthr, N,
+ C_blks_per_iter, SP, C_ithr, C_nthr, C_blk_s, C_blk_e,
+ N_ithr, N_nthr, N_s, N_e, S_ithr, S_nthr, S_s, S_e);
+ balance211(C_blks_per_iter, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
+ int SP_N_ithr = N_ithr * S_nthr + S_ithr;
+ int SP_N_nthr = N_nthr * S_nthr;
+
+ for (int64_t it = 0; it < iters; ++it) {
+ if (it == iters - 1 && iters > 1) {
+ // On the last iteration the access pattern to ws_reduce
+ // might change (due to re-balance on C). So sync the
+ // threads if they are not synced by the algorithm.
+ if (SP_N_nthr == 1 && mkldnn_thr_syncable())
+ mkldnn_thr_barrier();
+
+ C_blk_s = C_blk_e = N_s = N_e = 0;
+ spatial_thr_allowed = bnorm_utils::thread_balance(do_blocking,
+ spatial_thr_allowed, ithr, nthr, N, last_iter_blks, SP,
+ C_ithr, C_nthr, C_blk_s, C_blk_e, N_ithr, N_nthr, N_s,
+ N_e, S_ithr, S_nthr, S_s, S_e);
+ balance211(last_iter_blks, nthr, ithr, C_blk_gl_s, C_blk_gl_e);
+ SP_N_ithr = N_ithr * S_nthr + S_ithr;
+ SP_N_nthr = N_nthr * S_nthr;
+ }
+ size_t C_off = it * C_blks_per_iter;
+ // On the last iteration the access pattern to ws_reduce
+ // might change (due to re-balance on C). Since sync is not always
+ // possible (in case of TBB) use different parts of ws for each
+ // iteration if threads are not synced by the algorithm.
+ size_t ws_iter_off = (mkldnn_thr_syncable() ? 0 : 1) * 2 * C_off;
+
+ data_t *diff_gamma_blk = diff_scaleshift + C_off;
+ data_t *diff_beta_blk = diff_scaleshift + C + C_off;
+ for (dim_t c = C_blk_s; c < C_blk_e; c++) {
+ size_t off = c + C_off;
+ data_t diff_gamma = 0.0, diff_beta = 0.0;
+ data_t v_mean = mean[off];
+ for (dim_t n = N_s; n < N_e; ++n)
+ PRAGMA_OMP_SIMD(reduction(+ : diff_gamma, diff_beta))
+ for (dim_t sp = S_s; sp < S_e; ++sp) {
+ const size_t d_off = off * SP + n * C * SP + sp;
+ data_t dd;
+ if (fuse_bn_relu)
+ dd = (!ws[d_off]) ? 0 : diff_dst[d_off];
+ else
+ dd = diff_dst[d_off];
+ diff_gamma += (src[d_off] - v_mean) * dd;
+ diff_beta += dd;
+ }
+ ws_reduce[ws_iter_off + SP_N_ithr * C_blks_per_iter + c]
+ = diff_gamma;
+ ws_reduce[ws_iter_off + SP_N_nthr * C_blks_per_iter
+ + SP_N_ithr * C_blks_per_iter + c] = diff_beta;
+ }
+
+ if (SP_N_nthr > 1) mkldnn_thr_barrier();
+
+ for (dim_t c = C_blk_gl_s; c < C_blk_gl_e; c++) {
+ data_t sqrt_variance = static_cast<data_t>(
+ 1.0f / sqrtf(variance[c + C_off] + eps));
+ diff_gamma_blk[c] = 0.;
+ diff_beta_blk[c] = 0.;
+ for (dim_t n = 0; n < SP_N_nthr; n++) {
+ diff_gamma_blk[c] += ws_reduce[ws_iter_off
+ + n * C_blks_per_iter + c];
+ diff_beta_blk[c] += ws_reduce[ws_iter_off
+ + SP_N_nthr * C_blks_per_iter + n * C_blks_per_iter
+ + c];
+ }
+ diff_gamma_blk[c] *= sqrt_variance;
+ }
+
+ if (SP_N_nthr > 1) mkldnn_thr_barrier();
+
+ for (dim_t c = C_blk_s; c < C_blk_e; c++) {
+ size_t off = c + C_off;
+ data_t gamma = use_scaleshift ? scaleshift[off] : 1;
+ data_t sqrt_variance
+ = static_cast<data_t>(1.0f / sqrtf(variance[off] + eps));
+ data_t v_mean = mean[off];
+ for (dim_t n = N_s; n < N_e; ++n)
+#if SAFE_TO_USE_OMP_SIMD
+ PRAGMA_OMP_SIMD()
+#endif
+ for (dim_t sp = S_s; sp < S_e; ++sp) {
+ const size_t d_off = off * SP + n * C * SP + sp;
+
+ data_t v_diff_src;
+ if (fuse_bn_relu)
+ v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off];
+ else
+ v_diff_src = diff_dst[d_off];
+ if (calculate_diff_stats) {
+ v_diff_src -= diff_beta_blk[c] / (SP * N)
+ + (src[d_off] - v_mean) * diff_gamma_blk[c]
+ * sqrt_variance / (SP * N);
+ }
+ v_diff_src *= gamma * sqrt_variance;
+ diff_src[d_off] = v_diff_src;
+ }
+ }
+ }
+ });
+}
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp
new file mode 100644
index 0000000000..97ca3b003f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ncsp_batch_normalization.hpp
@@ -0,0 +1,160 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_NCSP_BATCH_NORMALIZATION_HPP
+#define CPU_NCSP_BATCH_NORMALIZATION_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_batch_normalization_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct ncsp_batch_normalization_fwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_batch_normalization_fwd_pd_t {
+ using cpu_batch_normalization_fwd_pd_t::cpu_batch_normalization_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_fwd_t);
+
+ status_t init() {
+ using namespace data_type;
+ using namespace prop_kind;
+ using namespace format_tag;
+
+ bool ok = true
+ && is_fwd()
+ && !has_zero_dim_memory()
+ && src_md()->data_type == f32
+ && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32)
+ && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc)
+ && (attr()->has_default_values() || this->with_relu_post_op());
+ if (!ok) return status::unimplemented;
+
+ if (is_training() && fuse_bn_relu()) init_default_ws(8);
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ private:
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ if (!stats_is_src()) {
+ scratchpad.book(key_bnorm_reduction,
+ sizeof(data_t) * C() * mkldnn_get_max_threads());
+
+ if (!is_training()) {
+ scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * C());
+ scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * C());
+ }
+ }
+ }
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ ncsp_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ ~ncsp_batch_normalization_fwd_t() {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+struct ncsp_batch_normalization_bwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_batch_normalization_bwd_pd_t {
+ using cpu_batch_normalization_bwd_pd_t::cpu_batch_normalization_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ncsp_bnorm:any", ncsp_batch_normalization_bwd_t);
+
+ status_t init() {
+ using namespace data_type;
+ using namespace format_tag;
+
+ bool ok = true
+ && is_bwd()
+ && !has_zero_dim_memory()
+ && utils::everyone_is(f32, src_md()->data_type,
+ diff_src_md()->data_type)
+ && IMPLICATION(use_scaleshift(),
+ utils::everyone_is(f32,
+ weights_md()->data_type,
+ diff_weights_md()->data_type))
+ && memory_desc_matches_one_of_tag(*src_md(), ncdhw, nchw, nc)
+ && memory_desc_matches_one_of_tag(*diff_src_md(), ncdhw, nchw, nc)
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ if (fuse_bn_relu()) {
+ init_default_ws(8);
+ if (!compare_ws(hint_fwd_pd_))
+ return status::unimplemented;
+ }
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ private:
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ scratchpad.book(key_bnorm_reduction,
+ sizeof(data_t) * 2 * C() * mkldnn_get_max_threads());
+ if (!(use_scaleshift() && desc()->prop_kind == prop_kind::backward))
+ scratchpad.book(key_bnorm_tmp_diff_ss,
+ sizeof(data_t) * 2 * C());
+ }
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ ncsp_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ ~ncsp_batch_normalization_bwd_t() {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp
new file mode 100644
index 0000000000..38cfb28dce
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.cpp
@@ -0,0 +1,392 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+
+#include "nhwc_pooling.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+#define MEM_D(name) name##_d
+
+#define DECLARE_READ_STRIDES(name) \
+ const size_t name##_n_stride = MEM_D(name).blocking_desc().strides[0]; \
+ const size_t name##_d_stride = (!is_3d) \
+ ? 0 \
+ : MEM_D(name).blocking_desc().strides[2]; \
+ const size_t name##_h_stride = (!is_3d) \
+ ? MEM_D(name).blocking_desc().strides[2] \
+ : MEM_D(name).blocking_desc().strides[3]; \
+ const size_t name##_w_stride = (!is_3d) \
+ ? MEM_D(name).blocking_desc().strides[3] \
+ : MEM_D(name).blocking_desc().strides[4];
+
+namespace nhwc_pooling {
+ size_t strided_offset(const int _n, const size_t _sn,
+ const int _d, const size_t _sd,
+ const int _h, const size_t _sh,
+ const int _w, const size_t _sw)
+ {
+ return _n * _sn
+ + _d * _sd
+ + _h * _sh
+ + _w * _sw;
+ }
+}
+
+template <impl::data_type_t data_type>
+void nhwc_pooling_fwd_t<data_type>::array_div_by_const(const int n,
+ const data_t *src, const size_t num, data_t *dst) const
+{
+ for (int i = 0; i < n; ++i)
+ {
+ float ftmp = (float)src[i];
+ ftmp = ftmp / num;
+ dst[i] = math::out_round<data_t>(ftmp);
+ }
+}
+
+template <impl::data_type_t data_type>
+void nhwc_pooling_fwd_t<data_type>::array_add(const int n, const data_t *src,
+ data_t *dst) const
+{
+ for (int i = 0; i < n; ++i)
+ {
+ dst[i] += src[i];
+ }
+}
+
+template <impl::data_type_t data_type>
+void nhwc_pooling_fwd_t<data_type>::execute_forward(
+ const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+ using namespace prop_kind;
+ using namespace nhwc_pooling;
+
+ auto alg = pd()->desc()->alg_kind;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE);
+
+ const memory_desc_wrapper MEM_D(src)(pd()->src_md());
+ const memory_desc_wrapper MEM_D(dst)(pd()->dst_md());
+ const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
+
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ const int SD = pd()->KSD();
+ const int SH = pd()->KSH();
+ const int SW = pd()->KSW();
+ const int padF = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+ const int MB = pd()->MB();
+ const int OC = pd()->C();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+
+ const bool is_3d = pd()->desc()->src_desc.ndims == 5;
+ const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
+
+ DECLARE_READ_STRIDES(src);
+ DECLARE_READ_STRIDES(dst);
+
+ auto apply_offset = [=](int index, int offset) {
+ return (index > offset) ? index - offset : 0;
+ };
+
+ parallel_nd(MB, OD, OH, OW,
+ [&](int mb, int od, int oh, int ow) {
+ size_t dst_offset_init = strided_offset(mb, dst_n_stride,
+ od, dst_d_stride,
+ oh, dst_h_stride,
+ ow, dst_w_stride);
+ if (alg == pooling_max) {
+ size_t ws_offset_init = 0;
+ if (ws)
+ {
+ DECLARE_READ_STRIDES(ws);
+ ws_offset_init = strided_offset(mb, ws_n_stride,
+ od, ws_d_stride,
+ oh, ws_h_stride,
+ ow, ws_w_stride);
+ }
+ // Note: GCC 4.8.5 won't vectorize below
+ // simple loops unless they are singled out
+ // into separate helper routines:
+ // array_nhwc_initialize, array_nhwc_max
+ if (!ws)
+ array_nhwc_initialize<false>(OC, dst + dst_offset_init,
+ ws, ws_offset_init, ws_dt);
+ else
+ array_nhwc_initialize<true>(OC, dst + dst_offset_init,
+ ws, ws_offset_init, ws_dt);
+
+
+ for (int kd = 0; kd < KD; ++kd)
+ for (int kh = 0; kh < KH; ++kh)
+ for (int kw = 0; kw < KW; ++kw) {
+ const int id = od * SD - padF + kd;
+ const int ih = oh * SH - padT + kh;
+ const int iw = ow * SW - padL + kw;
+
+ if (id < 0 || id >= ID)
+ continue;
+ if (ih < 0 || ih >= IH)
+ continue;
+ if (iw < 0 || iw >= IW)
+ continue;
+
+ size_t src_offset_init = strided_offset(mb, src_n_stride,
+ id, src_d_stride,
+ ih, src_h_stride,
+ iw, src_w_stride);
+
+ if (!ws)
+ array_nhwc_max<false>(OC,
+ dst + dst_offset_init,
+ src + src_offset_init,
+ ws, ws_offset_init,
+ ws_dt,
+ kd * KH * KW + kh * KW + kw
+ );
+ else
+ array_nhwc_max<true>(OC,
+ dst + dst_offset_init,
+ src + src_offset_init,
+ ws, ws_offset_init,
+ ws_dt,
+ kd * KH * KW + kh * KW + kw
+ );
+ }
+ } else {
+ // pooling_avg
+ auto d = dst + dst_offset_init;
+
+ utils::array_set(d, 0, OC);
+
+ auto id_start = apply_offset(od * SD, padF);
+ auto ih_start = apply_offset(oh * SH, padT);
+ auto iw_start = apply_offset(ow * SW, padL);
+ auto id_end = nstl::min(od * SD - padF + KD, ID);
+ auto ih_end = nstl::min(oh * SH - padT + KH, IH);
+ auto iw_end = nstl::min(ow * SW - padL + KW, IW);
+
+ // it is cheaper to actually count this in a loop
+ // as the typical kernel is small
+ size_t num_summands = 0;
+
+ for (int id = id_start; id < id_end; ++id)
+ for (int ih = ih_start; ih < ih_end; ++ih)
+ for (int iw = iw_start; iw < iw_end; ++iw) {
+ size_t src_offset_init = strided_offset(mb, src_n_stride,
+ id, src_d_stride,
+ ih, src_h_stride,
+ iw, src_w_stride);
+ auto s = src + src_offset_init;
+
+ // need to move the loop to separate function
+ // for GCC 4.8.5 to vectorize
+ array_add(OC, s, d);
+
+ num_summands++;
+ }
+
+ num_summands = (alg == pooling_avg_include_padding) ?
+ KW * KH * KD : num_summands;
+
+ // need to move the loop to separate function
+ // for GCC 4.8.5 to vectorize
+ array_div_by_const(OC, d, num_summands, d);
+ }
+ });
+}
+
+template <impl::data_type_t data_type>
+void nhwc_pooling_bwd_t<data_type>::execute_backward(
+ const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+ using namespace nhwc_pooling;
+
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper MEM_D(diff_src)(pd()->diff_src_md());
+ const memory_desc_wrapper MEM_D(diff_dst)(pd()->diff_dst_md());
+ const memory_desc_wrapper MEM_D(ws)(pd()->workspace_md());
+
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ const int SD = pd()->KSD();
+ const int SH = pd()->KSH();
+ const int SW = pd()->KSW();
+ const int OC = pd()->C();
+ const int padF = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+
+ const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
+ auto alg = pd()->desc()->alg_kind;
+
+ DECLARE_READ_STRIDES(diff_src);
+ DECLARE_READ_STRIDES(diff_dst);
+
+ auto apply_offset = [=](int index, int offset) {
+ return (index > offset) ? index - offset : 0;
+ };
+
+ const int MB = pd()->MB();
+
+ parallel_nd(MB, ID, IH, IW,
+ [&](int mb, int id, int ih, int iw) {
+ size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
+ id, diff_src_d_stride,
+ ih, diff_src_h_stride,
+ iw, diff_src_w_stride);
+
+ // check if kernel windows are disjoint, in this case there's no
+ // update needed and we just write there once, no initialization
+ // required.
+ if (!(KD == SD && KH == SH && KW == SW))
+ for (int oc = 0; oc < OC; ++oc)
+ diff_src[src_offset_init + oc] = data_type_t(0);
+
+ // Find out which output cells may correspond to current
+ // input position. Current input postition divided by
+ // stride, with integer divide rounding down, is the
+ // right-most output.
+ // Left-most output may be computed if we decrement input
+ // by (kernel_size - 1) and then do the same division by
+ // stride.
+ int od_left = nstl::max((id + padF - KD + 1) / SD, 0);
+ int oh_left = nstl::max((ih + padT - KH + 1) / SH, 0);
+ int ow_left = nstl::max((iw + padL - KW + 1) / SW, 0);
+ // Notice +1 here to preserve the C loop "less than"
+ // condition for continuing the for loop.
+ int od_right = nstl::min((id + padF) / SD + 1 , OD);
+ int oh_right = nstl::min((ih + padT) / SH + 1 , OH);
+ int ow_right = nstl::min((iw + padL) / SW + 1 , OW);
+
+ for (int od = od_left; od < od_right; ++od)
+ for (int oh = oh_left; oh < oh_right; ++oh)
+ for (int ow = ow_left; ow < ow_right; ++ow) {
+ const int kd = id - od*SD + padF;
+ const int kh = ih - oh*SH + padT;
+ const int kw = iw - ow*SW + padL;
+
+ if (kd < 0 || kd >= KD)
+ continue;
+ if (kh < 0 || kh >= KH)
+ continue;
+ if (kw < 0 || kw >= KW)
+ continue;
+
+ size_t dst_offset_init = strided_offset(mb, diff_dst_n_stride,
+ od, diff_dst_d_stride,
+ oh, diff_dst_h_stride,
+ ow, diff_dst_w_stride);
+
+ if (alg == pooling_max) {
+ DECLARE_READ_STRIDES(ws);
+ size_t ws_offset_init = strided_offset(mb, ws_n_stride,
+ od, ws_d_stride,
+ oh, ws_h_stride,
+ ow, ws_w_stride);
+ const int index = kd * KH * KW + kh * KW + kw;
+
+ PRAGMA_OMP_SIMD()
+ for (int oc = 0; oc < OC; ++oc) {
+ const int index_from_ws =
+ (MEM_D(ws).data_type() == data_type::u8)
+ ? (int)ws[ws_offset_init + oc]
+ : ((int *)ws)[ws_offset_init + oc];
+
+ const data_t d = diff_dst[dst_offset_init + oc];
+
+ // Check if kernel windows are disjoint, in this case
+ // there's no update needed and we just write there once
+ // otherwise we add value to the contents.
+ if (!(KD == SD && KH == SH && KW == SW))
+ diff_src[src_offset_init + oc] +=
+ (index_from_ws == index)
+ ? d
+ : data_type_t(0);
+ else
+ diff_src[src_offset_init + oc] =
+ (index_from_ws == index)
+ ? d
+ : data_type_t(0);
+ }
+ } else {
+ // pooling_avg
+ auto id_start = apply_offset(od*SD, padF);
+ auto ih_start = apply_offset(oh*SH, padT);
+ auto iw_start = apply_offset(ow*SW, padL);
+ auto id_end = nstl::min(od*SD - padF + KD, ID);
+ auto ih_end = nstl::min(oh*SH - padT + KH, IH);
+ auto iw_end = nstl::min(ow*SW - padL + KW, IW);
+
+ auto num_summands = (alg == pooling_avg_include_padding)
+ ? KW*KH*KD
+ : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
+
+ PRAGMA_OMP_SIMD()
+ for (int oc = 0; oc < OC; ++oc) {
+ const data_t d = diff_dst[dst_offset_init + oc];
+ // Check if kernel windows are disjoint, in this case
+ // there's no update needed and we just write there once
+ // otherwise we add value to the contents.
+ if (!(KD == SD && KH == SH && KW == SW))
+ diff_src[src_offset_init + oc] += d / num_summands;
+ else
+ diff_src[src_offset_init + oc] = d / num_summands;
+ }
+ }
+ }
+ });
+}
+
+template struct nhwc_pooling_fwd_t<data_type::f32>;
+template struct nhwc_pooling_bwd_t<data_type::f32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp
new file mode 100644
index 0000000000..7e33b6869f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/nhwc_pooling.hpp
@@ -0,0 +1,210 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_NHWC_POOLING_HPP
+#define CPU_NHWC_POOLING_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_pooling_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace nhwc_pooling {
+size_t strided_offset(const int _n, const size_t _sn, const int _d,
+ const size_t _sd, const int _h, const size_t _sh, const int _w,
+ const size_t _sw);
+}
+
+template <impl::data_type_t data_type>
+struct nhwc_pooling_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_pooling_fwd_pd_t {
+ using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T("nhwc_pooling:any", nhwc_pooling_fwd_t);
+
+ status_t init() {
+ const format_tag_t desired_fmt_tag =
+ ndims() == 4 ? format_tag::nhwc : format_tag::ndhwc;
+
+ bool ok = true
+ && set_default_params() == status::success
+ && is_fwd()
+ && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
+ alg_kind::pooling_avg_include_padding,
+ alg_kind::pooling_avg_exclude_padding)
+ && utils::everyone_is(data_type,
+ src_md()->data_type,
+ dst_md()->data_type)
+ && attr()->has_default_values()
+ && memory_desc_matches_tag(*src_md(), desired_fmt_tag)
+ && memory_desc_matches_tag(*dst_md(), desired_fmt_tag);
+ if (!ok) return status::unimplemented;
+
+ bool is_training = desc_.prop_kind == prop_kind::forward_training;
+ if (desc()->alg_kind == alg_kind::pooling_max && is_training)
+ init_default_ws();
+
+ return status::success;
+ }
+ };
+
+ nhwc_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ void array_div_by_const(const int n, const data_t *src, const size_t num,
+ data_t *dst) const;
+ void array_add(const int n, const data_t *src, data_t *dst) const;
+
+ template <bool use_workspace>
+ void array_nhwc_max(const int n, data_t *dst, const data_t *src,
+ unsigned char *ws, const size_t ws_offset, const data_type_t ws_dt,
+ const int index) const {
+ assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists
+ PRAGMA_OMP_SIMD()
+ for (int oc = 0; oc < n; ++oc) {
+ auto s = src[oc];
+ data_t mv = dst[oc];
+
+ // update index of maximum
+#if defined __INTEL_COMPILER
+ if ((use_workspace) && (s > mv)) {
+ assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
+ if (ws_dt == data_type::u8) {
+ assert(0 <= index && index <= 255);
+ ws[ws_offset + oc] = index;
+ } else
+ reinterpret_cast<int *>(ws)[ws_offset + oc] = index;
+ }
+#else
+ // Need to add explicit predicates for GCC to vectorize this.
+ // And although the resulting code is ugly, it is still 4 times
+ // faster than scalar
+ if (use_workspace) {
+ assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
+
+ if (ws_dt == data_type::u8) {
+ assert(0 <= index && index <= 255);
+ unsigned char predicate = (s > mv) ? 0xff : 0;
+ unsigned char current_value = ws[ws_offset + oc];
+ current_value = (predicate & (unsigned char)index)
+ | ((~predicate) & current_value);
+ ws[ws_offset + oc] = current_value;
+ } else {
+ auto wint = reinterpret_cast<int *>(ws);
+ unsigned int predicate = (s > mv) ? 0xffffffff : 0;
+ unsigned int current_value = wint[ws_offset + oc];
+ current_value = (predicate & (unsigned int)index)
+ | ((~predicate) & current_value);
+ wint[ws_offset + oc] = current_value;
+ }
+ }
+#endif
+ // update maximum
+ dst[oc] = nstl::max(s, mv);
+ }
+ }
+
+ template <bool use_workspace>
+ void array_nhwc_initialize(const int n, data_t *dst, unsigned char *ws,
+ const size_t ws_offset, const data_type_t ws_dt) const {
+ assert(!((use_workspace == false) ^ (!ws))); // ensure ws pointer exists
+ for (int oc = 0; oc < n; ++oc) {
+ if (use_workspace) {
+ assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
+ if (ws_dt == data_type::u8) {
+ ws[ws_offset + oc] = 0;
+ } else
+ reinterpret_cast<int *>(ws)[ws_offset + oc] = 0;
+ }
+ dst[oc] = nstl::numeric_limits<data_t>::lowest();
+ }
+ }
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t data_type>
+struct nhwc_pooling_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_pooling_bwd_pd_t {
+ using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T("nhwc:any", nhwc_pooling_bwd_t);
+
+ status_t init() {
+ const format_tag_t desired_fmt_tag =
+ ndims() == 4 ? format_tag::nchw : format_tag::ncdhw;
+
+ bool ok = true
+ && set_default_params() == status::success
+ && !is_fwd()
+ && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
+ alg_kind::pooling_avg_include_padding,
+ alg_kind::pooling_avg_exclude_padding)
+ && utils::everyone_is(data_type,
+ diff_dst_md()->data_type,
+ diff_src_md()->data_type)
+ && attr()->has_default_values()
+ && memory_desc_matches_tag(*diff_dst_md(), desired_fmt_tag)
+ && memory_desc_matches_tag(*diff_src_md(), desired_fmt_tag);
+ if (!ok) return status::unimplemented;
+
+ if (desc()->alg_kind == alg_kind::pooling_max) {
+ init_default_ws();
+ if (!compare_ws(hint_fwd_pd_))
+ return status::unimplemented;
+ }
+
+ return status::success;
+ }
+ };
+
+ nhwc_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}// namespace cpu
+}// namespace impl
+}// namespace mkldnn
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
new file mode 100644
index 0000000000..e20333e66f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.cpp
@@ -0,0 +1,288 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+
+#include "cpu_batch_normalization_utils.hpp"
+#include "jit_generator.hpp"
+
+#include "nspc_batch_normalization.hpp"
+
+// clang 6 and 7 generate incorrect code with OMP_SIMD in some particular cases
+#if (defined __clang_major__) && (__clang_major__ >= 6)
+#define SAFE_TO_USE_OMP_SIMD 0
+#else
+#define SAFE_TO_USE_OMP_SIMD 1
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace memory_tracking::names;
+
+void nspc_batch_normalization_fwd_t::execute_forward(
+ const exec_ctx_t &ctx) const {
+ const bool save_stats = pd()->is_training();
+ const bool is_training = pd()->is_training();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+ const bool calculate_stats = !pd()->stats_is_src();
+ const bool with_relu = pd()->with_relu_post_op();
+
+ auto scratchpad = this->scratchpad(ctx);
+ auto tmp_mean = scratchpad.get<data_t>(key_bnorm_tmp_mean);
+ auto tmp_var = scratchpad.get<data_t>(key_bnorm_tmp_var);
+ auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+
+ data_t *mean, *variance;
+ if (!calculate_stats) {
+ mean = const_cast<data_t *>(
+ CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN));
+ variance = const_cast<data_t *>(
+ CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE));
+ } else {
+ if (save_stats) {
+ mean = CTX_OUT_MEM(data_t *, MKLDNN_ARG_MEAN);
+ variance = CTX_OUT_MEM(data_t *, MKLDNN_ARG_VARIANCE);
+ } else {
+ mean = tmp_mean;
+ variance = tmp_var;
+ }
+ }
+
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ const dim_t N = pd()->MB();
+ const dim_t C = pd()->C();
+ const dim_t SP = pd()->H() * pd()->W() * pd()->D();
+
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();
+ auto maybe_post_op
+ = [&](data_t res) { return (with_relu && res < 0) ? 0 : res; };
+
+ assert(mkldnn_thr_syncable());
+ parallel(0, [&](const int ithr, const int nthr) {
+ dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0;
+ balance211(N, nthr, ithr, N_s, N_e);
+ balance211(C, nthr, ithr, C_s, C_e);
+ data_t *mean_loc = tmp_mean + nstl::max(C, (dim_t)16) * ithr;
+ data_t *variance_loc = tmp_var + nstl::max(C, (dim_t)16) * ithr;
+
+ if (calculate_stats) {
+ for (dim_t c = 0; c < C; c++)
+ ws_reduce[C * ithr + c] = 0.;
+
+ for (dim_t n = N_s; n < N_e; n++)
+ for (dim_t sp = 0; sp < SP; sp++)
+ PRAGMA_OMP_SIMD()
+ for (dim_t c = 0; c < C; c++)
+ ws_reduce[C * ithr + c] += src[(size_t)n * SP * C
+ + sp * C + c];
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = C_s; c < C_e; c++) {
+ mean[c] = 0;
+ for (dim_t n = 0; n < nthr; n++)
+ mean[c] += ws_reduce[C * n + c];
+ mean[c] /= SP * N;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = 0; c < C; c++) {
+ mean_loc[c] = mean[c];
+ ws_reduce[C * ithr + c] = 0.;
+ }
+
+ for (dim_t n = N_s; n < N_e; n++)
+ for (dim_t sp = 0; sp < SP; sp++)
+ PRAGMA_OMP_SIMD()
+ for (dim_t c = 0; c < C; c++) {
+ data_t m = src[(size_t)n * SP * C + sp * C + c]
+ - mean_loc[c];
+ ws_reduce[C * ithr + c] += m * m;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = C_s; c < C_e; c++) {
+ variance[c] = 0;
+ for (dim_t n = 0; n < nthr; n++)
+ variance[c] += ws_reduce[C * n + c];
+ variance[c] /= SP * N;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = 0; c < C; c++)
+ variance_loc[c] = variance[c];
+ } else {
+ variance_loc = variance;
+ mean_loc = mean;
+ }
+
+ for (dim_t n = N_s; n < N_e; n++) {
+ for (dim_t sp = 0; sp < SP; sp++) {
+#if SAFE_TO_USE_OMP_SIMD
+ PRAGMA_OMP_SIMD()
+#endif
+ for (dim_t c = 0; c < C; c++) {
+ data_t sqrt_variance = static_cast<data_t>(
+ sqrtf(variance_loc[c] + eps));
+ data_t sm = (use_scaleshift ? scaleshift[c] : 1.0f) / sqrt_variance;
+ data_t sv = use_scaleshift ? scaleshift[C + c] : 0;
+ size_t d_off = (size_t)n * SP * C + sp * C + c;
+ data_t bn_res = sm * (src[d_off] - mean_loc[c]) + sv;
+ if (fuse_bn_relu) {
+ if (bn_res <= 0) {
+ bn_res = 0;
+ if (is_training)
+ ws[d_off] = 0;
+ } else {
+ if (is_training)
+ ws[d_off] = 1;
+ }
+ }
+ dst[d_off] = maybe_post_op(bn_res);
+ }
+ }
+ }
+ });
+}
+
+void nspc_batch_normalization_bwd_t::execute_backward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
+ auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+ auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+ auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
+
+ auto scratchpad = this->scratchpad(ctx);
+ auto tmp_diff_ss = scratchpad.get<data_t>(key_bnorm_tmp_diff_ss);
+
+ if (diff_scaleshift == nullptr)
+ diff_scaleshift = tmp_diff_ss;
+
+ const dim_t N = pd()->MB();
+ const dim_t C = pd()->C();
+ const dim_t SP = pd()->D() * pd()->H() * pd()->W();
+ data_t *diff_gamma = diff_scaleshift, *diff_beta = diff_scaleshift + C;
+ auto *ws_reduce = scratchpad.get<data_t>(key_bnorm_reduction);
+
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();
+ const bool calculate_diff_stats = !pd()->use_global_stats();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+
+ assert(mkldnn_thr_syncable());
+ parallel(0, [&](const int ithr, const int nthr) {
+ dim_t N_s = 0, N_e = 0, C_s = 0, C_e = 0;
+ balance211(N, nthr, ithr, N_s, N_e);
+ balance211(C, nthr, ithr, C_s, C_e);
+
+ data_t *diff_gamma_loc = tmp_diff_ss + 2 * C + C * ithr;
+ data_t *diff_beta_loc = tmp_diff_ss + 2 * C + C * (nthr + ithr);
+
+ for (dim_t c = 0; c < C; c++) {
+ ws_reduce[C * ithr + c] = 0.;
+ ws_reduce[C * nthr + C * ithr + c] = 0.;
+ }
+
+ for (dim_t n = N_s; n < N_e; n++)
+ for (dim_t sp = 0; sp < SP; sp++)
+#if SAFE_TO_USE_OMP_SIMD
+ PRAGMA_OMP_SIMD()
+#endif
+ for (dim_t c = 0; c < C; c++) {
+ const size_t d_off = (size_t)n * SP * C + sp * C + c;
+ data_t dd;
+ if (fuse_bn_relu)
+ dd = (!ws[d_off]) ? 0 : diff_dst[d_off];
+ else
+ dd = diff_dst[d_off];
+ ws_reduce[C * ithr + c] += (src[d_off] - mean[c]) * dd;
+ ws_reduce[C * nthr + C * ithr + c] += dd;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = C_s; c < C_e; c++) {
+ data_t sqrt_variance
+ = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps));
+ diff_gamma[c] = 0;
+ diff_beta[c] = 0;
+ for (dim_t n = 0; n < nthr; n++) {
+ diff_gamma[c] += ws_reduce[C * n + c];
+ diff_beta[c] += ws_reduce[C * nthr + C * n + c];
+ }
+ diff_gamma[c] *= sqrt_variance;
+ }
+
+ mkldnn_thr_barrier();
+
+ for (dim_t c = 0; c < C; c++) {
+ diff_gamma_loc[c] = diff_gamma[c];
+ diff_beta_loc[c] = diff_beta[c];
+ }
+
+ for (dim_t n = N_s; n < N_e; n++) {
+ for (dim_t sp = 0; sp < SP; sp++) {
+#if SAFE_TO_USE_OMP_SIMD
+ PRAGMA_OMP_SIMD()
+#endif
+ for (dim_t c = 0; c < C; c++) {
+ const size_t d_off = (size_t)n * SP * C + sp * C + c;
+ data_t gamma = use_scaleshift ? scaleshift[c] : 1;
+ data_t sqrt_variance
+ = static_cast<data_t>(1.0f / sqrtf(variance[c] + eps));
+ data_t v_diff_src;
+ if (fuse_bn_relu)
+ v_diff_src = (!ws[d_off]) ? 0 : diff_dst[d_off];
+ else
+ v_diff_src = diff_dst[d_off];
+ if (calculate_diff_stats) {
+ v_diff_src -= diff_beta_loc[c] / (SP * N)
+ + (src[d_off] - mean[c]) * diff_gamma_loc[c]
+ * sqrt_variance / (SP * N);
+ }
+ v_diff_src *= gamma * sqrt_variance;
+ diff_src[d_off] = v_diff_src;
+ }
+ }
+ }
+ });
+}
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp
new file mode 100644
index 0000000000..aad86b05a7
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/nspc_batch_normalization.hpp
@@ -0,0 +1,169 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_NSPC_BATCH_NORMALIZATION_HPP
+#define CPU_NSPC_BATCH_NORMALIZATION_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_batch_normalization_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct nspc_batch_normalization_fwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_batch_normalization_fwd_pd_t {
+ pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_fwd_t);
+
+ status_t init() {
+ using namespace data_type;
+ using namespace prop_kind;
+
+ bool ok = true
+ /* the algorithm requires barriers while switching
+ * between parallelization over N and C dimensions */
+ && mkldnn_thr_syncable()
+ && is_fwd()
+ && !has_zero_dim_memory()
+ && src_md()->data_type == f32
+ && IMPLICATION(use_scaleshift(), weights_md()->data_type == f32)
+ && memory_desc_matches_tag(*src_md(), format_tag::nhwc)
+ && (attr()->has_default_values() || this->with_relu_post_op());
+ if (!ok) return status::unimplemented;
+
+ if (is_training() && fuse_bn_relu()) init_default_ws(8);
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ private:
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ if (!stats_is_src()) {
+ dim_t sz = nstl::max<dim_t>(C(), 16) * mkldnn_get_max_threads();
+ scratchpad.book(key_bnorm_reduction, sizeof(data_t) * sz);
+ scratchpad.book(key_bnorm_tmp_mean, sizeof(data_t) * sz);
+ scratchpad.book(key_bnorm_tmp_var, sizeof(data_t) * sz);
+ }
+ }
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ nspc_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ ~nspc_batch_normalization_fwd_t() {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+struct nspc_batch_normalization_bwd_t : public cpu_primitive_t {
+ struct pd_t : public cpu_batch_normalization_bwd_pd_t {
+ pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ DECLARE_COMMON_PD_T("nspc_bnorm:any", nspc_batch_normalization_bwd_t);
+
+ status_t init() {
+ using namespace data_type;
+ using namespace prop_kind;
+
+ bool ok = true
+ /* the algorithm requires barriers while switching
+ * between parallelization over N and C dimensions */
+ && mkldnn_thr_syncable()
+ && is_bwd()
+ && !has_zero_dim_memory()
+ && utils::everyone_is(f32, src_md()->data_type,
+ diff_src_md()->data_type)
+ && IMPLICATION(use_scaleshift(),
+ utils::everyone_is(f32,
+ weights_md()->data_type,
+ diff_weights_md()->data_type))
+ && memory_desc_matches_tag(*src_md(), format_tag::nhwc)
+ && memory_desc_matches_tag(*diff_src_md(), format_tag::nhwc)
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ if (fuse_bn_relu()) {
+ init_default_ws(8);
+ if (!compare_ws(hint_fwd_pd_))
+ return status::unimplemented;
+ }
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ private:
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ scratchpad.book(key_bnorm_reduction,
+ sizeof(data_t) * 2 * C() * mkldnn_get_max_threads());
+ scratchpad.book(key_bnorm_tmp_diff_ss, sizeof(data_t) * 2 * C()
+ * (mkldnn_get_max_threads() + 1));
+ }
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ nspc_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ ~nspc_batch_normalization_bwd_t() {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp
new file mode 100644
index 0000000000..d79b1a034b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.cpp
@@ -0,0 +1,265 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+#include "simple_q10n.hpp"
+
+#include "ref_batch_normalization.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type>
+void ref_batch_normalization_fwd_t<data_type>::execute_forward(
+ const exec_ctx_t &ctx) const {
+ /* fast return */
+ if (this->pd()->has_zero_dim_memory()) return;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto scaleshift = CTX_IN_MEM(const float *, MKLDNN_ARG_SCALE_SHIFT);
+
+ auto mean = pd()->stats_is_src()
+ ? const_cast<float *>(CTX_IN_MEM(const float *, MKLDNN_ARG_MEAN))
+ : CTX_OUT_MEM(float *, MKLDNN_ARG_MEAN);
+ auto variance = pd()->stats_is_src()
+ ? const_cast<float *>(CTX_IN_MEM(const float *, MKLDNN_ARG_VARIANCE))
+ : CTX_OUT_MEM(float *, MKLDNN_ARG_VARIANCE);
+
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+ const memory_desc_wrapper scaleshift_d(pd()->weights_md());
+
+ const dim_t N = pd()->MB();
+ const dim_t C = pd()->C();
+ dim_t H = 1, W = 1, D = 1;
+ const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5);
+ if (has_spatial) {
+ D = pd()->D();
+ H = pd()->H();
+ W = pd()->W();
+ }
+
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();;
+ const bool save_stats = pd()->is_training();
+ const bool is_training = pd()->is_training();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+ const bool calculate_stats = !pd()->stats_is_src();
+
+ const bool with_relu = pd()->with_relu_post_op();
+ auto maybe_post_op = [&](float res) {
+ return (with_relu && res < 0.0f) ? 0.0f : res;
+ };
+ const bool is_3d = data_d.ndims() == 5;
+
+ auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c,
+ dim_t d, dim_t h, dim_t w) {
+ if (has_spatial) {
+ if (is_3d)
+ return data_d.off(n, c, d, h, w);
+ else
+ return data_d.off(n, c, h, w);
+ } else
+ return data_d.off(n, c);
+ };
+
+ parallel_nd(C, [&](dim_t c) {
+ float v_mean = calculate_stats ? 0 : mean[c];
+ float v_variance = calculate_stats ? 0 : variance[c];
+
+ if (calculate_stats) {
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w)
+ v_mean += src[data_offset(data_d, n, c, d, h, w)];
+ v_mean /= W*N*H*D;
+
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w) {
+ float m = src[data_offset(data_d, n, c, d, h, w)] - v_mean;
+ v_variance += m*m;
+ }
+ v_variance /= W*H*N*D;
+ }
+
+ float sqrt_variance = sqrtf(v_variance + eps);
+ float sm = (use_scaleshift
+ ? scaleshift[scaleshift_d.off(0, c)]
+ : 1.0f) / sqrt_variance;
+ float sv = use_scaleshift ? scaleshift[scaleshift_d.off(1, c)] : 0;
+
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w) {
+ auto d_off = data_offset(data_d,n,c,d,h,w);
+ float bn_res = sm * ((float)src[d_off] - v_mean) + sv;
+ if (fuse_bn_relu) {
+ if (bn_res <= 0) {
+ bn_res = 0;
+ if (is_training)
+ ws[d_off] = 0;
+ } else {
+ if (is_training)
+ ws[d_off] = 1;
+ }
+ }
+ if (data_type == data_type::s8) {
+ dst[d_off] = qz_a1b0<float, data_t>()(maybe_post_op(bn_res));
+ } else {
+ dst[d_off] = static_cast<data_t>(maybe_post_op(bn_res));
+ }
+ }
+
+ if (calculate_stats) {
+ if (save_stats) {
+ mean[c] = v_mean;
+ variance[c] = v_variance;
+ }
+ }
+ });
+}
+
+template struct ref_batch_normalization_fwd_t<data_type::f32>;
+template struct ref_batch_normalization_fwd_t<data_type::s8>;
+
+template <impl::data_type_t data_type>
+void ref_batch_normalization_bwd_t<data_type>::execute_backward(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto mean = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MEAN);
+ auto variance = CTX_IN_MEM(const data_t *, MKLDNN_ARG_VARIANCE);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto scaleshift = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SCALE_SHIFT);
+ auto ws = CTX_IN_MEM(const uint8_t *, MKLDNN_ARG_WORKSPACE);
+
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+ auto diff_scaleshift = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SCALE_SHIFT);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+ const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
+ const memory_desc_wrapper scaleshift_d(pd()->weights_md());
+ const memory_desc_wrapper diff_scaleshift_d(pd()->diff_weights_md());
+
+ const dim_t C = pd()->C();
+
+ /* fast return */
+ if (this->pd()->has_zero_dim_memory()) {
+ if (diff_scaleshift) {
+ for (dim_t c = 0; c < C; ++c) {
+ diff_scaleshift[diff_scaleshift_d.off(0, c)] = 0;
+ diff_scaleshift[diff_scaleshift_d.off(1, c)] = 0;
+ }
+ }
+ return;
+ }
+
+ const dim_t N = pd()->MB();
+ dim_t H = 1, W = 1, D = 1;
+ const bool has_spatial = utils::one_of(data_d.ndims(), 4, 5);
+ if (has_spatial) {
+ D = pd()->D();
+ H = pd()->H();
+ W = pd()->W();
+ }
+
+ const float eps = pd()->desc()->batch_norm_epsilon;
+ const bool use_scaleshift = pd()->use_scaleshift();
+ const bool calculate_diff_stats = !pd()->use_global_stats();
+ const bool fuse_bn_relu = pd()->fuse_bn_relu();
+
+ const bool is_3d = data_d.ndims() == 5;
+
+ auto data_offset = [&](const memory_desc_wrapper &data_d, dim_t n, dim_t c,
+ dim_t d, dim_t h, dim_t w) {
+ if (has_spatial) {
+ if (is_3d)
+ return data_d.off(n, c, d, h, w);
+ else
+ return data_d.off(n, c, h, w);
+ } else
+ return data_d.off(n, c);
+ };
+
+ parallel_nd(C, [&](dim_t c) {
+ data_t v_mean = mean[c];
+ data_t v_variance = variance[c];
+ data_t sqrt_variance = static_cast<data_t>(1.0f / sqrtf(v_variance + eps));
+ data_t gamma = use_scaleshift ? scaleshift[scaleshift_d.off(0, c)] : 1;
+ data_t diff_gamma = data_t(0);
+ data_t diff_beta = data_t(0);
+ diff_gamma = 0.0;
+ diff_beta = 0.0;
+
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w) {
+ const size_t s_off = data_offset(data_d, n, c, d, h, w);
+ data_t dd = diff_dst[data_offset(diff_data_d, n, c, d, h, w)];
+ if (fuse_bn_relu && !ws[s_off])
+ dd = 0;
+
+ diff_gamma += (src[s_off] - v_mean) * dd;
+ diff_beta += dd;
+ }
+ diff_gamma *= sqrt_variance;
+
+ if (diff_scaleshift) {
+ diff_scaleshift[diff_scaleshift_d.off(0, c)] = diff_gamma;
+ diff_scaleshift[diff_scaleshift_d.off(1, c)] = diff_beta;
+ }
+
+ for (dim_t n = 0; n < N; ++n)
+ for (dim_t d = 0; d < D; ++d)
+ for (dim_t h = 0; h < H; ++h)
+ for (dim_t w = 0; w < W; ++w) {
+ const size_t s_off = data_offset(data_d, n, c, d, h, w);
+ const size_t dd_off = data_offset(diff_data_d, n, c, d, h, w);
+ data_t dd = diff_dst[dd_off];
+ if (fuse_bn_relu && !ws[s_off])
+ dd = 0;
+
+ data_t v_diff_src = dd;
+ if (calculate_diff_stats) {
+ v_diff_src -= diff_beta/(D*W*H*N) +
+ (src[s_off] - v_mean) *
+ diff_gamma*sqrt_variance/(D*W*H*N);
+ }
+ v_diff_src *= gamma*sqrt_variance;
+ diff_src[dd_off] = v_diff_src;
+ }
+ });
+}
+
+template struct ref_batch_normalization_bwd_t<data_type::f32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp
new file mode 100644
index 0000000000..aa9f74125a
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_batch_normalization.hpp
@@ -0,0 +1,127 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_BATCH_NORMALIZATION_HPP
+#define CPU_REF_BATCH_NORMALIZATION_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_batch_normalization_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type>
+struct ref_batch_normalization_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_batch_normalization_fwd_pd_t {
+ pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : cpu_batch_normalization_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && src_md()->data_type == data_type
+ && IMPLICATION(use_scaleshift(),
+ weights_md()->data_type == data_type::f32)
+ && (attr()->has_default_values() || with_relu_post_op());
+ if (!ok) return status::unimplemented;
+
+ if (src_md()->data_type == data_type::s8 && !stats_is_src())
+ return status::unimplemented;
+
+ if (is_training() && fuse_bn_relu()) init_default_ws(8);
+
+ return status::success;
+ }
+ };
+
+ ref_batch_normalization_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t data_type>
+struct ref_batch_normalization_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_batch_normalization_bwd_pd_t {
+ pd_t(engine_t *engine, const batch_normalization_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const batch_normalization_fwd_pd_t *hint_fwd_pd)
+ : cpu_batch_normalization_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ {}
+
+ DECLARE_COMMON_PD_T("ref:any", ref_batch_normalization_bwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_bwd()
+ && utils::everyone_is(data_type, src_md()->data_type,
+ diff_src_md()->data_type)
+ && IMPLICATION(use_scaleshift(), utils::everyone_is(data_type,
+ weights_md()->data_type,
+ diff_weights_md()->data_type))
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ if (fuse_bn_relu()) {
+ init_default_ws(8);
+ if (!compare_ws(hint_fwd_pd_))
+ return status::unimplemented;
+ }
+
+ return status::success;
+ }
+ };
+
+ ref_batch_normalization_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp
new file mode 100644
index 0000000000..4c534b5508
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_concat.hpp
@@ -0,0 +1,97 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef REF_CONCAT_HPP
+#define REF_CONCAT_HPP
+
+#include "reorder_pd.hpp"
+
+#include "cpu_concat_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct ref_concat_t: public cpu_primitive_t {
+ struct pd_t: public cpu_concat_pd_t {
+ using cpu_concat_pd_t::cpu_concat_pd_t;
+
+ pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) {
+ for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i)
+ reorder_pds_.push_back(
+ (const reorder_pd_t *)rhs.reorder_pds_[i]->clone());
+ }
+ ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; }
+
+ DECLARE_CONCAT_PD_T("ref:any", ref_concat_t);
+
+ status_t init() {
+ bool ok = cpu_concat_pd_t::init() == status::success;
+ if (!ok) return status::unimplemented;
+
+ for (int i = 0; i < n_; ++i) {
+ auto r_impls = engine_->get_reorder_implementation_list();
+ for (auto r = r_impls; *r; ++r) {
+ const primitive_attr_t attr; /* alpha == 1. */
+ reorder_pd_t *r_pd = nullptr;
+ if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i),
+ engine_, src_image_md(i)) == status::success) {
+ r_pd->init_info();
+ reorder_pds_.push_back(r_pd);
+ break;
+ }
+ }
+ }
+
+ ok = reorder_pds_.size() == (size_t)n_;
+ return ok ? status::success : status::unimplemented;
+ }
+
+ nstl::vector<const reorder_pd_t *> reorder_pds_;
+ };
+
+ ref_concat_t(const pd_t *apd): cpu_primitive_t(apd) {
+ const int n = pd()->n_inputs();
+ reorders_.resize(n);
+ for (int i = 0; i < n; ++i)
+ pd()->reorder_pds_[i]->create_primitive(&reorders_[i]);
+ }
+
+ ~ref_concat_t() { for (auto &r: reorders_) delete r; }
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ const auto n = pd()->n_inputs();
+ for (int i = 0; i < n; ++i) {
+ exec_args_t r_args;
+ r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i);
+ r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST);
+ exec_ctx_t r_ctx(ctx.stream(), std::move(r_args));
+ reorders_[i]->execute(r_ctx);
+ }
+ return status::success;
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ nstl::vector<primitive_t *> reorders_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp
new file mode 100644
index 0000000000..c0a979c4cf
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.cpp
@@ -0,0 +1,395 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+#include "mkldnn_traits.hpp"
+#include "type_helpers.hpp"
+
+#include "ref_convolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using math::saturate;
+using math::get_bias;
+
+template <data_type_t src_type, data_type_t wei_type,
+ data_type_t dst_type, data_type_t acc_type>
+void ref_convolution_fwd_t<src_type, wei_type, dst_type, acc_type>::
+execute_forward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const bool with_groups = pd()->with_groups();
+
+ const int G = pd()->G();
+ const int MB = pd()->MB();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+
+ const int OC = pd()->OC() / G;
+ const int IC = pd()->IC() / G;
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+
+ const int KSD = pd()->KSD();
+ const int KSH = pd()->KSH();
+ const int KSW = pd()->KSW();
+
+ const int KDD = pd()->KDD();
+ const int KDH = pd()->KDH();
+ const int KDW = pd()->KDW();
+
+ const int padFront = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+
+ const bool with_relu = 0; // TODO: change if support post_ops
+ const float nslope = 0.f;
+
+ const int ndims = pd()->desc()->src_desc.ndims;
+
+ auto ker = [=](int g, int mb, int oc, int od, int oh,
+ int ow) {
+ acc_data_t d = 0;
+ for (int ic = 0; ic < IC; ++ic)
+ for (int kd = 0; kd < KD; ++kd)
+ for (int kh = 0; kh < KH; ++kh)
+ for (int kw = 0; kw < KW; ++kw) {
+ const int id = od * KSD - padFront + kd * (1 + KDD);
+ const int ih = oh * KSH - padT + kh * (1 + KDH);
+ const int iw = ow * KSW - padL + kw * (1 + KDW);
+
+ if (id < 0 || id >= ID) continue;
+ if (ih < 0 || ih >= IH) continue;
+ if (iw < 0 || iw >= IW) continue;
+
+ if (ndims == 5)
+ d += (acc_data_t)src[src_d.off(mb, g*IC + ic, id, ih, iw)]
+ * (with_groups
+ ? weights[weights_d.off(g, oc, ic, kd, kh, kw)]
+ : weights[weights_d.off(oc, ic, kd, kh, kw)]);
+ else if (ndims == 4)
+ d += (acc_data_t)src[src_d.off(mb, g*IC + ic, ih, iw)]
+ * (with_groups
+ ? weights[weights_d.off(g, oc, ic, kh, kw)]
+ : weights[weights_d.off(oc, ic, kh, kw)]);
+ else if (ndims == 3)
+ d += (acc_data_t)src[src_d.off(mb, g*IC + ic, iw)]
+ * (with_groups
+ ? weights[weights_d.off(g, oc, ic, kw)]
+ : weights[weights_d.off(oc, ic, kw)]);
+ else
+ assert(false);
+
+ }
+ return d;
+ };
+
+ parallel_nd(G, MB, OC, OD, OH, OW,
+ [&](int g, int mb, int oc, int od, int oh, int ow) {
+ float a = bias
+ ? get_bias(bias, bias_d.off(g * OC + oc),
+ pd()->desc()->bias_desc.data_type)
+ : 0;
+ a += ker(g, mb, oc, od, oh, ow);
+ if (with_relu && a < 0)
+ a = a * nslope;
+ if (ndims == 5)
+ dst[dst_d.off(mb, g*OC + oc, od, oh, ow)] = saturate<dst_data_t>(a);
+ else if (ndims == 4)
+ dst[dst_d.off(mb, g*OC + oc, oh, ow)] = saturate<dst_data_t>(a);
+ else if (ndims == 3)
+ dst[dst_d.off(mb, g*OC + oc, ow)] = saturate<dst_data_t>(a);
+ else
+ assert(false);
+ });
+}
+
+template <data_type_t diff_src_type, data_type_t wei_type,
+ data_type_t diff_dst_type, data_type_t acc_type>
+void ref_convolution_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
+ acc_type>::execute_backward_data(const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const bool with_groups = pd()->with_groups();
+
+ const int G = pd()->G();
+ const int MB = pd()->MB();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+
+ const int OC = pd()->OC() / G;
+ const int IC = pd()->IC() / G;
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+
+ const int KSD = pd()->KSD();
+ const int KSH = pd()->KSH();
+ const int KSW = pd()->KSW();
+
+ const int KDD = pd()->KDD();
+ const int KDH = pd()->KDH();
+ const int KDW = pd()->KDW();
+
+ const int padFront = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+
+ const int ndims = pd()->desc()->diff_src_desc.ndims;
+
+ auto ker = [=](int g, int mb, int ic, int id, int ih,
+ int iw) {
+ acc_data_t d = 0;
+ for (int oc = 0; oc < OC; ++oc)
+ for (int kd = 0; kd < KD; ++kd)
+ for (int kh = 0; kh < KH; ++kh)
+ for (int kw = 0; kw < KW; ++kw) {
+ if (iw + padL < kw * (1 + KDW)
+ || ih + padT < kh * (1 + KDH)
+ || id + padFront < kd * (1 + KDD))
+ continue;
+ int ow = iw - kw * (1 + KDW) + padL;
+ int oh = ih - kh * (1 + KDH) + padT;
+ int od = id - kd * (1 + KDD) + padFront;
+ if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0)
+ continue;
+
+ ow /= KSW;
+ oh /= KSH;
+ od /= KSD;
+
+ if (od < OD && oh < OH && ow < OW) {
+ if (ndims == 5)
+ d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
+ + oc, od, oh, ow)] * (with_groups
+ ? weights[weights_d.off(g, oc, ic, kd, kh, kw)]
+ : weights[weights_d.off(oc, ic, kd, kh, kw)]);
+ else if (ndims == 4)
+ d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
+ + oc, oh, ow)] * (with_groups
+ ? weights[weights_d.off(g, oc, ic, kh, kw)]
+ : weights[weights_d.off(oc, ic, kh, kw)]);
+ else if (ndims == 3)
+ d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC
+ + oc, ow)] * (with_groups
+ ? weights[weights_d.off(g, oc, ic, kw)]
+ : weights[weights_d.off(oc, ic, kw)]);
+ else
+ assert(false);
+ }
+ }
+ return d;
+ };
+
+ parallel_nd(G, MB, IC, ID, IH, IW,
+ [&](int g, int mb, int ic, int id, int ih, int iw) {
+ auto ds_idx = (ndims == 5)
+ ? diff_src_d.off(mb, g*IC + ic, id, ih, iw)
+ : (ndims == 4)
+ ? diff_src_d.off(mb, g*IC + ic, ih, iw)
+ : diff_src_d.off(mb, g*IC + ic, iw);
+ float a = bias
+ ? get_bias(bias, bias_d.off(g * IC + ic),
+ pd()->desc()->bias_desc.data_type)
+ : 0;
+ a += ker(g, mb, ic, id, ih, iw);
+ diff_src[ds_idx] = saturate<diff_src_data_t>(a);
+ });
+}
+
+template <data_type_t src_type, data_type_t diff_wei_type,
+ data_type_t diff_dst_type, data_type_t acc_type>
+void ref_convolution_bwd_weights_t<src_type, diff_wei_type, diff_dst_type,
+ acc_type>::execute_backward_weights(const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto diff_weights = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto diff_bias = CTX_OUT_MEM(diff_wei_data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
+ const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
+
+ const bool with_groups = pd()->with_groups();
+
+ const int G = pd()->G();
+ const int MB = pd()->MB();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+
+ const int OC = pd()->OC() / G;
+ const int IC = pd()->IC() / G;
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+
+ const int KSD = pd()->KSD();
+ const int KSH = pd()->KSH();
+ const int KSW = pd()->KSW();
+
+ const int KDD = pd()->KDD();
+ const int KDH = pd()->KDH();
+ const int KDW = pd()->KDW();
+
+ const int padFront = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+
+ const int ndims = pd()->desc()->src_desc.ndims;
+
+auto ker = [=](acc_data_t &d, int g, int oc, int ic, int kd, int kh, int kw) {
+ for (int mb = 0; mb < MB; ++mb)
+ for (int od = 0; od < OD; ++od)
+ for (int oh = 0; oh < OH; ++oh)
+ for (int ow = 0; ow < OW; ++ow) {
+ if (ow*KSW + kw * (1 + KDW) < padL
+ || oh*KSH + kh * (1 + KDH) < padT
+ || od*KSD + kd * (1 + KDD) < padFront
+ || ow*KSW + kw * (1 + KDW) >= IW + padL
+ || oh*KSH + kh * (1 + KDH) >= IH + padT
+ || od*KSD + kd * (1 + KDD) >= ID + padFront)
+ continue;
+
+ int id = od*KSD - padFront + kd * (1 + KDD);
+ int ih = oh*KSH - padT + kh * (1 + KDH);
+ int iw = ow*KSW - padL + kw * (1 + KDW);
+ if (ndims == 5)
+ d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od,
+ oh, ow)] * src[src_d.off(mb, g*IC + ic, id, ih, iw)];
+ else if (ndims == 4)
+ d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh, ow)]
+ * src[src_d.off(mb, g*IC + ic, ih, iw)];
+ else if (ndims == 3)
+ d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)]
+ * src[src_d.off(mb, g*IC + ic, iw)];
+ else
+ assert(false);
+ }
+ };
+
+ auto ker_bias = [=](acc_data_t &d, int g, int oc) {
+ for (int mb = 0; mb < MB; ++mb)
+ for (int od = 0; od < OD; ++od)
+ for (int oh = 0; oh < OH; ++oh)
+ for (int ow = 0; ow < OW; ++ow) {
+ if (ndims == 5)
+ d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, od, oh,
+ ow)];
+ else if (ndims == 4)
+ d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, oh,
+ ow)];
+ else if (ndims == 3)
+ d += (acc_data_t)diff_dst[diff_dst_d.off(mb, g*OC + oc, ow)];
+ else
+ assert(false);
+ }
+ };
+
+ parallel_nd(G, OC, [&](int g, int oc) {
+ if (diff_bias) {
+ // XXX: loss of precision when bias is a float...
+ acc_data_t db = 0;
+ ker_bias(db, g, oc);
+ diff_bias[diff_bias_d.off(g*OC+oc)]
+ = saturate<diff_wei_data_t>(db);
+ }
+
+ for (int ic = 0; ic < IC; ++ic)
+ for (int kd = 0; kd < KD; ++kd)
+ for (int kh = 0; kh < KH; ++kh)
+ for (int kw = 0; kw < KW; ++kw) {
+ acc_data_t dw = 0;
+ ker(dw, g, oc, ic, kd, kh, kw);
+
+ if (ndims == 5) {
+ auto idx = with_groups
+ ? diff_weights_d.off(g, oc, ic, kd, kh, kw)
+ : diff_weights_d.off(oc, ic, kd, kh, kw);
+ diff_weights[idx] = saturate<diff_wei_data_t>(dw);
+ } else if (ndims == 4) {
+ auto idx = with_groups
+ ? diff_weights_d.off(g, oc, ic, kh, kw)
+ : diff_weights_d.off(oc, ic, kh, kw);
+ diff_weights[idx] = saturate<diff_wei_data_t>(dw);
+ } else if (ndims == 3) {
+ auto idx = with_groups
+ ? diff_weights_d.off(g, oc, ic, kw)
+ : diff_weights_d.off(oc, ic, kw);
+ diff_weights[idx] = saturate<diff_wei_data_t>(dw);
+ } else {
+ assert(false);
+ }
+ }
+ });
+}
+
+using namespace data_type;
+
+template struct ref_convolution_fwd_t<f32>;
+
+template struct ref_convolution_fwd_t<u8, s8, f32, s32>;
+template struct ref_convolution_fwd_t<u8, s8, s32, s32>;
+template struct ref_convolution_fwd_t<u8, s8, s8, s32>;
+template struct ref_convolution_fwd_t<u8, s8, u8, s32>;
+
+template struct ref_convolution_bwd_data_t<f32, f32, f32, f32>;
+
+template struct ref_convolution_bwd_data_t<f32, s8, u8, s32>;
+template struct ref_convolution_bwd_data_t<s32, s8, u8, s32>;
+template struct ref_convolution_bwd_data_t<s8, s8, u8, s32>;
+template struct ref_convolution_bwd_data_t<u8, s8, u8, s32>;
+
+template struct ref_convolution_bwd_weights_t<f32, f32, f32, f32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp
new file mode 100644
index 0000000000..7c83d0c6d4
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_convolution.hpp
@@ -0,0 +1,194 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_CONVOLUTION_HPP
+#define CPU_REF_CONVOLUTION_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t src_type,
+ impl::data_type_t wei_type = src_type,
+ impl::data_type_t dst_type = src_type,
+ impl::data_type_t acc_type = dst_type>
+struct ref_convolution_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_convolution_fwd_t);
+
+ status_t init() {
+ using namespace data_type;
+
+ bool ok = true
+ && is_fwd()
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(src_type, wei_type, data_type::undef,
+ dst_type, acc_type)
+ && IMPLICATION(with_bias(), true
+ && IMPLICATION(src_type == u8,
+ utils::one_of(bias_md_.data_type, f32, s32, s8, u8))
+ && IMPLICATION(src_type == f32,
+ bias_md_.data_type == f32))
+ && set_default_formats()
+ && attr()->has_default_values();
+ return ok ? status::success : status::unimplemented;
+ }
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
+ auto wei_tag = with_groups()
+ ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
+ : utils::pick(ndims() - 3, oiw, oihw, oidhw);
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ ref_convolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<wei_type>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+ typedef typename prec_traits<acc_type>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t diff_src_type, impl::data_type_t wei_type,
+ impl::data_type_t diff_dst_type,
+ impl::data_type_t acc_type = diff_src_type>
+struct ref_convolution_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_data_pd_t {
+ using cpu_convolution_bwd_data_pd_t::cpu_convolution_bwd_data_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(diff_src_type, wei_type, data_type::undef,
+ diff_dst_type, acc_type)
+ && set_default_formats()
+ && attr()->has_default_values();
+
+ return ok ? status::success : status::unimplemented;
+ }
+
+ virtual bool support_bias() const override { return true; }
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
+ auto wei_tag = with_groups()
+ ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
+ : utils::pick(ndims() - 3, oiw, oihw, oidhw);
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ ref_convolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
+ typedef typename prec_traits<wei_type>::type wei_data_t;
+ typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
+ typedef typename prec_traits<acc_type>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t src_type, impl::data_type_t diff_wei_type,
+ impl::data_type_t diff_dst_type,
+ impl::data_type_t acc_type = diff_wei_type>
+struct ref_convolution_bwd_weights_t: public cpu_primitive_t {
+ struct pd_t: public cpu_convolution_bwd_weights_pd_t {
+ using cpu_convolution_bwd_weights_pd_t::cpu_convolution_bwd_weights_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_convolution_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_weights
+ && set_default_alg_kind(alg_kind::convolution_direct)
+ && expect_data_types(src_type, diff_wei_type, diff_wei_type,
+ diff_dst_type, acc_type)
+ && set_default_formats()
+ && attr()->has_default_values();
+ return ok ? status::success : status::unimplemented;
+ }
+
+ protected:
+ bool set_default_formats() {
+ using namespace format_tag;
+ auto dat_tag = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
+ auto wei_tag = with_groups()
+ ? utils::pick(ndims() - 3, goiw, goihw, goidhw)
+ : utils::pick(ndims() - 3, oiw, oihw, oidhw);
+ return set_default_formats_common(dat_tag, wei_tag, dat_tag);
+ }
+ };
+
+ ref_convolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<diff_wei_type>::type diff_wei_data_t;
+ typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
+ typedef typename prec_traits<acc_type>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_weights(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_weights(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp
new file mode 100644
index 0000000000..541a303aab
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.cpp
@@ -0,0 +1,199 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+#include "mkldnn_traits.hpp"
+#include "math_utils.hpp"
+
+#include "ref_deconvolution.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+void ref_deconvolution_fwd_t::compute_fwd_bias(const data_t *bias,
+ data_t *dst) const {
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+
+ const int G = pd()->G();
+ const int MB = pd()->MB();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int OD = pd()->OD();
+ const int OC = pd()->OC() / G;
+ const int ndims = pd()->desc()->src_desc.ndims;
+
+ parallel_nd(MB, G, OC, OD, OH, OW,
+ [&](int mb, int g, int oc, int od, int oh, int ow) {
+ auto b = bias[g * OC + oc];
+ switch (ndims) {
+ case 5: dst[dst_d.off(mb, g * OC + oc, od, oh, ow)] += b; break;
+ case 4: dst[dst_d.off(mb, g * OC + oc, oh, ow)] += b; break;
+ case 3: dst[dst_d.off(mb, g * OC + oc, ow)] += b; break;
+ default: assert(!"invalid dimension size");
+ }
+ });
+}
+
+void ref_deconvolution_fwd_t::compute_fwd_bias_ncdhw(const data_t *bias,
+ data_t *dst) const {
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int SP = pd()->OW()*pd()->OH()*pd()->OD();
+
+ parallel_nd(MB, OC, [&](int mb, int oc) {
+ PRAGMA_OMP_SIMD()
+ for (int sp = 0; sp < SP; ++sp) {
+ auto offset = (size_t)(mb * OC + oc) * SP + sp;
+ dst[offset] += bias[oc];
+ }
+ });
+}
+
+template <int blksize>
+void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc(const data_t *bias,
+ data_t *dst) const {
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int SP = pd()->OW() * pd()->OH() * pd()->OD();
+
+ const ptrdiff_t stride_mb = dst_d.blocking_desc().strides[0];
+
+ parallel_nd(MB, utils::div_up(OC, blksize), SP,
+ [&](int mb, int oc_blk, int sp) {
+ int oc = oc_blk * blksize;
+ auto offset = mb * stride_mb + oc * SP + sp * blksize;
+ const int blk = nstl::min(blksize, OC - oc);
+
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < blk; ++i)
+ dst[offset + i] += bias[oc + i];
+ });
+}
+
+void ref_deconvolution_bwd_weights_t::compute_bwd_bias(const data_t *diff_dst,
+ data_t *diff_bias) const {
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+
+ const int G = pd()->G();
+ const int MB = pd()->MB();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+ const int OC = pd()->OC() / G;
+ const int OD = pd()->OD();
+ const int ndims = pd()->desc()->src_desc.ndims;
+
+ parallel_nd(G, OC, [&](int g, int oc) {
+ data_t db = 0;
+ for (int mb = 0; mb < MB; ++mb) {
+ for (int od = 0; od < OD; ++od) {
+ for (int oh = 0; oh < OH; ++oh) {
+ for (int ow = 0; ow < OW; ++ow) {
+ switch (ndims) {
+ case 5:
+ db += diff_dst[diff_dst_d.off(
+ mb, g * OC + oc, od, oh, ow)];
+ break;
+ case 4:
+ db += diff_dst[diff_dst_d.off(
+ mb, g * OC + oc, oh, ow)];
+ break;
+ case 3:
+ db += diff_dst[diff_dst_d.off(mb, g * OC + oc, ow)];
+ break;
+ default: assert(!"invalid dimension size");
+ }
+ }
+ }
+ }
+ }
+ diff_bias[g * OC + oc] = db;
+ });
+}
+
+void ref_deconvolution_bwd_weights_t::compute_bwd_bias_ncdhw(
+ const data_t *diff_dst, data_t *diff_bias) const {
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+
+ const int OC = pd()->OC();
+ const int MB = pd()->MB();
+ const int SP = pd()->OH()*pd()->OW()*pd()->OD();
+
+ parallel_nd(OC, [&](int oc) {
+ data_t db = 0;
+ for (int mb = 0; mb < MB; ++mb) {
+ PRAGMA_OMP_SIMD()
+ for (int sp = 0; sp < SP; ++sp) {
+ auto offset = (size_t)(mb * OC + oc) * SP + sp;
+ db += diff_dst[offset];
+ }
+ }
+ diff_bias[oc] = db;
+ });
+}
+
+template <int blksize>
+void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc(
+ const data_t *diff_dst, data_t *diff_bias) const {
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+
+ const int OC = pd()->OC();
+ const int MB = pd()->MB();
+ const int SP = pd()->OH() * pd()->OW() * pd()->OD();
+
+ const ptrdiff_t stride_mb = diff_dst_d.blocking_desc().strides[0];
+
+ parallel_nd(utils::div_up(OC, blksize), [&](int ocb) {
+ data_t db[blksize] = {0};
+
+ for (int mb = 0; mb < MB; ++mb) {
+ for (int sp = 0; sp < SP; ++sp) {
+ auto offset = mb * stride_mb + (ocb * SP + sp) * blksize;
+
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < blksize; ++i)
+ db[i] += diff_dst[offset+i];
+ }
+ }
+
+ const int blk = nstl::min(blksize, OC - ocb * blksize);
+
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < blk; ++i)
+ diff_bias[ocb * blksize + i] = db[i];
+ });
+}
+
+template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<8>(
+ const data_t *diff_dst, data_t *diff_bias) const;
+template void ref_deconvolution_fwd_t::compute_fwd_bias_nCdhwXc<16>(
+ const data_t *diff_dst, data_t *diff_bias) const;
+template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<8>(
+ const data_t *diff_dst, data_t *diff_bias) const;
+template void ref_deconvolution_bwd_weights_t::compute_bwd_bias_nCdhwXc<16>(
+ const data_t *diff_dst, data_t *diff_bias) const;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp
new file mode 100644
index 0000000000..d61903c32d
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_deconvolution.hpp
@@ -0,0 +1,502 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_DECONVOLUTION_HPP
+#define CPU_REF_DECONVOLUTION_HPP
+
+#include <assert.h>
+#include <string.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "primitive_iterator.hpp"
+
+#include "cpu_convolution_pd.hpp"
+#include "cpu_deconvolution_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+static status_t compute_blocked_format(bool with_groups,
+ const memory_desc_t *oi_md, memory_desc_t *io_md)
+{
+ /* Computes blocking for *i*o* format from *o*i* format */
+
+ bool sanity_check_ok = true
+ && oi_md->ndims == io_md->ndims
+ && oi_md->format_kind == format_kind::blocked;
+ if (!sanity_check_ok) return status::invalid_arguments;
+
+ const blocking_desc_t &oi_blk = oi_md->format_desc.blocking;
+ blocking_desc_t io_blk = io_md->format_desc.blocking;
+
+ io_md->format_kind = format_kind::blocked;
+ io_blk = oi_blk;
+
+ const int ID_OC = 0 + with_groups;
+ const int ID_IC = 1 + with_groups;
+
+ nstl::swap(io_blk.strides[ID_OC], io_blk.strides[ID_IC]);
+ for (int i_blk = 0; i_blk < io_blk.inner_nblks; ++i_blk) {
+ if (utils::one_of(io_blk.inner_idxs[i_blk], ID_OC, ID_IC)) {
+ io_blk.inner_idxs[i_blk] =
+ (io_blk.inner_idxs[i_blk] == ID_OC ? ID_IC : ID_OC);
+ }
+ }
+
+ return memory_desc_init_by_blocking_desc(*io_md, io_blk);
+}
+
+static status_t conv_descr_create(const deconvolution_desc_t *dd,
+ convolution_desc_t *cd)
+{
+ using namespace prop_kind;
+ alg_kind_t alg_kind = dd->alg_kind == alg_kind::deconvolution_direct
+ ? alg_kind::convolution_direct : alg_kind::convolution_winograd;
+
+ const memory_desc_t *src_md, *dst_md, *d_weights_d;
+ prop_kind_t prop_kind;
+ memory_desc_t c_weights_d;
+ if (utils::one_of(dd->prop_kind, forward_training, forward_inference)) {
+ prop_kind = backward_data;
+ src_md = &dd->dst_desc;
+ dst_md = &dd->src_desc;
+ d_weights_d = &dd->weights_desc;
+ } else if (dd->prop_kind == backward_data) {
+ prop_kind = forward_training;
+ src_md = &dd->diff_dst_desc;
+ dst_md = &dd->diff_src_desc;
+ d_weights_d = &dd->weights_desc;
+ } else {
+ prop_kind = dd->prop_kind;
+ src_md = &dd->diff_dst_desc;
+ dst_md = &dd->src_desc;
+ d_weights_d = &dd->diff_weights_desc;
+ }
+
+ const bool with_groups = d_weights_d->ndims == src_md->ndims + 1;
+
+ /* create weights desc for convolution */
+ c_weights_d = *d_weights_d;
+
+ const int ID_OC = 0 + with_groups;
+ const int ID_IC = 1 + with_groups;
+
+ nstl::swap(c_weights_d.dims[ID_OC], c_weights_d.dims[ID_IC]);
+ nstl::swap(c_weights_d.padded_dims[ID_OC], c_weights_d.padded_dims[ID_IC]);
+ nstl::swap(c_weights_d.padded_offsets[ID_OC], c_weights_d.padded_offsets[ID_IC]);
+
+ if (c_weights_d.format_kind != format_kind::any)
+ CHECK(compute_blocked_format(with_groups, d_weights_d, &c_weights_d));
+
+ return conv_desc_init(cd, prop_kind, alg_kind, src_md, &c_weights_d,
+ prop_kind != backward_weights ? &dd->bias_desc : nullptr,
+ dst_md, dd->strides, dd->dilates,
+ dd->padding[0], dd->padding[1], dd->padding_kind);
+}
+
+struct ref_deconvolution_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_deconvolution_fwd_pd_t {
+ pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_deconvolution_fwd_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , conv_pd_(nullptr)
+ {}
+
+ pd_t(const pd_t &other)
+ : cpu_deconvolution_fwd_pd_t(other)
+ , conv_pd_(other.conv_pd_->clone())
+ , conv_supports_bias_(other.conv_supports_bias_)
+ , dst_tag_(other.dst_tag_)
+ {}
+
+ ~pd_t() { delete conv_pd_; }
+
+ DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_fwd_t);
+
+ status_t init_convolution() {
+ using namespace types;
+
+ convolution_desc_t cd;
+ CHECK(conv_descr_create(desc(), &cd));
+
+ mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
+ &attr_, nullptr);
+ while (++it != it.end()) {
+ conv_pd_ = *it;
+ conv_supports_bias_ =
+ static_cast<cpu_convolution_bwd_data_pd_t *>(conv_pd_)
+ ->support_bias();
+ bool output_f32 = utils::everyone_is(data_type::f32,
+ desc()->accum_data_type, desc()->dst_desc.data_type);
+
+ bool ok = true
+ && conv_pd_->weights_md()->extra.flags == 0
+ /* deconv reference code can process only f32 bias */
+ && IMPLICATION(with_bias(),
+ conv_supports_bias_ || output_f32);
+ if (ok) return status::success;
+
+ delete conv_pd_;
+ }
+ conv_pd_ = nullptr;
+ return status::unimplemented;
+ }
+
+ status_t init() {
+ using namespace format_tag;
+ bool ok = true
+ && is_fwd()
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::deconvolution_direct,
+ alg_kind::deconvolution_winograd)
+ && attr()->post_ops_.has_default_values();
+
+ if (ok) {
+ CHECK(init_convolution());
+ if (weights_md_.format_kind == format_kind::any) {
+ CHECK(compute_blocked_format(with_groups(),
+ conv_pd_->weights_md(), &desc_.weights_desc));
+ weights_md_ = desc_.weights_desc;
+ }
+ if (src_md_.format_kind == format_kind::any)
+ src_md_ = *conv_pd_->diff_dst_md();
+ if (dst_md_.format_kind == format_kind::any)
+ dst_md_ = *conv_pd_->diff_src_md();
+ if (bias_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(bias_md_, x));
+
+ dst_tag_ = memory_desc_matches_one_of_tag(dst_md_,
+ utils::pick(ndims() - 3, ncw, nchw, ncdhw),
+ utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
+ utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
+
+ return status::success;
+ }
+
+ return status::unimplemented;
+ }
+
+ virtual void init_scratchpad_md() override {
+ scratchpad_md_ = *conv_pd_->scratchpad_md();
+ }
+
+ primitive_desc_t *conv_pd_;
+ bool conv_supports_bias_;
+ format_tag_t dst_tag_;
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ ref_deconvolution_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
+ { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
+ ~ref_deconvolution_fwd_t() { delete conv_p_; }
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ const auto &args = ctx.args();
+ exec_args_t conv_args;
+ conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC);
+ conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS);
+ if (pd()->with_bias() && pd()->conv_supports_bias_)
+ conv_args[MKLDNN_ARG_BIAS] = args.at(MKLDNN_ARG_BIAS);
+ conv_args[MKLDNN_ARG_DIFF_SRC] = args.at(MKLDNN_ARG_DST);
+ if (!types::is_zero_md(pd()->scratchpad_md()))
+ conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
+ const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
+
+ conv_p_->execute(conv_ctx);
+
+ if (pd()->with_bias() && !pd()->conv_supports_bias_) {
+ using namespace format_tag;
+
+ auto bias = CTX_IN_MEM(const data_t *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ switch (pd()->dst_tag_) {
+ case ncdhw: case nchw: case ncw:
+ compute_fwd_bias_ncdhw(bias, dst);
+ break;
+ case nCdhw8c: case nChw8c: case nCw8c:
+ compute_fwd_bias_nCdhwXc<8>(bias, dst);
+ break;
+ case nCdhw16c: case nChw16c: case nCw16c:
+ compute_fwd_bias_nCdhwXc<16>(bias, dst);
+ break;
+ default:
+ compute_fwd_bias(bias, dst);
+ break;
+ }
+ }
+ return status::success;
+ }
+
+private:
+ void compute_fwd_bias(const data_t *bias, data_t *dst) const;
+ void compute_fwd_bias_ncdhw(const data_t *bias, data_t *dst) const;
+ template <int blksize> void compute_fwd_bias_nCdhwXc(const data_t *bias,
+ data_t *dst) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ primitive_t *conv_p_;
+};
+
+struct ref_deconvolution_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_deconvolution_bwd_data_pd_t {
+ pd_t(engine_t *engine, const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_deconvolution_bwd_data_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , conv_pd_(nullptr)
+ {}
+
+ pd_t(const pd_t &other)
+ : cpu_deconvolution_bwd_data_pd_t(other)
+ , conv_pd_(other.conv_pd_->clone()) {}
+
+ ~pd_t() { delete conv_pd_; }
+
+ DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_data_t);
+
+ status_t init_convolution() {
+ using namespace types;
+
+ convolution_desc_t cd;
+ status_t status = conv_descr_create(desc(), &cd);
+ if (status != status::success) return status;
+
+ mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
+ &attr_, nullptr);
+ while (++it != it.end()) {
+ conv_pd_ = *it;
+ if (conv_pd_->weights_md()->extra.flags == 0)
+ return status::success;
+ delete conv_pd_;
+ }
+
+ return status::unimplemented;
+ }
+
+ status_t init() {
+ using namespace data_type;
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_data
+ && utils::everyone_is(data_type::f32,
+ desc()->diff_src_desc.data_type,
+ desc()->weights_desc.data_type,
+ desc()->diff_dst_desc.data_type)
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::deconvolution_direct,
+ alg_kind::deconvolution_winograd);
+
+ if (ok) {
+ CHECK(init_convolution());
+ if (weights_md_.format_kind == format_kind::any) {
+ CHECK(compute_blocked_format(with_groups(),
+ conv_pd_->weights_md(), &desc_.weights_desc));
+ weights_md_ = desc_.weights_desc;
+ }
+ if (diff_src_md_.format_kind == format_kind::any)
+ diff_src_md_ = *conv_pd_->dst_md();
+ if (diff_dst_md_.format_kind == format_kind::any)
+ diff_dst_md_ = *conv_pd_->src_md();
+
+ return status::success;
+ }
+
+ return status::unimplemented;
+ }
+
+ virtual void init_scratchpad_md() override {
+ scratchpad_md_ = *conv_pd_->scratchpad_md();
+ }
+
+ primitive_desc_t *conv_pd_;
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ ref_deconvolution_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd)
+ { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
+ ~ref_deconvolution_bwd_data_t() { delete conv_p_; }
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ const auto &args = ctx.args();
+ exec_args_t conv_args;
+ conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST);
+ conv_args[MKLDNN_ARG_WEIGHTS] = args.at(MKLDNN_ARG_WEIGHTS);
+ conv_args[MKLDNN_ARG_DST] = args.at(MKLDNN_ARG_DIFF_SRC);
+ if (!types::is_zero_md(pd()->scratchpad_md()))
+ conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
+ const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
+
+ conv_p_->execute(conv_ctx);
+ return status::success;
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ primitive_t *conv_p_;
+};
+
+struct ref_deconvolution_bwd_weights_t: public cpu_primitive_t {
+ struct pd_t: public cpu_deconvolution_bwd_weights_pd_t {
+ pd_t(engine_t *engine,
+ const deconvolution_desc_t *adesc,
+ const primitive_attr_t *attr,
+ const deconvolution_fwd_pd_t *hint_fwd_pd)
+ : cpu_deconvolution_bwd_weights_pd_t(engine, adesc, attr, hint_fwd_pd)
+ , conv_pd_(nullptr)
+ {}
+
+ pd_t(const pd_t &other)
+ : cpu_deconvolution_bwd_weights_pd_t(other)
+ , conv_pd_(other.conv_pd_->clone())
+ , dst_tag_(other.dst_tag_)
+ {}
+
+ ~pd_t() { delete conv_pd_; }
+
+ DECLARE_COMMON_PD_T(conv_pd_->name(), ref_deconvolution_bwd_weights_t);
+
+ status_t init_convolution() {
+ using namespace types;
+
+ convolution_desc_t cd;
+ status_t status = conv_descr_create(desc(), &cd);
+ if (status != status::success) return status;
+
+ mkldnn_primitive_desc_iterator it(engine_, (op_desc_t *)&cd,
+ &attr_, nullptr);
+ while (++it != it.end()) {
+ conv_pd_ = *it;
+ if (conv_pd_->diff_weights_md()->extra.flags == 0)
+ return status::success;
+ delete conv_pd_;
+ }
+ return status::unimplemented;
+ }
+
+ status_t init() {
+ using namespace format_tag;
+ bool ok = true
+ && desc()->prop_kind == prop_kind::backward_weights
+ && utils::everyone_is(data_type::f32,
+ desc()->src_desc.data_type,
+ desc()->diff_weights_desc.data_type,
+ desc()->diff_dst_desc.data_type)
+ && utils::one_of(desc()->alg_kind,
+ alg_kind::deconvolution_direct,
+ alg_kind::deconvolution_winograd)
+ && attr()->has_default_values();
+ if (ok) {
+ CHECK(init_convolution());
+ if (diff_weights_md_.format_kind == format_kind::any) {
+ CHECK(compute_blocked_format(with_groups(),
+ conv_pd_->diff_weights_md(),
+ &desc_.diff_weights_desc));
+ diff_weights_md_ = desc_.diff_weights_desc;
+ }
+ if (src_md_.format_kind == format_kind::any)
+ src_md_ = *conv_pd_->diff_dst_md();
+ if (diff_dst_md_.format_kind == format_kind::any)
+ diff_dst_md_ = *conv_pd_->src_md();
+ if (diff_bias_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_bias_md_, x));
+
+ dst_tag_ = memory_desc_matches_one_of_tag(diff_dst_md_,
+ utils::pick(ndims() - 3, ncw, nchw, ncdhw),
+ utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c),
+ utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c));
+
+ return status::success;
+ }
+
+ return status::unimplemented;
+ }
+
+ virtual void init_scratchpad_md() override {
+ scratchpad_md_ = *conv_pd_->scratchpad_md();
+ }
+
+ primitive_desc_t *conv_pd_;
+ format_tag_t dst_tag_;
+ };
+
+ typedef typename prec_traits<data_type::f32>::type data_t;
+
+ ref_deconvolution_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd)
+ { pd()->conv_pd_->create_primitive((primitive_t **)&conv_p_); }
+ ~ref_deconvolution_bwd_weights_t() { delete conv_p_; }
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ const auto &args = ctx.args();
+ exec_args_t conv_args;
+ conv_args[MKLDNN_ARG_DIFF_DST] = args.at(MKLDNN_ARG_SRC);
+ conv_args[MKLDNN_ARG_SRC] = args.at(MKLDNN_ARG_DIFF_DST);
+ conv_args[MKLDNN_ARG_DIFF_WEIGHTS] = args.at(MKLDNN_ARG_DIFF_WEIGHTS);
+ if (!types::is_zero_md(pd()->scratchpad_md()))
+ conv_args[MKLDNN_ARG_SCRATCHPAD] = args.at(MKLDNN_ARG_SCRATCHPAD);
+ const exec_ctx_t conv_ctx(ctx.stream(), std::move(conv_args));
+
+ status_t status = conv_p_->execute(conv_ctx);
+ if (status != status::success) return status;
+
+ if (pd()->with_bias()) {
+ using namespace format_tag;
+
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ switch (pd()->dst_tag_) {
+ case ncdhw: case nchw: case ncw:
+ compute_bwd_bias_ncdhw(diff_dst, diff_bias);
+ break;
+ case nCdhw8c: case nChw8c: case nCw8c:
+ compute_bwd_bias_nCdhwXc<8>(diff_dst, diff_bias);
+ break;
+ case nCdhw16c: case nChw16c: case nCw16c:
+ compute_bwd_bias_nCdhwXc<16>(diff_dst, diff_bias);
+ break;
+ default:
+ compute_bwd_bias(diff_dst, diff_bias);
+ break;
+ }
+ }
+ return status::success;
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ void compute_bwd_bias(const data_t *diff_dst, data_t *diff_bias) const;
+ void compute_bwd_bias_ncdhw(const data_t *diff_dst,
+ data_t *diff_bias) const;
+ template <int blksize> void compute_bwd_bias_nCdhwXc(
+ const data_t *diff_dst, data_t *diff_bias) const;
+
+ primitive_t *conv_p_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp
new file mode 100644
index 0000000000..7beee8d323
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.cpp
@@ -0,0 +1,297 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_eltwise.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace alg_kind;
+using namespace math;
+
+ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha,
+ float beta): alg_(alg), alpha_(alpha), beta_(beta) {
+ assert(utils::one_of(alg_, eltwise_relu, eltwise_tanh, eltwise_elu,
+ eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+ eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic));
+}
+
+ref_eltwise_scalar_fwd_t::ref_eltwise_scalar_fwd_t(
+ const post_ops_t::entry_t::eltwise_t &eltwise)
+ : ref_eltwise_scalar_fwd_t(eltwise.alg, eltwise.alpha, eltwise.beta) {}
+
+float ref_eltwise_scalar_fwd_t::compute_scalar(float s) {
+ switch (alg_) {
+ case eltwise_relu: return relu_fwd(s, alpha_);
+ case eltwise_tanh: return tanh_fwd(s);
+ case eltwise_elu: return elu_fwd(s, alpha_);
+ case eltwise_square: return square_fwd(s);
+ case eltwise_abs: return abs_fwd(s);
+ case eltwise_sqrt: return sqrt_fwd(s);
+ case eltwise_linear: return linear_fwd(s, alpha_, beta_);
+ case eltwise_bounded_relu: return bounded_relu_fwd(s, alpha_);
+ case eltwise_soft_relu: return soft_relu_fwd(s);
+ case eltwise_logistic: return logistic_fwd(s);
+ default: assert(!"unknown eltwise alg_kind");
+ }
+
+ return 0.f;
+}
+
+template <impl::data_type_t data_type>
+void ref_eltwise_fwd_t<data_type>::execute_forward_nCspBc_padded(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+ const blocking_desc_t &blk = data_d.blocking_desc();
+ const int block = blk.inner_blks[0];
+
+ const int MB = pd()->MB();
+ const int C = pd()->C() / block;
+ const int C_PADDED = data_d.padded_dims()[1] / block;
+ const int tail = pd()->C() % block;
+ const int SP = pd()->D() * pd()->H() * pd()->W();
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
+
+ auto ker = [=] (data_t &d, data_t s) {
+ switch (alg_kind) {
+ case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
+ case eltwise_bounded_relu:
+ d = bounded_relu_fwd(s, alpha); break;
+ case eltwise_soft_relu: d = soft_relu_fwd(s); break;
+ case eltwise_logistic: d = logistic_fwd(s); break;
+ default: assert(!"unknown eltwise alg_kind");
+ }
+ };
+
+ // FIXME: integer overflow?
+
+ parallel_nd(MB, C_PADDED, SP,
+ [&](int n, int c, int sp) {
+ auto d_off = (n*C_PADDED*SP + c*SP + sp) * block;
+ if (c < C) {
+ for (int v = 0; v < block; v++)
+ ker(dst[d_off + v], src[d_off + v]);
+ } else {
+ for (int v = 0; v < tail; v++)
+ ker(dst[d_off + v], src[d_off + v]);
+ }
+ });
+}
+
+template <impl::data_type_t data_type>
+void ref_eltwise_fwd_t<data_type>::execute_forward_generic(
+ const exec_ctx_t &ctx) const {
+ /* fast return */
+ if (pd()->has_zero_dim_memory()) return;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+
+ const int MB = pd()->MB();
+ const int C = pd()->C();
+ const int D = pd()->D();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
+ const bool is_3d = pd()->desc()->data_desc.ndims == 5;
+
+ parallel_nd(MB, C, D, H, W,
+ [&](int n, int c, int id, int h, int w) {
+ auto d_off = is_3d
+ ? data_d.off(n, c, id, h, w) : data_d.off(n, c, h, w);
+ data_t s = src[d_off];
+ data_t &d = dst[d_off];
+ switch (alg_kind) {
+ case eltwise_relu: d = relu_fwd(s, alpha); break;
+ case eltwise_tanh: d = tanh_fwd(s); break;
+ case eltwise_elu: d = elu_fwd(s, alpha); break;
+ case eltwise_square: d = square_fwd(s); break;
+ case eltwise_abs: d = abs_fwd(s); break;
+ case eltwise_sqrt: d = sqrt_fwd(s); break;
+ case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
+ case eltwise_bounded_relu:
+ d = bounded_relu_fwd(s, alpha); break;
+ case eltwise_soft_relu: d = soft_relu_fwd(s); break;
+ case eltwise_logistic: d = logistic_fwd(s); break;
+ default: assert(!"unknown eltwise alg_kind");
+ }
+ });
+}
+
+template <impl::data_type_t data_type>
+void ref_eltwise_fwd_t<data_type>::execute_forward_dense(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+
+ const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
+
+ src += data_d.offset0();
+ dst += data_d.offset0();
+
+ if (alg_kind == eltwise_relu) {
+ // a fast path for relu as the most popular activation
+ parallel_nd(nelems, [&](ptrdiff_t e) {
+ dst[e] = relu_fwd(src[e], alpha);
+ });
+ return;
+ }
+
+ parallel_nd(nelems, [&](ptrdiff_t e) {
+ const data_t s = src[e];
+ data_t &d = dst[e];
+
+ switch (alg_kind) {
+ case eltwise_tanh: d = tanh_fwd(s); break;
+ case eltwise_elu: d = elu_fwd(s, alpha); break;
+ case eltwise_square: d = square_fwd(s); break;
+ case eltwise_abs: d = abs_fwd(s); break;
+ case eltwise_sqrt: d = sqrt_fwd(s); break;
+ case eltwise_linear: d = linear_fwd(s, alpha, beta); break;
+ case eltwise_bounded_relu: d = bounded_relu_fwd(s, alpha); break;
+ case eltwise_soft_relu: d = soft_relu_fwd(s); break;
+ case eltwise_logistic: d = logistic_fwd(s); break;
+ default: assert(!"unknown eltwise alg_kind");
+ }
+ });
+}
+
+template <impl::data_type_t data_type>
+void ref_eltwise_bwd_t<data_type>::execute_backward_generic(
+ const exec_ctx_t &ctx) const {
+ /* fast return */
+ if (pd()->has_zero_dim_memory()) return;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+ const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
+
+ const int MB = pd()->MB();
+ const int C = pd()->C();
+ const int D = pd()->D();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
+ const bool is_3d = pd()->desc()->data_desc.ndims == 5;
+
+ parallel_nd(MB, C, D, H, W,
+ [&](int n, int c, int d, int h, int w) {
+ auto data_off = is_3d
+ ? data_d.off(n, c, d, h, w) : data_d.off(n, c, h, w);
+ auto diff_data_off = is_3d
+ ? diff_data_d.off(n, c, d, h, w)
+ : diff_data_d.off(n, c, h, w);
+ data_t s = src[data_off];
+ data_t dd = diff_dst[diff_data_off];
+ data_t &ds = diff_src[diff_data_off];
+ switch (alg_kind) {
+ case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
+ case eltwise_tanh: ds = tanh_bwd(dd, s); break;
+ case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
+ case eltwise_square: ds = square_bwd(dd, s); break;
+ case eltwise_abs: ds = abs_bwd(dd, s); break;
+ case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
+ case eltwise_linear:
+ ds = linear_bwd(dd, s, alpha, beta); break;
+ case eltwise_bounded_relu:
+ ds = bounded_relu_bwd(dd, s, alpha); break;
+ case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
+ case eltwise_logistic: ds = logistic_bwd(dd, s); break;
+ default: assert(!"unknown eltwise alg_kind");
+ }
+ });
+}
+
+template <impl::data_type_t data_type>
+void ref_eltwise_bwd_t<data_type>::execute_backward_dense(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+ const memory_desc_wrapper diff_data_d(pd()->diff_src_md());
+
+ const ptrdiff_t nelems = static_cast<ptrdiff_t>(data_d.nelems(true));
+ const auto alg_kind = pd()->desc()->alg_kind;
+ const float alpha = pd()->desc()->alpha;
+ const float beta = pd()->desc()->beta;
+
+ src += data_d.offset0();
+ diff_dst += diff_data_d.offset0();
+ diff_src += diff_data_d.offset0();
+
+ parallel_nd(nelems, [&](ptrdiff_t e) {
+ const data_t dd = diff_dst[e];
+ const data_t s = src[e];
+ data_t &ds = diff_src[e];
+
+ switch (alg_kind) {
+ case eltwise_relu: ds = relu_bwd(dd, s, alpha); break;
+ case eltwise_tanh: ds = tanh_bwd(dd, s); break;
+ case eltwise_elu: ds = elu_bwd(dd, s, alpha); break;
+ case eltwise_square: ds = square_bwd(dd, s); break;
+ case eltwise_abs: ds = abs_bwd(dd, s); break;
+ case eltwise_sqrt: ds = sqrt_bwd(dd, s); break;
+ case eltwise_linear: ds = linear_bwd(dd, s, alpha, beta); break;
+ case eltwise_bounded_relu: ds = bounded_relu_bwd(dd, s, alpha); break;
+ case eltwise_soft_relu: ds = soft_relu_bwd(dd, s); break;
+ case eltwise_logistic: ds = logistic_bwd(dd, s); break;
+ default: assert(!"unknown eltwise alg_kind");
+ }
+ });
+}
+
+template struct ref_eltwise_fwd_t<data_type::f32>;
+template struct ref_eltwise_fwd_t<data_type::s32>;
+template struct ref_eltwise_fwd_t<data_type::s8>;
+template struct ref_eltwise_fwd_t<data_type::u8>;
+
+template struct ref_eltwise_bwd_t<data_type::f32>;
+template struct ref_eltwise_bwd_t<data_type::s32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp
new file mode 100644
index 0000000000..8f4ab35413
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_eltwise.hpp
@@ -0,0 +1,168 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_ELTWISE_HPP
+#define CPU_REF_ELTWISE_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_eltwise_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct ref_eltwise_scalar_fwd_t {
+public:
+ ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, float beta);
+
+ // note that eltwise.scale is ignored
+ ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise);
+
+ float compute_scalar(float s);
+
+ const alg_kind_t alg_;
+ const float alpha_;
+ const float beta_;
+};
+
+template <impl::data_type_t data_type>
+struct ref_eltwise_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_eltwise_fwd_pd_t {
+ using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_eltwise_fwd_t);
+
+ status_t init() {
+ using namespace utils;
+
+ auto src_d = memory_desc_wrapper(src_md());
+
+ use_dense_ = false
+ || src_d.is_dense()
+ || (src_d.is_dense(true) && is_zero_preserved());
+
+ use_nCspBc_padded_ = !use_dense_
+ && src_d.blocking_desc().inner_nblks == 1
+ && one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
+ && src_d.blocking_desc().inner_idxs[0] == 1
+ && src_d.only_padded_dim(1)
+ && src_d.is_dense(true);
+
+ if (has_zero_dim_memory())
+ use_dense_ = use_nCspBc_padded_ = false;
+
+ const bool use_generic = !use_dense_ && !use_nCspBc_padded_;
+
+ bool ok = true
+ && is_fwd()
+ && everyone_is(data_type, desc()->data_desc.data_type)
+ && IMPLICATION(use_generic, one_of(src_d.ndims(), 4, 5))
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ return status::success;
+ }
+
+ bool use_dense_, use_nCspBc_padded_;
+ };
+
+ ref_eltwise_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ if (pd()->use_dense_)
+ execute_forward_dense(ctx);
+ else if (pd()->use_nCspBc_padded_)
+ execute_forward_nCspBc_padded(ctx);
+ else
+ execute_forward_generic(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward_nCspBc_padded(const exec_ctx_t &ctx) const;
+ void execute_forward_dense(const exec_ctx_t &ctx) const;
+ void execute_forward_generic(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t data_type>
+struct ref_eltwise_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_eltwise_bwd_pd_t {
+ using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_eltwise_bwd_t);
+
+ status_t init() {
+ using namespace utils;
+
+ bool ok = true
+ && !is_fwd()
+ && everyone_is(data_type,
+ desc()->data_desc.data_type,
+ desc()->diff_data_desc.data_type)
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ auto diff_dst_d = memory_desc_wrapper(diff_dst_md());
+ const bool same_fmt_ = diff_dst_d == memory_desc_wrapper(src_md());
+
+ use_dense_ = true
+ && same_fmt_
+ && diff_dst_d.is_dense(true)
+ && is_zero_preserved()
+ && !has_zero_dim_memory();
+ const bool use_generic = !use_dense_;
+
+ if (use_generic && !one_of(diff_dst_d.ndims(), 4, 5))
+ return status::unimplemented;
+
+ return status::success;
+ }
+
+ bool use_dense_;
+ };
+
+ ref_eltwise_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ if (pd()->use_dense_)
+ execute_backward_dense(ctx);
+ else
+ execute_backward_generic(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_dense(const exec_ctx_t &ctx) const;
+ void execute_backward_generic(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp
new file mode 100644
index 0000000000..c807a9ffd0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.cpp
@@ -0,0 +1,285 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+#include "mkldnn_traits.hpp"
+#include "math_utils.hpp"
+
+#include "ref_inner_product.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using math::saturate;
+using math::get_bias;
+
+template <data_type_t src_type, data_type_t wei_type, data_type_t dst_type,
+ data_type_t acc_type>
+void ref_inner_product_fwd_t<src_type, wei_type, dst_type, acc_type>::
+execute_forward(const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto bias = CTX_IN_MEM(const char *, MKLDNN_ARG_BIAS);
+ auto dst = CTX_OUT_MEM(dst_data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper bias_d(pd()->weights_md(1));
+
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC();
+
+ const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4, 5);
+ const int ndims = src_d.ndims() - 2;
+
+ const auto &post_ops = pd()->attr()->post_ops_;
+ const bool do_relu = post_ops.len_ == 1;
+ const float nslope = do_relu ? post_ops.entry_[0].eltwise.alpha : 0.f;
+
+ auto ker_has_spatial = [=](int mb, int oc) {
+ acc_data_t d = 0;
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ for (int ic = 0; ic < IC; ++ic) {
+ for (int kd = 0; kd < KD; ++kd) {
+ for (int kh = 0; kh < KH; ++kh) {
+ for (int kw = 0; kw < KW; ++kw) {
+ switch (ndims) {
+ case 3:
+ d += (acc_data_t)src[src_d.off(mb, ic, kd, kh, kw)]
+ * weights[weights_d.off(
+ oc, ic, kd, kh, kw)];
+ break;
+ case 2:
+ d += (acc_data_t)src[src_d.off(mb, ic, kh, kw)]
+ * weights[weights_d.off(oc, ic, kh, kw)];
+ break;
+ case 1:
+ d += (acc_data_t)src[src_d.off(mb, ic, kw)]
+ * weights[weights_d.off(oc, ic, kw)];
+ break;
+ default: assert(!"unsupported ndims size");
+ }
+ }
+ }
+ }
+ }
+ return d;
+ };
+
+ auto ker_no_spatial = [=](int mb, int oc) {
+ acc_data_t d = 0;
+ for (int ic = 0; ic < IC; ++ic) {
+ d += (acc_data_t)src[src_d.off(mb, ic)]
+ * weights[weights_d.off(oc, ic)];
+ }
+ return d;
+ };
+
+ parallel_nd(MB, OC, [&](int mb, int oc) {
+ float a = bias
+ ? get_bias(bias, bias_d.off(oc), pd()->desc()->bias_desc.data_type)
+ : 0;
+ if (src_has_spatial)
+ a += ker_has_spatial(mb, oc);
+ else
+ a += ker_no_spatial(mb, oc);
+ if (do_relu && a < (acc_data_t)0)
+ a *= nslope;
+ dst[dst_d.off(mb, oc)] = saturate<dst_data_t>(a);
+ });
+}
+
+using namespace data_type;
+template struct ref_inner_product_fwd_t<f32>;
+template struct ref_inner_product_fwd_t<u8, s8, f32, s32>;
+template struct ref_inner_product_fwd_t<u8, s8, s32, s32>;
+template struct ref_inner_product_fwd_t<u8, s8, s8, s32>;
+template struct ref_inner_product_fwd_t<u8, s8, u8, s32>;
+
+template <data_type_t diff_src_type, data_type_t wei_type,
+ data_type_t diff_dst_type, data_type_t acc_type>
+void ref_inner_product_bwd_data_t<diff_src_type, wei_type, diff_dst_type,
+ acc_type>::execute_backward_data(const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const diff_dst_data_t *, MKLDNN_ARG_DIFF_DST);
+ auto weights = CTX_IN_MEM(const wei_data_t *, MKLDNN_ARG_WEIGHTS);
+ auto diff_src = CTX_OUT_MEM(diff_src_data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper weights_d(pd()->weights_md(0));
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC();
+
+ const bool diff_src_has_spatial
+ = utils::one_of(diff_src_d.ndims(), 3, 4, 5);
+ const int ndims = diff_src_d.ndims() - 2;
+
+ parallel_nd(MB, IC, [&](int mb, int ic) {
+ if (diff_src_has_spatial) {
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ for (int kd = 0; kd < KD; ++kd)
+ for (int kh = 0; kh < KH; ++kh)
+ for (int kw = 0; kw < KW; ++kw) {
+ acc_data_t ds = acc_data_t(0);
+ for (int oc = 0; oc < OC; ++oc) {
+ switch (ndims) {
+ case 3:
+ ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
+ * weights[weights_d.off(oc, ic, kd, kh, kw)]);
+ break;
+ case 2:
+ ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
+ * weights[weights_d.off(oc, ic, kh, kw)]);
+ break;
+ case 1:
+ ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)]
+ * weights[weights_d.off(oc, ic, kw)]);
+ break;
+ default: assert(!"unsupported ndims size");
+ }
+ }
+ switch (ndims) {
+ case 3:
+ diff_src[diff_src_d.off(mb, ic, kd, kh, kw)]
+ = (diff_src_data_t)ds;
+ break;
+ case 2:
+ diff_src[diff_src_d.off(mb, ic, kh, kw)]
+ = (diff_src_data_t)ds;
+ break;
+ case 1:
+ diff_src[diff_src_d.off(mb, ic, kw)] = (diff_src_data_t)ds;
+ break;
+ default: assert(!"unsupported ndims size");
+ }
+ }
+ } else {
+ acc_data_t ds = acc_data_t(0);
+ for (int oc = 0; oc < OC; ++oc) {
+ ds += (acc_data_t)(diff_dst[diff_dst_d.off(mb, oc)] *
+ weights[weights_d.off(oc, ic)]);
+ }
+ diff_src[diff_src_d.off(mb, ic)] = (diff_src_data_t)ds;
+ }
+ });
+}
+
+template struct ref_inner_product_bwd_data_t<f32, f32, f32, f32>;
+
+template <impl::data_type_t data_type>
+void ref_inner_product_bwd_weights_t<data_type>::execute_backward_weights(
+ const exec_ctx_t &ctx) const {
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_weights = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_WEIGHTS);
+ auto diff_bias = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_BIAS);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
+ const memory_desc_wrapper diff_bias_d(pd()->diff_weights_md(1));
+
+ const int MB = pd()->MB();
+ const int OC = pd()->OC();
+ const int IC = pd()->IC();
+
+ const bool src_has_spatial = utils::one_of(src_d.ndims(), 3, 4 ,5);
+ const int ndims = src_d.ndims() - 2;
+
+ parallel_nd(OC, IC, [&](int oc, int ic) {
+ if (src_has_spatial) {
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ for (int kd = 0; kd < KD; ++kd) {
+ for (int kh = 0; kh < KH; ++kh) {
+ for (int kw = 0; kw < KW; ++kw) {
+ data_t *dw(nullptr);
+ switch (ndims) {
+ case 3:
+ dw = &diff_weights[diff_weights_d.off(
+ oc, ic, kd, kh, kw)];
+ break;
+ case 2:
+ dw = &diff_weights[diff_weights_d.off(
+ oc, ic, kh, kw)];
+ break;
+ case 1:
+ dw = &diff_weights[diff_weights_d.off(oc, ic, kw)];
+ break;
+ default: assert(!"unsupported ndims size");
+ }
+ *dw = data_t(0);
+ for (int mb = 0; mb < MB; ++mb) {
+ switch (ndims) {
+ case 3:
+ *dw += diff_dst[diff_dst_d.off(mb, oc)]
+ * src[src_d.off(mb, ic, kd, kh, kw)];
+ break;
+ case 2:
+ *dw += diff_dst[diff_dst_d.off(mb, oc)]
+ * src[src_d.off(mb, ic, kh, kw)];
+ break;
+ case 1:
+ *dw += diff_dst[diff_dst_d.off(mb, oc)]
+ * src[src_d.off(mb, ic, kw)];
+ break;
+ default: assert(!"unsupported ndims size");
+ }
+ }
+ }
+ }
+ }
+ } else {
+ data_t *dw = &diff_weights[diff_weights_d.off(oc, ic)];
+ *dw = data_t(0);
+ for (int mb = 0; mb < MB; ++mb) {
+ *dw += diff_dst[diff_dst_d.off(mb, oc)] *
+ src[src_d.off(mb, ic)];
+ }
+ }
+ });
+
+ if (diff_bias) {
+ diff_bias += diff_bias_d.offset0();
+
+ parallel_nd(OC, [&](int oc) {
+ data_t *db = &diff_bias[oc];
+ *db = data_t(0);
+ for (int mb = 0; mb < MB; ++mb)
+ *db += diff_dst[diff_dst_d.off(mb, oc)];
+ });
+ }
+}
+
+template struct ref_inner_product_bwd_weights_t<data_type::f32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp
new file mode 100644
index 0000000000..bf87dbd514
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_inner_product.hpp
@@ -0,0 +1,159 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_INNER_PRODUCT_HPP
+#define CPU_REF_INNER_PRODUCT_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_inner_product_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t src_type, impl::data_type_t wei_type = src_type,
+ impl::data_type_t dst_type = src_type,
+ impl::data_type_t acc_type = dst_type>
+struct ref_inner_product_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_inner_product_fwd_pd_t {
+ using cpu_inner_product_fwd_pd_t::cpu_inner_product_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_inner_product_fwd_t);
+
+ status_t init() {
+ using namespace data_type;
+
+ bool ok = true
+ && set_default_params() == status::success
+ && is_fwd()
+ && src_md()->data_type == src_type
+ && weights_md()->data_type == wei_type
+ && desc()->accum_data_type == acc_type
+ && dst_md()->data_type == dst_type
+ && IMPLICATION(with_bias(), utils::one_of(
+ weights_md(1)->data_type, f32, s32, s8, u8))
+ && attr()->output_scales_.has_default_values()
+ && attr()->post_ops_.len_ <= 1
+ && IMPLICATION(attr()->post_ops_.len_ == 1,
+ attr()->post_ops_.entry_[0].is_relu(true, false));
+ return ok ? status::success : status::unimplemented;
+ }
+ };
+
+ ref_inner_product_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<wei_type>::type wei_data_t;
+ typedef typename prec_traits<dst_type>::type dst_data_t;
+ typedef typename prec_traits<acc_type>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t diff_src_type, impl::data_type_t wei_type,
+ impl::data_type_t diff_dst_type,
+ impl::data_type_t acc_type = diff_src_type>
+struct ref_inner_product_bwd_data_t: public cpu_primitive_t {
+ struct pd_t: public cpu_inner_product_bwd_data_pd_t {
+ using cpu_inner_product_bwd_data_pd_t::cpu_inner_product_bwd_data_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_data_t);
+
+ status_t init() {
+ bool ok = true
+ && set_default_params() == status::success
+ && desc()->prop_kind == prop_kind::backward_data
+ && diff_src_md()->data_type == diff_src_type
+ && weights_md()->data_type == wei_type
+ && desc()->accum_data_type == acc_type
+ && diff_dst_md()->data_type == diff_dst_type
+ && attr()->has_default_values();
+ return ok ? status::success : status::unimplemented;
+ }
+ };
+
+ ref_inner_product_bwd_data_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
+ typedef typename prec_traits<wei_type>::type wei_data_t;
+ typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
+ typedef typename prec_traits<acc_type>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_data(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_data(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t data_type>
+struct ref_inner_product_bwd_weights_t: public cpu_primitive_t {
+ struct pd_t: public cpu_inner_product_bwd_weights_pd_t {
+ using cpu_inner_product_bwd_weights_pd_t::cpu_inner_product_bwd_weights_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_inner_product_bwd_weights_t);
+
+ status_t init() {
+ bool ok = true
+ && set_default_params() == status::success
+ && desc()->prop_kind == prop_kind::backward_weights
+ && utils::everyone_is(data_type,
+ src_md()->data_type,
+ diff_dst_md()->data_type,
+ diff_weights_md()->data_type)
+ && IMPLICATION(with_bias(),
+ data_type == diff_weights_md(1)->data_type)
+ && attr()->has_default_values();
+ return ok ? status::success : status::unimplemented;
+ }
+ };
+
+ ref_inner_product_bwd_weights_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward_weights(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_weights(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp
new file mode 100644
index 0000000000..325e97963b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.cpp
@@ -0,0 +1,252 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+
+#include "ref_lrn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+static inline float fast_negative_powf(float omega, float beta) {
+ float Y;
+/*
+ * Y = omega^(-3/4) =
+ * = 1.0f / sqrtf(omega) * sqrtf(1.0f / sqrtf(omega))
+ * = sqrtf(1.0f / sqrtf(omega)) * 1.0f / sqrtf(omega)
+ * = sqrtf(1.0f / sqrtf(omega)) / sqrtf(omega)
+ * = sqrtf(1.0f / sqrtf(omega) / omega)
+ * = sqrtf(1.0f / (sqrtf(omega) * omega))
+ */
+ if (beta == 0.75f) {
+ Y = sqrtf(1.0f / (sqrtf(omega) * omega));
+ } else {
+ Y = 1.0f / powf(omega, beta);
+ }
+ return Y;
+};
+
+template <impl::data_type_t data_type>
+template <impl::format_tag_t tag>
+void ref_lrn_fwd_t<data_type>::execute_forward(const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+ using namespace format_tag;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const size_t stride_mb = data_d.blocking_desc().strides[0];
+ const bool across_channels = pd()->desc()->alg_kind == lrn_across_channels;
+ constexpr int blksize = tag == nChw16c ? 16 : 8;
+
+ auto data_off = [&](int mb, int c, int h, int w) -> size_t {
+ switch (tag) {
+ case nChw16c:
+ case nChw8c: return mb * stride_mb + c / blksize * H * W * blksize
+ + h * W * blksize + w * blksize + c % blksize;
+ case nchw: return mb * stride_mb + c * H * W + h * W + w;
+ case nhwc: return mb * stride_mb + h * W * C + w * C + c;
+ default: return data_d.off(mb, c, h, w);
+ }
+ };
+
+ auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
+ const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
+ const float beta = static_cast<float>(pd()->desc()->lrn_beta);
+ const float k = static_cast<float>(pd()->desc()->lrn_k);
+
+ const int size = pd()->desc()->local_size;
+ const int half_size = (size - 1) / 2;
+
+ float sum = 0;
+ if (across_channels) {
+ const int c_st = nstl::max(oc - half_size + 0, 0);
+ const int c_en = nstl::min(oc + half_size + 1, C);
+
+ for (int c = c_st; c < c_en; ++c) {
+ const float s = src[data_off(mb, c, oh, ow)];
+ sum += s * s;
+ }
+ } else {
+ int h_st = nstl::max(oh - half_size + 0, 0);
+ int h_en = nstl::min(oh + half_size + 1, H);
+ int w_st = nstl::max(ow - half_size + 0, 0);
+ int w_en = nstl::min(ow + half_size + 1, W);
+ for (int h = h_st; h < h_en; ++h) {
+ for (int w = w_st; w < w_en; ++w) {
+ const float s = src[data_off(mb, oc, h, w)];
+ sum += s * s;
+ }
+ }
+ }
+ const int summands = across_channels ? size : size * size;
+ sum = k + alpha * sum / summands;
+ size_t off = data_off(mb, oc, oh, ow);
+ d[0] = static_cast<data_t>(src[off] * fast_negative_powf(sum, beta));
+ };
+
+ const int MB = pd()->MB();
+ if (tag == nChw16c || tag == nChw8c) {
+ parallel_nd(MB, utils::div_up(C, blksize), H, W,
+ [&](int mb, int c_blk, int h, int w) {
+ int c = c_blk * blksize;
+ const size_t off = mb * stride_mb + c * H * W
+ + (h * W + w) * blksize;
+ PRAGMA_OMP_SIMD()
+ for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc)
+ ker(&dst[off + cc], mb, c + cc, h, w);
+ });
+ } else if (tag == nhwc) {
+ parallel_nd(MB, H, W, C,
+ [&](int mb, int h, int w, int c) {
+ const size_t off = mb * stride_mb + h * W * C + w * C + c;
+ ker(&dst[off], mb, c, h, w);
+ });
+ } else {
+ parallel_nd(MB, C, H, W,
+ [&](int mb, int c, int h, int w) {
+ const size_t off = data_off(mb, c, h, w);
+ ker(&dst[off], mb, c, h, w);
+ });
+ }
+}
+
+template <impl::data_type_t data_type>
+template <mkldnn_format_tag_t tag>
+void ref_lrn_bwd_t<data_type>::execute_backward(const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+ using namespace format_tag;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+
+ const int MB = pd()->MB();
+ const int C = pd()->C();
+ const int H = pd()->H();
+ const int W = pd()->W();
+ const size_t stride_mb = data_d.blocking_desc().strides[0];
+ constexpr int blksize = tag == nChw16c ? 16 : 8;
+
+ const float alpha = static_cast<float>(pd()->desc()->lrn_alpha);
+ const float beta = static_cast<float>(pd()->desc()->lrn_beta);
+ const float k = static_cast<float>(pd()->desc()->lrn_k);
+ const int kernel_size = pd()->desc()->local_size;
+ const int half_ksize = (kernel_size - 1) / 2;
+
+ auto data_off = [&](int mb, int c, int h, int w) -> size_t {
+ switch (tag) {
+ case nChw16c:
+ case nChw8c: return mb * stride_mb + c/blksize * H * W * blksize
+ + h * W * blksize + w * blksize + c%blksize;
+ case nchw: return mb * stride_mb + c * H * W + h * W + w;
+ case nhwc: return mb * stride_mb + h * W * C + w * C + c;
+ default: return data_d.off(mb, c, h, w);
+ }
+ };
+
+ auto ker = [=](data_t *d, int mb, int oc, int oh, int ow) {
+ const int c_st = nstl::max(oc - half_ksize + 0, 0);
+ const int c_en = nstl::min(oc + half_ksize + 1, C);
+
+ float A = 0, B = 0, omega_mid = 0;
+ for (int c = c_st; c < c_en; c++) {
+ float sum = 0.0;
+ const int i_st = nstl::max(c - half_ksize, 0);
+ const int i_en = nstl::min(c + kernel_size - half_ksize, C);
+
+ for (int i = i_st; i < i_en; ++i) {
+ const float value = src[data_off(mb, i, oh, ow)];
+ sum += value * value;
+ }
+ const float omega = static_cast<float>(k + sum * alpha / kernel_size);
+ if (c == oc) omega_mid = omega;
+ float t = src[data_off(mb, c, oh, ow)]
+ * fast_negative_powf(omega, beta);
+ B += 1.0f / omega * t * diff_dst[data_off(mb, c, oh, ow)];
+ }
+
+ const size_t off = data_off(mb, oc, oh, ow);
+ A = fast_negative_powf(omega_mid, beta) * diff_dst[off];
+ B *= src[off];
+ B *= (2.0f * alpha * beta) / kernel_size;
+ *d = static_cast<data_t>(A - B); // final cast down to data_t
+ };
+
+ if (tag == nChw16c || tag == nChw8c) {
+ parallel_nd(MB, utils::div_up(C, blksize), H, W,
+ [&](int mb, int c_blk, int h, int w) {
+ int c = c_blk * blksize;
+ const size_t off = mb * stride_mb + c * H * W +
+ (h * W + w) * blksize;
+ PRAGMA_OMP_SIMD()
+ for (int cc = 0; cc < nstl::min(blksize, C - c); ++cc)
+ ker(&diff_src[off + cc], mb, c + cc, h, w);
+ });
+ } else if (tag == nhwc) {
+ parallel_nd(MB, H, W, C,
+ [&](int mb, int h, int w, int c) {
+ const size_t off = mb * stride_mb + h * W * C + w * C + c;
+ ker(&diff_src[off], mb, c, h, w);
+ });
+ } else {
+ parallel_nd(MB, C, H, W,
+ [&](int mb, int c, int h, int w) {
+ const size_t off = data_off(mb, c, h, w);
+ ker(&diff_src[off], mb, c, h, w);
+ });
+ }
+}
+
+template void ref_lrn_fwd_t<data_type::f32>::
+execute_forward<format_tag::nChw16c>(const exec_ctx_t &ctx) const;
+template void ref_lrn_fwd_t<data_type::f32>::
+execute_forward<format_tag::nChw8c>(const exec_ctx_t &ctx) const;
+template void ref_lrn_fwd_t<data_type::f32>::
+execute_forward<format_tag::nchw>(const exec_ctx_t &ctx) const;
+template void ref_lrn_fwd_t<data_type::f32>::
+execute_forward<format_tag::nhwc>(const exec_ctx_t &ctx) const;
+template void ref_lrn_fwd_t<data_type::f32>::
+execute_forward<format_tag::any>(const exec_ctx_t &ctx) const;
+template void ref_lrn_bwd_t<data_type::f32>::
+execute_backward<format_tag::nChw16c>(const exec_ctx_t &ctx) const;
+template void ref_lrn_bwd_t<data_type::f32>::
+execute_backward<format_tag::nChw8c>(const exec_ctx_t &ctx) const;
+template void ref_lrn_bwd_t<data_type::f32>::
+execute_backward<format_tag::nchw>(const exec_ctx_t &ctx) const;
+template void ref_lrn_bwd_t<data_type::f32>::
+execute_backward<format_tag::nhwc>(const exec_ctx_t &ctx) const;
+template void ref_lrn_bwd_t<data_type::f32>::
+execute_backward<format_tag::any>(const exec_ctx_t &ctx) const;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp
new file mode 100644
index 0000000000..f25cfb7fae
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_lrn.hpp
@@ -0,0 +1,136 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_LRN_HPP
+#define CPU_REF_LRN_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_lrn_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type>
+struct ref_lrn_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_lrn_fwd_pd_t {
+ using cpu_lrn_fwd_pd_t::cpu_lrn_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_lrn_fwd_t);
+
+ status_t init() {
+ using namespace format_tag;
+
+ bool ok = true
+ && is_fwd()
+ && src_md()->data_type == data_type
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ dat_tag_ = memory_desc_matches_one_of_tag(
+ *src_md(), nChw16c, nChw8c, nchw, nhwc);
+
+ return status::success;
+ }
+
+ format_tag_t dat_tag_;
+ };
+
+ ref_lrn_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ using namespace format_tag;
+ switch (pd()->dat_tag_) {
+ case nChw16c: execute_forward<nChw16c>(ctx); break;
+ case nChw8c: execute_forward<nChw8c>(ctx); break;
+ case nchw: execute_forward<nchw>(ctx); break;
+ case nhwc: execute_forward<nhwc>(ctx); break;
+ default: execute_forward<any>(ctx);
+ }
+ return status::success;
+ }
+
+private:
+ template<format_tag_t tag>
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t data_type>
+struct ref_lrn_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_lrn_bwd_pd_t {
+ using cpu_lrn_bwd_pd_t::cpu_lrn_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_lrn_bwd_t);
+
+ status_t init() {
+ using namespace format_tag;
+ using namespace alg_kind;
+
+ bool ok = true
+ && !is_fwd()
+ && utils::one_of(desc()->alg_kind, lrn_across_channels
+ /*, lrn_within_channel */) // not supported yet
+ && utils::everyone_is(data_type,
+ src_md()->data_type,
+ diff_src_md()->data_type)
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ dat_tag_ = memory_desc_matches_one_of_tag(
+ *src_md(), nChw16c, nChw8c, nchw, nhwc);
+
+ return status::success;
+ }
+
+ format_tag_t dat_tag_;
+ };
+
+ ref_lrn_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ using namespace format_tag;
+ switch (pd()->dat_tag_) {
+ case nChw16c: execute_backward<nChw16c>(ctx); break;
+ case nChw8c: execute_backward<nChw8c>(ctx); break;
+ case nchw: execute_backward<nchw>(ctx); break;
+ case nhwc: execute_backward<nhwc>(ctx); break;
+ default: execute_backward<any>(ctx);
+ }
+ return status::success;
+ }
+
+private:
+ template<format_tag_t tag>
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp
new file mode 100644
index 0000000000..65b934e123
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.cpp
@@ -0,0 +1,381 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+
+#include "ref_pooling.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t data_type, data_type_t acc_type>
+void ref_pooling_fwd_t<data_type, acc_type>::execute_forward(
+ const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+ using namespace prop_kind;
+
+ auto alg = pd()->desc()->alg_kind;
+
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+ auto ws = CTX_OUT_MEM(unsigned char *, MKLDNN_ARG_WORKSPACE);
+
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+ const memory_desc_wrapper ws_d(pd()->workspace_md());
+ const data_type_t ws_dt = ws ? ws_d.data_type() : data_type::undef;
+
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ const int SD = pd()->KSD();
+ const int SH = pd()->KSH();
+ const int SW = pd()->KSW();
+ const int padF = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+
+ const bool is_3d = pd()->desc()->src_desc.ndims == 5;
+
+ auto apply_offset = [=](int index, int offset) {
+ return (index > offset) ? index - offset : 0;
+ };
+
+ auto set_ws = [=](int mb, int oc, int od, int oh, int ow, int value) {
+ if (ws) {
+ assert(ws_dt == data_type::u8 || ws_dt == data_type::s32);
+ size_t offset = is_3d
+ ? ws_d.off(mb, oc, od, oh, ow) : ws_d.off(mb, oc, oh, ow);;
+ if (ws_dt == data_type::u8) {
+ assert(0 <= value && value <= 255);
+ ws[offset] = value;
+ } else
+ reinterpret_cast<int *>(ws)[offset] = value;
+ }
+ };
+
+ auto ker_max = [=](data_t *d, int mb, int oc, int oh, int ow) {
+ for (int kh = 0; kh < KH; ++kh) {
+ for (int kw = 0; kw < KW; ++kw) {
+ const int ih = oh * SH - padT + kh;
+ const int iw = ow * SW - padL + kw;
+
+ if (ih < 0 || ih >= IH) continue;
+ if (iw < 0 || iw >= IW) continue;
+
+ auto s = src[src_d.off(mb, oc, ih, iw)];
+ if (s > d[0]) {
+ d[0] = s;
+ set_ws(mb, oc, 1, oh, ow, kh*KW + kw);
+ }
+ }
+ }
+ };
+
+ auto ker_avg = [=](data_t *d, int mb, int oc, int oh, int ow) {
+ auto ih_start = apply_offset(oh*SH, padT);
+ auto iw_start = apply_offset(ow*SW, padL);
+ auto ih_end = nstl::min(oh*SH - padT + KH, IH);
+ auto iw_end = nstl::min(ow*SW - padL + KW, IW);
+
+ auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH
+ : (ih_end - ih_start)*(iw_end - iw_start);
+
+ acc_data_t dst = 0;
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+ for (int iw = iw_start; iw < iw_end; ++iw) {
+ dst += src[src_d.off(mb, oc, ih, iw)];
+ }
+ }
+
+ d[0] = math::out_round<data_t>((float)dst / num_summands);
+ };
+
+ auto ker_max_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) {
+ for (int kd = 0; kd < KD; ++kd) {
+ for (int kh = 0; kh < KH; ++kh) {
+ for (int kw = 0; kw < KW; ++kw) {
+ const int id = od * SD - padF + kd;
+ const int ih = oh * SH - padT + kh;
+ const int iw = ow * SW - padL + kw;
+
+ if (id < 0 || id >= ID) continue;
+ if (ih < 0 || ih >= IH) continue;
+ if (iw < 0 || iw >= IW) continue;
+
+ auto s = src[src_d.off(mb, oc, id, ih, iw)];
+ if (s > d[0]) {
+ d[0] = s;
+ set_ws(mb, oc, od, oh, ow, kd * KH * KW + kh*KW + kw);
+ }
+ }
+ }
+ }
+ };
+
+ auto ker_avg_3d = [=](data_t *d, int mb, int oc, int od, int oh, int ow) {
+ auto id_start = apply_offset(od*SD, padF);
+ auto ih_start = apply_offset(oh*SH, padT);
+ auto iw_start = apply_offset(ow*SW, padL);
+ auto id_end = nstl::min(od*SD - padF + KD, ID);
+ auto ih_end = nstl::min(oh*SH - padT + KH, IH);
+ auto iw_end = nstl::min(ow*SW - padL + KW, IW);
+
+ auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD
+ : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
+
+ acc_data_t dst = 0;
+ for (int id = id_start; id < id_end; ++id) {
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+ for (int iw = iw_start; iw < iw_end; ++iw) {
+ dst += src[src_d.off(mb, oc, id, ih, iw)];
+ }
+ }
+ }
+
+ d[0] = math::out_round<data_t>((float)dst / num_summands);
+ };
+
+ const int MB = pd()->MB();
+ const int OC = pd()->C();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+
+ if (alg == pooling_max) {
+ parallel_nd(MB, OC, OD, OH, OW,
+ [&](int mb, int oc, int od, int oh, int ow) {
+ data_t *d = is_3d
+ ? &dst[dst_d.off(mb, oc, od, oh, ow)]
+ : &dst[dst_d.off(mb, oc, oh, ow)];
+ d[0] = nstl::numeric_limits<data_t>::lowest();
+ set_ws(mb, oc, od, oh, ow, 0);
+ if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow);
+ else ker_max(d, mb, oc, oh, ow);
+ });
+ } else {
+ parallel_nd(MB, OC, OD, OH, OW,
+ [&](int mb, int oc, int od, int oh, int ow) {
+ data_t *d = is_3d
+ ? &dst[dst_d.off(mb, oc, od, oh, ow)]
+ : &dst[dst_d.off(mb, oc, oh, ow)];
+ d[0] = 0;
+ if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow);
+ else ker_avg(d, mb, oc, oh, ow);
+ });
+ }
+}
+
+template <data_type_t data_type, data_type_t acc_type>
+void ref_pooling_bwd_t<data_type, acc_type>::execute_backward(
+ const exec_ctx_t &ctx) const {
+ using namespace alg_kind;
+
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto ws = CTX_IN_MEM(const unsigned char *, MKLDNN_ARG_WORKSPACE);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
+ const memory_desc_wrapper ws_d(pd()->workspace_md());
+
+ const int ID = pd()->ID();
+ const int IH = pd()->IH();
+ const int IW = pd()->IW();
+ const int KD = pd()->KD();
+ const int KH = pd()->KH();
+ const int KW = pd()->KW();
+ const int SD = pd()->KSD();
+ const int SH = pd()->KSH();
+ const int SW = pd()->KSW();
+ const int padF = pd()->padFront();
+ const int padT = pd()->padT();
+ const int padL = pd()->padL();
+
+ const bool is_3d = pd()->desc()->diff_src_desc.ndims == 5;
+
+ auto alg = pd()->desc()->alg_kind;
+
+ auto apply_offset = [=](int index, int offset) {
+ return (index > offset) ? index - offset : 0;
+ };
+
+ auto ker_zero = [=](int _mb, int _oc) {
+ for (int ih = 0; ih < IH; ++ih) {
+ for (int iw = 0; iw < IW; ++iw) {
+ diff_src[diff_src_d.off(_mb, _oc, ih, iw)] = data_type_t(0);
+ }
+ }
+ };
+
+ auto ker_max = [=](const data_t *d, int mb, int oc, int oh, int ow) {
+ const size_t ws_off = ws_d.off(mb, oc, oh, ow);
+ const int index = ws_d.data_type() == data_type::u8
+ ? (int)ws[ws_off] : ((int *)ws)[ws_off];
+ const int kw = index % KW;
+ const int kh = index / KW;
+ const int ih = oh * SH - padT + kh;
+ const int iw = ow * SW - padL + kw;
+
+ // If padding area could fit the kernel,
+ // then input displacement would be out of bounds.
+ // No need to back propagate there as padding is
+ // virtual in pooling_max case.
+ if (ih < 0 || ih >= IH)
+ return;
+ if (iw < 0 || iw >= IW)
+ return;
+
+ diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0];
+ };
+
+ auto ker_avg = [=](const data_t *d, int mb, int oc, int oh, int ow) {
+ auto ih_start = apply_offset(oh*SH, padT);
+ auto iw_start = apply_offset(ow*SW, padL);
+ auto ih_end = nstl::min(oh*SH - padT + KH, IH);
+ auto iw_end = nstl::min(ow*SW - padL + KW, IW);
+
+ auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH
+ : (ih_end - ih_start)*(iw_end - iw_start);
+
+ for (int ih = ih_start; ih < ih_end; ++ih) {
+ for (int iw = iw_start; iw < iw_end; ++iw) {
+ diff_src[diff_src_d.off(mb, oc, ih, iw)] += d[0] / num_summands;
+ }
+ }
+ };
+
+ auto ker_zero_3d = [=](int _mb, int _oc) {
+ for (int id = 0; id < ID; ++id) {
+ for (int ih = 0; ih < IH; ++ih) {
+ for (int iw = 0; iw < IW; ++iw) {
+ diff_src[diff_src_d.off(_mb, _oc, id, ih, iw)] =
+ data_type_t(0);
+ }
+ }
+ }
+ };
+
+ auto ker_max_3d = [=](const data_t *d, int mb, int oc, int od, int oh,
+ int ow) {
+ const size_t ws_off = ws_d.off(mb, oc, od, oh, ow);
+ const int index = ws_d.data_type() == data_type::u8
+ ? (int)ws[ws_off] : ((int *)ws)[ws_off];
+ const int kw = index % KW;
+ const int kh = (index / KW) % KH;
+ const int kd = (index / KW) / KH;
+ const int id = od * SD - padF + kd;
+ const int ih = oh * SH - padT + kh;
+ const int iw = ow * SW - padL + kw;
+
+ // If padding area could fit the kernel,
+ // then input displacement would be out of bounds.
+ // No need to back propagate there as padding is
+ // virtual in pooling_max case.
+ if (id < 0 || id >= ID)
+ return;
+ if (ih < 0 || ih >= IH)
+ return;
+ if (iw < 0 || iw >= IW)
+ return;
+
+ diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0];
+ };
+
+ auto ker_avg_3d = [=](const data_t *d, int mb, int oc, int od, int oh,
+ int ow) {
+ auto id_start = apply_offset(od*SD, padF);
+ auto ih_start = apply_offset(oh*SH, padT);
+ auto iw_start = apply_offset(ow*SW, padL);
+ auto id_end = nstl::min(od*SD - padF + KD, ID);
+ auto ih_end = nstl::min(oh*SH - padT + KH, IH);
+ auto iw_end = nstl::min(ow*SW - padL + KW, IW);
+
+ auto num_summands = (alg == pooling_avg_include_padding) ? KW*KH*KD
+ : (ih_end - ih_start)*(iw_end - iw_start)*(id_end - id_start);
+
+ for (int id = id_start; id < id_end; ++id)
+ for (int ih = ih_start; ih < ih_end; ++ih)
+ for (int iw = iw_start; iw < iw_end; ++iw) {
+ diff_src[diff_src_d.off(mb, oc, id, ih, iw)] += d[0] / num_summands;
+ }
+ };
+
+ const int MB = pd()->MB();
+ const int OC = pd()->C();
+ const int OD = pd()->OD();
+ const int OH = pd()->OH();
+ const int OW = pd()->OW();
+
+ if (pd()->desc()->alg_kind == alg_kind::pooling_max) {
+ parallel_nd(MB, OC, [&](int mb, int oc) {
+ if (is_3d) ker_zero_3d(mb, oc);
+ else ker_zero(mb, oc);
+ for (int od = 0; od < OD; ++od) {
+ for (int oh = 0; oh < OH; ++oh) {
+ for (int ow = 0; ow < OW; ++ow) {
+ const data_t *d = is_3d
+ ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)]
+ : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)];
+ if (is_3d) ker_max_3d(d, mb, oc, od, oh, ow);
+ else ker_max(d, mb, oc, oh, ow);
+ }
+ }
+ }
+ });
+ } else {
+ parallel_nd(MB, OC, [&](int mb, int oc) {
+ if (is_3d) ker_zero_3d(mb, oc);
+ else ker_zero(mb, oc);
+ for (int od = 0; od < OD; ++od) {
+ for (int oh = 0; oh < OH; ++oh) {
+ for (int ow = 0; ow < OW; ++ow) {
+ const data_t *d = is_3d
+ ? &diff_dst[diff_dst_d.off(mb, oc, od, oh, ow)]
+ : &diff_dst[diff_dst_d.off(mb, oc, oh, ow)];
+ if (is_3d) ker_avg_3d(d, mb, oc, od, oh, ow);
+ else ker_avg(d, mb, oc, oh, ow);
+ }
+ }
+ }
+ });
+ }
+}
+
+template struct ref_pooling_fwd_t<data_type::f32>;
+template struct ref_pooling_fwd_t<data_type::s32>;
+template struct ref_pooling_fwd_t<data_type::s8, data_type::s32>;
+template struct ref_pooling_fwd_t<data_type::u8, data_type::s32>;
+
+template struct ref_pooling_bwd_t<data_type::f32>;
+template struct ref_pooling_bwd_t<data_type::s32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp
new file mode 100644
index 0000000000..e43ceaa82b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_pooling.hpp
@@ -0,0 +1,119 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_POOLING_HPP
+#define CPU_REF_POOLING_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_pooling_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type, impl::data_type_t acc_type = data_type>
+struct ref_pooling_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_pooling_fwd_pd_t {
+ using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_pooling_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && set_default_params() == status::success
+ && is_fwd()
+ && utils::everyone_is(data_type, src_md()->data_type,
+ dst_md()->data_type)
+ && desc()->accum_data_type == acc_type
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ bool is_training = desc_.prop_kind == prop_kind::forward_training;
+ if (desc()->alg_kind == alg_kind::pooling_max && is_training)
+ init_default_ws();
+
+ return status::success;
+ }
+ };
+
+ ref_pooling_fwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ typedef typename prec_traits<data_type>::type data_t;
+ typedef typename prec_traits<acc_type>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_forward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <impl::data_type_t data_type, impl::data_type_t acc_type = data_type>
+struct ref_pooling_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_pooling_bwd_pd_t {
+ using cpu_pooling_bwd_pd_t::cpu_pooling_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_pooling_bwd_t);
+
+ status_t init() {
+ bool ok = true
+ && set_default_params() == status::success
+ && !is_fwd()
+ && utils::everyone_is(data_type, diff_dst_md()->data_type,
+ diff_src_md()->data_type)
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ if (desc()->alg_kind == alg_kind::pooling_max) {
+ init_default_ws();
+ if (!compare_ws(hint_fwd_pd_))
+ return status::unimplemented;
+ }
+
+ return status::success;
+ }
+ };
+
+ ref_pooling_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {}
+ typedef typename prec_traits<data_type>::type data_t;
+ typedef typename prec_traits<acc_type>::type acc_data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_backward(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp
new file mode 100644
index 0000000000..af27743110
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.cpp
@@ -0,0 +1,153 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+
+#include "ref_shuffle.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace format_tag;
+
+template <int data_type_size>
+template <mkldnn_format_tag_t tag>
+void ref_shuffle_t<data_type_size>::execute_(const exec_ctx_t &ctx) const {
+ using namespace prop_kind;
+ using namespace utils;
+
+ const memory_desc_wrapper data_d(pd()->data_md());
+
+ auto i_arg = pd()->is_fwd() ? MKLDNN_ARG_SRC : MKLDNN_ARG_DIFF_DST;
+ auto o_arg = pd()->is_fwd() ? MKLDNN_ARG_DST : MKLDNN_ARG_DIFF_SRC;
+ auto input = CTX_IN_MEM(const data_t *, i_arg);
+ auto output = CTX_OUT_MEM(data_t *, o_arg);
+
+ const int axis = pd()->axis();
+ const int axis_size = pd()->axis_size();
+
+ const int MB = pd()->MB();
+ const int C = pd()->C();
+ int H = 1, W = 1, D = 1, HW = 1, SP = 1;
+ const bool has_spatial = utils::one_of(data_d.ndims(), 3, 4 ,5);
+ if (has_spatial)
+ {
+ D = pd()->D();
+ H = pd()->H();
+ W = pd()->W();
+ HW = H * W;
+ SP = D * HW;
+ }
+ const size_t stride_mb = data_d.blocking_desc().strides[0];
+ constexpr int blksize = one_of(tag, nChw16c, nCdhw16c) ? 16 : 8;
+
+ if (axis == 1 && one_of(tag, nChw16c, nChw8c, nCdhw16c, nCdhw16c)) {
+#if MKLDNN_THR == MKLDNN_THR_OMP
+# pragma omp parallel for collapse(3) schedule(static)
+ for (int mb = 0; mb < MB; ++mb)
+ for (int cb = 0; cb < C; cb += blksize)
+ for (int sp = 0; sp < SP; ++sp) {
+ const size_t off = mb * stride_mb + sp * blksize;
+ const size_t output_off = off + cb * SP;
+ PRAGMA_OMP_SIMD()
+ for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc)
+ {
+ int input_c = rev_transposed_[cb + cc];
+ const size_t input_off = off + input_c / blksize * SP * blksize
+ + input_c % blksize;
+ output[output_off + cc] = input[input_off];
+ }
+ }
+#else
+ parallel_nd(MB, utils::div_up(C, blksize), SP, [&](int mb, int c,
+ int sp) {
+ const size_t off = mb * stride_mb + sp * blksize;
+ const int cb = c * blksize;
+ const size_t output_off = off + cb * SP;
+ for (int cc = 0; cc < nstl::min(blksize, C - cb); ++cc)
+ {
+ int input_c = rev_transposed_[cb + cc];
+ const size_t input_off = off + input_c / blksize * SP * blksize
+ + input_c % blksize;
+ output[output_off + cc] = input[input_off];
+ }
+ });
+#endif
+ } else if (axis == 1 && one_of(tag, nhwc, ndhwc)) {
+ parallel_nd(MB, SP, [&](int mb, int sp) {
+ const size_t off = mb * stride_mb + sp * C;
+ PRAGMA_OMP_SIMD()
+ for (int c = 0; c < C; ++c)
+ output[off + c] = input[off + rev_transposed_[c]];
+ });
+ } else if (axis == 1 && one_of(tag, nchw, ncdhw)) {
+ parallel_nd(MB, C, [&](int mb, int c) {
+ const size_t output_off = mb * stride_mb + c * SP;
+ const size_t input_off = mb * stride_mb + rev_transposed_[c] * SP;
+ PRAGMA_OMP_SIMD()
+ for (int sp = 0; sp < SP; ++sp) {
+ output[output_off + sp] = input[input_off + sp];
+ }
+ });
+ } else {
+ auto dims = pd()->desc()->data_desc.dims;
+ auto ndims = pd()->desc()->data_desc.ndims;
+ const size_t outer_size = utils::array_product(dims, axis);
+ const size_t inner_size = utils::array_product(dims + axis + 1,
+ ndims - axis - 1);
+ const size_t dim = axis_size * inner_size;
+
+ parallel_nd(outer_size, axis_size, inner_size, [&](size_t ou, int a,
+ size_t in)
+ {
+ const size_t off = ou * dim + in;
+ auto &o = output[data_d.off_l(off + a * inner_size)];
+ o = input[data_d.off_l(off + rev_transposed_[a] * inner_size)];
+ });
+ }
+}
+
+template void ref_shuffle_t<4>::execute_<nCdhw16c>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<4>::execute_<nChw16c>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<4>::execute_<nCdhw8c>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<4>::execute_<nChw8c>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<4>::execute_<ncdhw>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<4>::execute_<nchw>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<4>::execute_<ndhwc>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<4>::execute_<nhwc>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<4>::execute_<any>(const exec_ctx_t &ctx) const;
+
+template void ref_shuffle_t<1>::execute_<nCdhw16c>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<1>::execute_<nChw16c>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<1>::execute_<nCdhw8c>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<1>::execute_<nChw8c>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<1>::execute_<ncdhw>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<1>::execute_<nchw>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<1>::execute_<ndhwc>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<1>::execute_<nhwc>(const exec_ctx_t &ctx) const;
+template void ref_shuffle_t<1>::execute_<any>(const exec_ctx_t &ctx) const;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp
new file mode 100644
index 0000000000..5e09a1a69b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_shuffle.hpp
@@ -0,0 +1,111 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_SHUFFLE_HPP
+#define CPU_REF_SHUFFLE_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_shuffle_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template<int data_type_size>
+struct ref_shuffle_t : public cpu_primitive_t {
+ using shuffle_class = ref_shuffle_t<data_type_size>;
+
+ struct pd_t: public cpu_shuffle_pd_t {
+ using cpu_shuffle_pd_t::cpu_shuffle_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", shuffle_class);
+
+ status_t init() {
+ using namespace format_tag;
+
+ bool ok = true
+ && data_type_size
+ == types::data_type_size(data_md()->data_type);
+ if (!ok) return status::unimplemented;
+
+ if (ndims() == 5) {
+ dat_tag_ = memory_desc_matches_one_of_tag(
+ *data_md(), nCdhw16c, nCdhw8c, ncdhw, ndhwc);
+ } else if (ndims() == 4) {
+ dat_tag_ = memory_desc_matches_one_of_tag(
+ *data_md(), nChw16c, nChw8c, nchw, nhwc);
+ } else
+ dat_tag_ = any;
+
+ return status::success;
+ }
+
+ format_tag_t dat_tag_;
+ };
+
+ ref_shuffle_t(const pd_t *apd): cpu_primitive_t(apd) {
+ const int axis_size = pd()->axis_size();
+ const int group_size = pd()->group_size();
+ const int transpose_row = pd()->is_fwd() ? group_size
+ : axis_size / group_size;
+ const int transpose_col = pd()->is_fwd() ? axis_size / group_size
+ : group_size;
+ rev_transposed_ = (int *)malloc(axis_size * sizeof(int), 64);
+ parallel_nd(transpose_col, transpose_row, [&](int i, int j) {
+ rev_transposed_[j * transpose_col + i] = i * transpose_row + j;
+ });
+ }
+
+ ~ref_shuffle_t() { free(rev_transposed_); }
+
+ typedef typename typesize_traits<data_type_size>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ using namespace format_tag;
+ switch (pd()->dat_tag_) {
+ case nCdhw16c: execute_<nCdhw16c>(ctx); break;
+ case nChw16c: execute_<nChw16c>(ctx); break;
+ case nCdhw8c: execute_<nCdhw8c>(ctx); break;
+ case nChw8c: execute_<nChw8c>(ctx); break;
+ case ncdhw: execute_<ncdhw>(ctx); break;
+ case nchw: execute_<nchw>(ctx); break;
+ case ndhwc: execute_<ndhwc>(ctx); break;
+ case nhwc: execute_<nhwc>(ctx); break;
+ default: execute_<any>(ctx); break;
+ }
+ return status::success;
+ }
+
+private:
+ template<format_tag_t tag>
+ void execute_(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ int *rev_transposed_;
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp
new file mode 100644
index 0000000000..36d5237f56
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.cpp
@@ -0,0 +1,264 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <float.h>
+#include <math.h>
+
+#include "c_types_map.hpp"
+#include "mkldnn_thread.hpp"
+#include "type_helpers.hpp"
+
+#include "ref_softmax.hpp"
+#include "gemm/os_blas.hpp"
+
+#ifdef USE_MKL
+#include "mkl_vml_functions.h"
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type>
+void ref_softmax_fwd_t<data_type>::execute_forward_dense(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ parallel_nd(outer_size_, [&](int ou) {
+ const data_t *src_data = src + ou * channels_;
+ data_t *dst_data = dst + ou * channels_;
+ data_t scalar = 0;
+
+ _max(channels_, src_data, &scalar);
+ _sub(channels_, scalar, src_data, dst_data);
+ _exp(channels_, dst_data, dst_data);
+ _sum(channels_, dst_data, &scalar);
+ _scal(channels_, data_t(1)/scalar, dst_data);
+ });
+}
+
+template <impl::data_type_t data_type>
+void ref_softmax_fwd_t<data_type>::execute_forward_generic(
+ const exec_ctx_t &ctx) const {
+ auto src = CTX_IN_MEM(const data_t *, MKLDNN_ARG_SRC);
+ auto dst = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ data_t space_max_val = 0, space_denom_val = 0;
+ data_t *space_max = &space_max_val, *space_denom = &space_denom_val;
+ if (inner_size_ > 1) {
+ using namespace memory_tracking::names;
+ space_max = scratchpad(ctx).template get<data_t>(key_softmax_reduction);
+ space_denom = space_max + inner_size_;
+ }
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+ const size_t dim = channels_ * inner_size_;
+
+ for (int ou = 0; ou < outer_size_; ou++) {
+ utils::array_set(space_max, -FLT_MAX, inner_size_);
+ utils::array_set(space_denom, 0, inner_size_);
+
+ for (int c = 0; c < channels_; c++) {
+ for(int in = 0; in < inner_size_; in++) {
+ size_t off = data_d.off_l(ou * dim + c * inner_size_ + in);
+ space_max[in] = nstl::max(space_max[in], src[off]);
+ }
+ }
+
+ for (int c = 0; c < channels_; c++) {
+ for(int in = 0; in < inner_size_; in++) {
+ size_t off = data_d.off_l(ou * dim + c * inner_size_ + in);
+ space_denom[in] += dst[off] = exp(src[off] - space_max[in]);
+ }
+ }
+
+ for (int c = 0; c < channels_; c++) {
+ for (int in = 0; in < inner_size_; in++) {
+ size_t off = data_d.off_l(ou * dim + c * inner_size_ + in);
+ dst[off] /= space_denom[in];
+ }
+ }
+ }
+}
+
+template <impl::data_type_t data_type>
+void ref_softmax_fwd_t<data_type>::_max(int n, const data_t *x,
+ data_t *max_data) const {
+// Intel(R) C++ Compiler generates the maxps + shuffle pattern
+// for the max search which works faster
+#if !defined(__INTEL_COMPILER)
+ // The code below makes a compiler to generate maxps instruction
+ // rather than maxss, which is generated for the 'else' code path
+ auto max_wrapper = [](data_t a, data_t b) { return nstl::max(a, b); };
+ auto min_wrapper = [](int a, int b) { return nstl::min(a, b); };
+
+ constexpr int unroll_factor = 32;
+ data_t max_values[unroll_factor];
+
+ if (n < unroll_factor) {
+ data_t max_val = x[0];
+ for (int i = 1; i < n; i++) {
+ max_val = max_wrapper(max_val, x[i]);
+ }
+ max_data[0] = max_val;
+ return;
+ }
+ for (int i = 0; i < unroll_factor; i++) {
+ max_values[i] = x[i];
+ }
+ for (int i = unroll_factor; i < n; i += unroll_factor) {
+ int offset = min_wrapper(i, n - unroll_factor);
+ for (int j = 0; j < unroll_factor; j++) {
+ max_values[j] = max_wrapper(max_values[j], x[offset + j]);
+ }
+ }
+ data_t max_val = max_values[0];
+ for (int i = 1; i < unroll_factor; i++) {
+ max_val = max_wrapper(max_val, max_values[i]);
+ }
+ max_data[0] = max_val;
+#else
+ max_data[0] = x[0];
+ for (int c = 1; c < n; ++c)
+ max_data[0] = nstl::max(max_data[0], x[c]);
+#endif
+}
+
+template <impl::data_type_t data_type>
+void ref_softmax_fwd_t<data_type>::_sub(int n, data_t alpha, const data_t *x,
+ data_t *y) const {
+ constexpr int unroll_factor = 32;
+ int tail = n % unroll_factor;
+ for (int i = 0; i < n - tail; i += unroll_factor) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < unroll_factor; j++) {
+ y[i + j] = x[i + j] - alpha;
+ }
+ }
+ PRAGMA_OMP_SIMD()
+ for (int i = n - tail; i < n; i++) {
+ y[i] = x[i] - alpha;
+ }
+}
+
+template <impl::data_type_t data_type>
+void ref_softmax_fwd_t<data_type>::_exp(int n, const data_t *a,
+ data_t *r) const {
+#ifdef USE_MKL
+ if (data_type == data_type::f32) {
+ vsExp(n, a, r);
+ return;
+ }
+#endif
+ parallel_nd(n, [&](int c) { r[c] = expf(a[c]); });
+}
+
+template <impl::data_type_t data_type>
+void ref_softmax_fwd_t<data_type>::_sum(int n, const data_t *x,
+ data_t *sum_data) const {
+#ifdef USE_CBLAS
+ // Here we are summing x's eg. e^z , which are positives
+ // so we can use BLAS ASUM
+ if (data_type == data_type::f32) {
+ sum_data[0] = cblas_sasum(n, x, 1);
+ return;
+ }
+#endif
+ data_t tsum = static_cast<data_t>(0);
+ PRAGMA_OMP_SIMD(reduction(+ : tsum))
+ for (int c = 0; c < n; ++c)
+ tsum += x[c];
+ sum_data[0] = tsum;
+}
+
+template <impl::data_type_t data_type>
+void ref_softmax_fwd_t<data_type>::_scal(int n, data_t alpha, data_t *x) const {
+#ifdef USE_CBLAS
+ if (data_type == data_type::f32) {
+ cblas_sscal(n, alpha, x, 1);
+ return;
+ }
+#endif
+ parallel_nd(n, [&](int c) { x[c] *= alpha; });
+}
+
+template struct ref_softmax_fwd_t<data_type::f32>;
+
+
+// NC/NCHW softmax for along final axe (1 for NC, 3 for NCHW)
+template <impl::data_type_t data_type>
+void ref_softmax_bwd_t<data_type>::execute_backward_dense(
+ const exec_ctx_t &ctx) const {
+ auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ parallel_nd(outer_size_, [&](int ou) {
+ data_t sbr = 0;
+ size_t off = channels_*ou;
+ for (int c = 0; c < channels_; c++) {
+ size_t loff = off + c;
+ data_t ldata = dst[loff];
+ sbr += diff_dst[loff]*ldata;
+ diff_src[loff] = ldata;
+ }
+
+ for(int c=0; c < channels_ ; ++c) {
+ size_t loff = off + c;
+ diff_src[loff] *= (diff_dst[loff] - sbr);
+ }
+ });
+}
+
+template <impl::data_type_t data_type>
+void ref_softmax_bwd_t<data_type>::execute_backward_generic(
+ const exec_ctx_t &ctx) const {
+ auto dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DST);
+ auto diff_dst = CTX_IN_MEM(const data_t *, MKLDNN_ARG_DIFF_DST);
+ auto diff_src = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DIFF_SRC);
+
+ const memory_desc_wrapper diff_d(pd()->diff_src_md());
+ const memory_desc_wrapper data_d(pd()->dst_md());
+
+ const size_t dim = channels_ * inner_size_;
+
+ parallel_nd(outer_size_, [&](int ou) {
+ for (int in = 0; in < inner_size_; in++) {
+ data_t sbr = 0;
+ for (int c = 0; c < channels_; c++) {
+ size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in);
+ size_t off_data = diff_d.off_l(ou * dim + c * inner_size_ + in);
+ sbr += diff_dst[off_diff] * dst[off_data];
+ }
+
+ for(int c=0; c < channels_ ; ++c) {
+ size_t off_diff = diff_d.off_l(ou * dim + c * inner_size_ + in);
+ size_t off_data = data_d.off_l(ou * dim + c * inner_size_ + in);
+ diff_src[off_diff] = dst[off_data] * (diff_dst[off_diff] - sbr);
+ }
+ }
+ });
+}
+
+template struct ref_softmax_bwd_t<data_type::f32>;
+
+}
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp
new file mode 100644
index 0000000000..5cb74d8007
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_softmax.hpp
@@ -0,0 +1,186 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_SOFTMAX_HPP
+#define CPU_REF_SOFTMAX_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "cpu_softmax_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <impl::data_type_t data_type>
+struct ref_softmax_fwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_softmax_fwd_pd_t {
+ using cpu_softmax_fwd_pd_t::cpu_softmax_fwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_softmax_fwd_t);
+
+ status_t init() {
+ bool ok = true
+ && is_fwd()
+ && src_md()->data_type == data_type
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ private:
+ void init_scratchpad() {
+ const int inner_size = utils::array_product(
+ desc()->data_desc.dims + desc()->softmax_axis + 1,
+ desc()->data_desc.ndims - desc()->softmax_axis - 1);
+
+ if (inner_size > 1) {
+ auto scratchpad = scratchpad_registry().registrar();
+ scratchpad.book(memory_tracking::names::key_softmax_reduction,
+ sizeof(data_t) * 2 * inner_size);
+ }
+ }
+ };
+
+ ref_softmax_fwd_t(const pd_t *apd): cpu_primitive_t(apd)
+ {
+ auto ndims = pd()->desc()->data_desc.ndims;
+ auto dims = pd()->desc()->data_desc.dims;
+ auto axis = pd()->desc()->softmax_axis;
+
+ outer_size_ = utils::array_product(dims, axis);
+ channels_ = dims[axis];
+ inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1);
+
+ const memory_desc_wrapper data_d(pd()->src_md());
+
+ bool no_axis_blocking = true;
+ for (int iblk = 0; iblk < data_d.blocking_desc().inner_nblks; ++iblk)
+ if (data_d.blocking_desc().inner_idxs[iblk] == axis)
+ no_axis_blocking = false;
+
+ use_dense_ = inner_size_ == 1 && data_d.is_dense()
+ && no_axis_blocking
+ && data_d.blocking_desc().strides[axis] == 1;
+ }
+
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ if (use_dense_)
+ execute_forward_dense(ctx);
+ else
+ execute_forward_generic(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_forward_dense(const exec_ctx_t &ctx) const;
+ void execute_forward_generic(const exec_ctx_t &ctx) const;
+
+ void _max(int n, const data_t *x, data_t *max_data) const;
+ void _sub(int n, data_t alpha, const data_t *x, data_t *y) const;
+ void _exp(int n, const data_t *a, data_t *r) const;
+ void _sum(int n, const data_t *x, data_t *sum_data) const;
+ void _scal(int n, data_t alpha, data_t *x) const;
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ bool use_dense_;
+ int outer_size_, channels_, inner_size_;
+};
+
+template <impl::data_type_t data_type>
+struct ref_softmax_bwd_t: public cpu_primitive_t {
+ struct pd_t: public cpu_softmax_bwd_pd_t {
+ using cpu_softmax_bwd_pd_t::cpu_softmax_bwd_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", ref_softmax_bwd_t);
+
+ status_t init() {
+ bool ok = true
+ && !is_fwd()
+ && utils::everyone_is(data_type,
+ dst_md()->data_type,
+ diff_src_md()->data_type)
+ && attr()->has_default_values();
+ if (!ok) return status::unimplemented;
+
+ return status::success;
+ }
+ };
+
+ ref_softmax_bwd_t(const pd_t *apd): cpu_primitive_t(apd) {
+ auto dims = pd()->desc()->diff_desc.dims;
+ auto axis = pd()->desc()->softmax_axis;
+ auto ndims = pd()->desc()->diff_desc.ndims;
+
+ outer_size_ = utils::array_product(dims, axis);
+ channels_ = dims[axis];
+ inner_size_ = utils::array_product(dims + axis + 1, ndims - axis - 1);
+
+ const memory_desc_wrapper data_d(pd()->dst_md());
+ const memory_desc_wrapper diff_d(pd()->diff_dst_md());
+
+ bool no_axis_blocking = true;
+ for (int iblk = 0; iblk < diff_d.blocking_desc().inner_nblks; ++iblk)
+ if (diff_d.blocking_desc().inner_idxs[iblk] == axis)
+ no_axis_blocking = false;
+
+ use_dense_ = true
+ && inner_size_ == 1
+ && diff_d == data_d
+ && diff_d.is_dense()
+ && no_axis_blocking
+ && diff_d.blocking_desc().strides[axis] == 1;
+ }
+
+ typedef typename prec_traits<data_type>::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ if (use_dense_)
+ execute_backward_dense(ctx);
+ else
+ execute_backward_generic(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_backward_dense(const exec_ctx_t &ctx) const;
+ void execute_backward_generic(const exec_ctx_t &ctx) const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ bool use_dense_;
+ int outer_size_, channels_, inner_size_;
+};
+
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp
new file mode 100644
index 0000000000..3b2a75d99b
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/ref_sum.hpp
@@ -0,0 +1,101 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef REF_SUM_HPP
+#define REF_SUM_HPP
+
+#include "reorder_pd.hpp"
+
+#include "cpu_sum_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct ref_sum_t: public cpu_primitive_t {
+ struct pd_t: public cpu_sum_pd_t {
+ using cpu_sum_pd_t::cpu_sum_pd_t;
+
+ pd_t(const pd_t &rhs): cpu_sum_pd_t(rhs) {
+ for (size_t i = 0; i < rhs.reorder_pds_.size(); ++i)
+ reorder_pds_.push_back(
+ (const reorder_pd_t *)rhs.reorder_pds_[i]->clone());
+ }
+
+ ~pd_t() { for (auto &rpd: reorder_pds_) delete rpd; }
+
+ DECLARE_SUM_PD_T("ref:any", ref_sum_t);
+
+ status_t init() {
+ bool ok = cpu_sum_pd_t::init() == status::success;
+ if (!ok) return status::unimplemented;
+
+ for (int i = 0; i < n_; ++i) {
+ auto r_impls = engine_->get_reorder_implementation_list();
+ for (auto r = r_impls; *r; ++r) {
+ primitive_attr_t attr;
+ attr.output_scales_.set(scales_[i]);
+ if (i != 0) attr.post_ops_.append_sum(1.0);
+
+ reorder_pd_t *r_pd;
+ if ((*r)(&r_pd, engine_, &attr, engine_, src_md(i),
+ engine_, dst_md()) == status::success) {
+ r_pd->init_info();
+ reorder_pds_.push_back(r_pd);
+ break;
+ }
+ }
+ }
+
+ ok = reorder_pds_.size() == (size_t)n_;
+ return ok ? status::success : status::unimplemented;
+ }
+
+ nstl::vector<const reorder_pd_t *> reorder_pds_;
+ };
+
+ ref_sum_t(const pd_t *apd): cpu_primitive_t(apd) {
+ const int n = pd()->n_inputs();
+ reorders_.resize(n);
+ for (int i = 0; i < n; ++i)
+ pd()->reorder_pds_[i]->create_primitive(&reorders_[i]);
+ }
+
+ ~ref_sum_t() { for (auto &r: reorders_) delete r; }
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ const auto n = pd()->n_inputs();
+ for (int i = 0; i < n; ++i) {
+ exec_args_t r_args;
+ r_args[MKLDNN_ARG_SRC] = ctx.args().at(MKLDNN_ARG_MULTIPLE_SRC + i);
+ r_args[MKLDNN_ARG_DST] = ctx.args().at(MKLDNN_ARG_DST);
+ exec_ctx_t r_ctx(ctx.stream(), std::move(r_args));
+ reorders_[i]->execute(r_ctx);
+ }
+ return status::success;
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ nstl::vector<primitive_t *> reorders_;
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp
new file mode 100644
index 0000000000..537084db91
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_common.cpp
@@ -0,0 +1,90 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Common for RNN and LSTM cell execution
+ */
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+using namespace rnn_utils;
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_cell_execution_sig(
+ (_ref_rnn_common_t<aprop, src_type, weights_type>::cell_execution)) {
+ if (!rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
+ rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
+ states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
+ rnn.gates_ws_ld);
+ }
+ (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic,
+ 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
+ rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld);
+
+ if (rnn_postgemm_ != nullptr)
+ rnn_postgemm_->execute<src_data_t, acc_data_t>(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+ else
+ (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+}
+template rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
+template rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
+
+template <>
+rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution) {
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+
+ /// bwd by data on the cell
+ (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic,
+ 1.0, w_iter_[0], rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld,
+ 0.0, diff_states_t_l_, rnn.states_ws_ld);
+
+ if (!rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
+ rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
+ rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
+ &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld);
+
+ /// bwd by weights on the cell
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
+ diff_w_layer_, rnn.diff_weights_layer_ld);
+ }
+
+ if (!rnn.merge_gemm_iter)
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0,
+ diff_w_iter_, rnn.diff_weights_iter_ld);
+
+ /// bwd by bias we just accumulate diffs from the gates
+ gates_reduction(rnn, ws_gates_, diff_bias_);
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp
new file mode 100644
index 0000000000..e1a61d4c62
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru.cpp
@@ -0,0 +1,180 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution GRU
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::math;
+using namespace rnn_utils;
+
+#define AOC array_offset_calculator
+template <>
+rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_[0]);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+
+ // 1. gemm Wx[0-2],x
+ if (!rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
+ rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
+ states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
+ rnn.gates_ws_ld);
+ }
+
+ // 2. gemm Wh[0-1],h
+ (this->*gemm_iter_func)('N', 'N', (rnn.n_gates - 1) * rnn.dic, rnn.mb,
+ rnn.sic, 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
+ rnn.states_ws_ld, 1.0, ws_gates_, rnn.gates_ws_ld);
+
+ // 3. activation zt and rt + elemwise multiplication rt,ht-1
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
+ ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
+ states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 1, j);
+ }
+ });
+
+ // 4. gemm Wh[2],h~t
+ (this->*gemm_iter_func)('N', 'N', rnn.dic, rnn.mb, rnn.sic, 1.0, w_iter_[1],
+ rnn.weights_iter_ld, states_t_l_, rnn.states_ws_ld, 1.0,
+ &(ws_gates(0, 2, 0)), rnn.gates_ws_ld);
+
+ // 5. activation h~t + calculate ht
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
+ states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j)
+ + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
+ }
+ });
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru) {
+ assert(!"GRU int8 is not supported");
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+ ws_diff_w_iter_aoc_t diff_w_iter(rnn, diff_w_iter_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
+ ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
+
+ // use state memory for intermediate computations
+ // TODO: use cell ws for that
+ float *dhG1_ = &(diff_states_t_l(rnn.n_states, 0, 0));
+ float *hG1_ = dhG1_;
+ AOC<float, 2> dhG1(dhG1_, rnn.states_nld, rnn.states_ws_ld);
+ AOC<float, 2> hG1(hG1_, rnn.states_nld, rnn.states_ws_ld);
+
+ // 1. calculate dG2, dG1, and part of dht-1
+ // dG2^ = dh * (1 - G0) * (1 - G2^2)
+ // dG0^ = dh * (ht-1 - G2) * u * (1 - G0)
+ // dht-1 (part) = dh * G0
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float h = states_tm1_l(i, j);
+ float dHt = diff_states_tp1_l(0, i, j)
+ + diff_states_t_lp1(rnn.n_states, i, j);
+ float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt
+ * one_m_square(ws_gates(i, 2, j));
+ float dG0 = (h - ws_gates(i, 2, j)) * dHt
+ * x_m_square(ws_gates(i, 0, j));
+
+ diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j);
+ ws_gates(i, 0, j) = dG0;
+ ws_gates(i, 2, j) = dG2;
+ }
+ });
+
+ // 2. calculate intermediate d(hG1)
+ // d(hG1) = dG2 * W2h^t
+ (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.dic, 1.0, w_iter_[1],
+ rnn.weights_iter_ld, &(ws_gates(0, 2, 0)), rnn.gates_ws_ld, 0.0,
+ dhG1_, rnn.states_ws_ld);
+
+ // 3. calculate dG1^ and part of dht-1
+ // dG1^ = d(hG1) * h * G1 * (1 - G1)
+ // dht-1 (part) += d(hG1) * G1
+ // h * G1 (required for dWh)
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float h = states_tm1_l(i, j);
+ float G1 = ws_gates(i, 1, j);
+ diff_states_t_l(0, i, j) += dhG1(i, j) * G1;
+ ws_gates(i, 1, j) = dhG1(i, j) * h * x_m_square(G1);
+ hG1(i, j) = G1 * h;
+ }
+ });
+
+ // 4. calculate diff weights
+ // dWh1 += dG1 * h, dWh2 += dG2 * h, dWh3 += dG3 * (G1(*)h)
+ gemm('N', 'T', (rnn.n_gates - 1) * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_,
+ rnn.diff_weights_iter_ld);
+ gemm('N', 'T', rnn.dic, rnn.sic, rnn.mb, 1.0, &(ws_gates(0, 2, 0)),
+ rnn.gates_ws_ld, hG1_, rnn.states_ws_ld, 1.0,
+ &(diff_w_iter(0, 2, 0)), rnn.diff_weights_iter_ld);
+
+ // 5. calculate diff states
+ // dht-1 += dG1 * W1h + dG0 * W0h
+ (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb,
+ (rnn.n_gates - 1) * rnn.dic, 1.0, w_iter_[0],
+ rnn.weights_iter_ld, ws_gates_, rnn.gates_ws_ld, 1.0,
+ diff_states_t_l_, rnn.states_ws_ld);
+
+ if (!rnn.merge_gemm_layer) {
+ // dWx += [dG0 dG1 dG2] * [x]
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
+ diff_w_layer_, rnn.diff_weights_layer_ld);
+ // dx = dG2 * W2x + dG1 * W1x + dG0 * W0x
+ (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
+ rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
+ rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
+ &(diff_states_t_l(rnn.n_states, 0, 0)), rnn.states_ws_ld);
+ }
+
+ // 6. calculate diff bias
+ gates_reduction(rnn, ws_gates_, diff_bias_);
+}
+#undef AOC
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp
new file mode 100644
index 0000000000..8dea8c90a4
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_gru_lbr.cpp
@@ -0,0 +1,170 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution GRU with linear before reset
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::math;
+using namespace rnn_utils;
+#define AOC array_offset_calculator
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+ ws_gates_aoc_t ws_gemm_state(rnn, ws_cell_);
+ AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float Wh_b = ws_gemm_state(i, 2, j) + bias(3, j);
+ ws_gates(i, 0, j) = logistic_fwd(
+ ws_gates(i, 0, j) + ws_gemm_state(i, 0, j) + bias(0, j));
+ ws_gates(i, 1, j) = logistic_fwd(
+ ws_gates(i, 1, j) + ws_gemm_state(i, 1, j) + bias(1, j));
+ ws_gates(i, 2, j) = tanh_fwd(
+ ws_gates(i, 2, j) + ws_gates(i, 1, j) * Wh_b + bias(2, j));
+ states_t_l(i, j) = states_tm1_l(i, j) * ws_gates(i, 0, j)
+ + (1.0f - ws_gates(i, 0, j)) * ws_gates(i, 2, j);
+ if (rnn.is_training)
+ ws_Wh_b(i, j) = Wh_b;
+ }
+ });
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise) {
+ assert(!"GRU LBR int8 is not supported");
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr) {
+ if (!rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb,
+ rnn.slc, 1.0, w_layer_[0], rnn.weights_layer_ld,
+ states_t_lm1_, rnn.states_ws_ld, 0.0, ws_gates_,
+ rnn.gates_ws_ld);
+ }
+ (this->*gemm_iter_func)('N', 'N', rnn.n_gates * rnn.dic, rnn.mb, rnn.sic,
+ 1.0, w_iter_[0], rnn.weights_iter_ld, states_tm1_l_,
+ rnn.states_ws_ld, 0.0, ws_cell_, rnn.gates_ws_ld);
+ (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr) {
+ assert(!"GRU LBR int8 is not supported");
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
+ ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
+ ws_gates_aoc_t ws_gates_r(rnn, ws_cell_);
+ AOC<float, 2> ws_Wh_b(ws_grid_, rnn.mb, rnn.dic);
+
+ // 1. calculate dG1 dG2 dG3
+ // dG0 = (dht - G2) * dht * (1 - G0) * G0
+ // dG1 = (W*h + b) * dG2 * (1 - G1) * G1
+ // dG2 = (1 - G0) * dht * (1 - G2*G2)
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float h = states_tm1_l(i, j);
+ float dHt = diff_states_tp1_l(0, i, j)
+ + diff_states_t_lp1(rnn.n_states, i, j);
+ float dG0 = (h - ws_gates(i, 2, j)) * dHt
+ * x_m_square(ws_gates(i, 0, j));
+ float dG2 = (1.0f - ws_gates(i, 0, j))
+ * one_m_square(ws_gates(i, 2, j)) * dHt;
+ float dG1 = ws_Wh_b(i, j) * dG2 * x_m_square(ws_gates(i, 1, j));
+
+ diff_states_t_l(0, i, j) = dHt * ws_gates(i, 0, j);
+ ws_gates(i, 2, j) = dG2;
+ ws_gates_r(i, 2, j) = dG2 * ws_gates(i, 1, j);
+ ws_gates(i, 0, j) = ws_gates_r(i, 0, j) = dG0;
+ ws_gates(i, 1, j) = ws_gates_r(i, 1, j) = dG1;
+ }
+ });
+}
+
+template <>
+rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr) {
+ ws_gates_aoc_t ws_gates_r(rnn, ws_cell_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+
+ (this->*elemwise_func)(rnn, ws_gates_, states_t_l_, c_states_t_l_,
+ states_tm1_l_, c_states_tm1_l_, diff_states_t_l_,
+ diff_states_t_lp1_, diff_states_tp1_l_, bias_[0], ws_grid_,
+ ws_cell_);
+
+ if (!rnn.merge_gemm_layer) {
+ // dx = dG * Wx^t
+ (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb,
+ rnn.n_gates * rnn.dic, 1.0, w_layer_[0],
+ rnn.weights_layer_ld, ws_gates_, rnn.gates_ws_ld, 0.0,
+ &diff_states_t_l(rnn.n_states, 0, 0), rnn.states_ws_ld);
+ // dWx += dG^t * x
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc, rnn.mb, 1.0, ws_gates_,
+ rnn.gates_ws_ld, states_t_lm1_, rnn.states_ws_ld, 1.0,
+ diff_w_layer_, rnn.diff_weights_layer_ld);
+ }
+ // dh += dGr * Wh^t
+ (this->*gemm_iter_func)('N', 'N', rnn.sic, rnn.mb, rnn.n_gates * rnn.dic,
+ 1.0, w_iter_[0], rnn.weights_iter_ld, ws_cell_, rnn.gates_ws_ld,
+ 1.0, diff_states_t_l_, rnn.states_ws_ld);
+
+ // dWh += dGr^t * h
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic, rnn.mb, 1.0, ws_cell_,
+ rnn.gates_ws_ld, states_tm1_l_, rnn.states_ws_ld, 1.0, diff_w_iter_,
+ rnn.diff_weights_layer_ld);
+
+ // db1-3 += e * dG
+ // db4 += e * (r * dG2)
+ gates_reduction(rnn, ws_gates_, diff_bias_);
+
+ parallel_nd(rnn.dic, [&](int j) {
+ for (int i = 0; i < rnn.mb; i++) {
+ diff_bias_[3 * rnn.dic + j] += ws_gates_r(i, 2, j);
+ }
+ });
+}
+
+#undef AOC
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp
new file mode 100644
index 0000000000..a15ba00d4c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_lstm.cpp
@@ -0,0 +1,143 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution LSTM
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "../simple_q10n.hpp"
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::math;
+using namespace rnn_utils;
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
+ ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ ws_gates(i, 0, j) = logistic_fwd(ws_gates(i, 0, j) + bias(0, j));
+ ws_gates(i, 1, j) = logistic_fwd(ws_gates(i, 1, j) + bias(1, j));
+ ws_gates(i, 2, j) = tanh_fwd(ws_gates(i, 2, j) + bias(2, j));
+ ws_gates(i, 3, j) = logistic_fwd(ws_gates(i, 3, j) + bias(3, j));
+
+ float tmp = ws_gates(i, 1, j) * c_states_tm1_l(i, j)
+ + ws_gates(i, 0, j) * ws_gates(i, 2, j);
+ states_t_l(i, j) = ws_gates(i, 3, j) * tanh_fwd(tmp);
+ c_states_t_l(i, j) = tmp;
+ }
+ });
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise) {
+ ws_gates_aoc_s32_t ws_gates_s32(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_u8_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
+ ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
+
+ float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
+ float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
+ float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
+
+ auto q_d = [&](float f) {
+ float qf = f * data_scale + data_shift;
+ return qz_a1b0<float, src_data_t>()(qf);
+ };
+
+ auto deq_w = [&](acc_data_t s, int gate, int j) {
+ return pd()->attr()->rnn_weights_qparams_.mask_ == 0 ?
+ saturate<float>(s) * (1.f / (weights_scales[0] * data_scale)) :
+ saturate<float>(s) * (1.f / (weights_scales[gate * rnn.dic + j]
+ * data_scale));
+ };
+
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float G0 = logistic_fwd<float>(
+ deq_w(ws_gates_s32(i, 0, j), 0, j) + bias(0, j));
+ float G1 = logistic_fwd<float>(
+ deq_w(ws_gates_s32(i, 1, j), 1, j) + bias(1, j));
+ float G2 = tanh_fwd<float>(
+ deq_w(ws_gates_s32(i, 2, j), 2, j) + bias(2, j));
+ float G3 = logistic_fwd<float>(
+ deq_w(ws_gates_s32(i, 3, j), 3, j) + bias(3, j));
+ float tmp = G1 * c_states_tm1_l(i, j) + G0 * G2;
+ states_t_l(i, j) = q_d(G3 * tanh_fwd(tmp));
+ c_states_t_l(i, j) = tmp;
+ }
+ });
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
+ ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
+ ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ PRAGMA_OMP_SIMD()
+ for (int j = 0; j < rnn.dic; j++) {
+ float Ct = c_states_t_l(i, j);
+ /// @todo save it in the workspace in fwd pass or recompute it to
+ /// save bw
+ float tanhCt = tanh_fwd(Ct);
+ // we have 2 incoming diffs on Ht
+ float dHt = diff_states_tp1_l(0, i, j)
+ + diff_states_t_lp1(rnn.n_states, i, j);
+ float dCt = diff_states_tp1_l(1, i, j)
+ + one_m_square(tanhCt) * ws_gates(i, 3, j) * dHt;
+
+ float dG1 = c_states_tm1_l(i, j) * dCt
+ * x_m_square(ws_gates(i, 1, j));
+ float dG0 = ws_gates(i, 2, j) * dCt * x_m_square(ws_gates(i, 0, j));
+ float dG3 = tanhCt * dHt * x_m_square(ws_gates(i, 3, j));
+ float dG2
+ = ws_gates(i, 0, j) * dCt * one_m_square(ws_gates(i, 2, j));
+
+ diff_states_t_l(1, i, j) = dCt * ws_gates(i, 1, j);
+
+ ws_gates(i, 0, j) = dG0;
+ ws_gates(i, 1, j) = dG1;
+ ws_gates(i, 2, j) = dG2;
+ ws_gates(i, 3, j) = dG3;
+ }
+ });
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp
new file mode 100644
index 0000000000..4536e8dfad
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cell_rnn.cpp
@@ -0,0 +1,113 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution of Vanilla RNN
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::math;
+using namespace rnn_utils;
+
+template <>
+float activation<alg_kind::eltwise_relu, prop_kind::forward>(
+ float dd, float s, float alpha, float cliping) {
+ return relu_fwd<float>(s, alpha);
+}
+
+template <>
+float activation<alg_kind::eltwise_relu, prop_kind::backward>(
+ float dd, float s, float alpha, float cliping) {
+ return relu_bwd<float>(dd, s, alpha);
+}
+
+template <>
+float activation<alg_kind::eltwise_tanh, prop_kind::forward>(
+ float dd, float s, float alpha, float cliping) {
+ return tanh_fwd<float>(s);
+}
+
+template <>
+float activation<alg_kind::eltwise_tanh, prop_kind::backward>(
+ float dd, float s, float alpha, float cliping) {
+ return dd * one_m_square<float>(s);
+}
+
+template <>
+float activation<alg_kind::eltwise_logistic, prop_kind::forward>(
+ float dd, float s, float alpha, float cliping) {
+ return logistic_fwd<float>(s);
+}
+
+template <>
+float activation<alg_kind::eltwise_logistic, prop_kind::backward>(
+ float dd, float s, float alpha, float cliping) {
+ return dd * x_m_square<float>(s);
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ for (int j = 0; j < rnn.dic; j++) {
+ const float h
+ = activation_func(0, ws_gates(i, 0, j) + bias(0, j), 0, 0);
+ ws_gates(i, 0, j) = states_t_l(i, j) = h;
+ }
+ });
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise) {
+ assert(!"VANILLA RNN int8 is not supported");
+}
+
+template <>
+rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise) {
+ ws_gates_aoc_t ws_gates(rnn, ws_gates_);
+ bias_aoc_t bias(rnn, bias_);
+ ws_states_aoc_t states_t_l(rnn, states_t_l_);
+ ws_states_aoc_t states_tm1_l(rnn, states_tm1_l_);
+ ws_diff_states_aoc_t diff_states_t_l(rnn, diff_states_t_l_);
+ ws_diff_states_aoc_t diff_states_tp1_l(rnn, diff_states_tp1_l_);
+ ws_diff_states_aoc_t diff_states_t_lp1(rnn, diff_states_t_lp1_);
+
+ parallel_nd(rnn.mb, [&](int i) {
+ for (int j = 0; j < rnn.dic; ++j) {
+ const float dH = diff_states_t_lp1(rnn.n_states, i, j)
+ + diff_states_tp1_l(0, i, j);
+ auto g = ws_gates(i, 0, j);
+ ws_gates(i, 0, j) = activation_func(dH, g, 0, 0);
+ }
+ });
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp
new file mode 100644
index 0000000000..b39427caf9
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/cpu_rnn_pd.hpp
@@ -0,0 +1,191 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_RNN_PD_HPP
+#define CPU_RNN_PD_HPP
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "rnn_pd.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "rnn_utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct cpu_rnn_fwd_pd_t : public rnn_fwd_pd_t {
+ using rnn_fwd_pd_t::rnn_fwd_pd_t;
+
+protected:
+ status_t set_default_params() {
+ using namespace format_tag;
+ if (src_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(src_layer_md_, tnc));
+ if (dst_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc));
+
+ // Optional parameters
+ if (with_src_iter() && src_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc));
+ if (with_bias() && bias_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(bias_md_, ldgo));
+ if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc));
+
+ return status::success;
+ }
+
+ status_t check_layout_consistency() {
+ using namespace format_tag;
+ using namespace data_type;
+ using namespace types;
+
+ auto is_blocked = [&](memory_desc_t md, int ndims) {
+ return md.format_kind == format_kind::blocked && md.ndims == ndims;
+ };
+
+ bool ok = true;
+ ok = ok && is_blocked(src_layer_md_, 3)
+ && is_blocked(dst_layer_md_, 3);
+ ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_),
+ is_blocked(src_iter_md_, 5))
+ && IMPLICATION(!is_zero_md(&dst_iter_md_),
+ is_blocked(dst_iter_md_, 5));
+
+ if (weights_layer_md_.format_kind == format_kind::rnn_packed)
+ ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format
+ == mkldnn_ldigo_p);
+ else
+ ok = ok && rnn_utils::is_ldigo(&weights_layer_md_);
+
+ if (weights_iter_md_.format_kind == format_kind::rnn_packed)
+ ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format
+ == mkldnn_ldigo_p);
+ else
+ ok = ok && rnn_utils::is_ldigo(&weights_iter_md_);
+
+ ok = ok && IMPLICATION(!is_zero_md(&bias_md_),
+ memory_desc_matches_tag(bias_md_, ldgo));
+
+ /* Int8 is supported only for packed weights */
+ data_type_t weights_iter_dt = weights_iter_md_.data_type;
+ data_type_t weights_layer_dt = weights_layer_md_.data_type;
+ ok = ok && IMPLICATION(
+ weights_iter_dt == s8, weights_iter_md_.format_kind
+ == format_kind::rnn_packed);
+ ok = ok && IMPLICATION(
+ weights_layer_dt == s8, weights_layer_md_.format_kind
+ == format_kind::rnn_packed);
+
+ return ok ? status::success : status::unimplemented;
+ }
+};
+
+struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
+ using rnn_bwd_pd_t::rnn_bwd_pd_t;
+
+protected:
+ status_t set_default_params() {
+ using namespace format_tag;
+ if (src_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(src_layer_md_, tnc));
+ if (dst_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_layer_md_, tnc));
+
+ if (diff_src_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_src_layer_md_, tnc));
+ if (diff_weights_layer_md_.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(diff_weights_layer_md_, ldigo));
+ CHECK(rnn_utils::set_good_strides(diff_weights_layer_md_, ldigo));
+ }
+ if (diff_weights_iter_md_.format_kind == format_kind::any) {
+ CHECK(memory_desc_init_by_tag(diff_weights_iter_md_, ldigo));
+ CHECK(rnn_utils::set_good_strides(diff_weights_iter_md_, ldigo));
+ }
+ if (diff_dst_layer_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_dst_layer_md_, tnc));
+
+ // Optional parameters
+ if (with_src_iter() && src_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(src_iter_md_, ldsnc));
+ if (with_bias() && bias_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(bias_md_, ldgo));
+ if (with_dst_iter() && dst_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(dst_iter_md_, ldsnc));
+
+ if (with_src_iter() && diff_src_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_src_iter_md_, ldsnc));
+ if (with_bias() && diff_bias_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_bias_md_, ldgo));
+ if (with_dst_iter() && diff_dst_iter_md_.format_kind == format_kind::any)
+ CHECK(memory_desc_init_by_tag(diff_dst_iter_md_, ldsnc));
+
+ return status::success;
+ }
+
+ status_t check_layout_consistency() {
+ using namespace format_tag;
+ using namespace types;
+
+ auto is_blocked = [&](memory_desc_t md, int ndims) {
+ return md.format_kind == format_kind::blocked && md.ndims == ndims;
+ };
+
+ bool ok = true;
+ ok = ok && is_blocked(src_layer_md_, 3)
+ && is_blocked(dst_layer_md_, 3);
+ ok = ok && IMPLICATION(!is_zero_md(&src_iter_md_),
+ is_blocked(src_iter_md_, 5))
+ && IMPLICATION(!is_zero_md(&dst_iter_md_),
+ is_blocked(dst_iter_md_, 5));
+
+ if (weights_layer_md_.format_kind == format_kind::rnn_packed)
+ ok = ok && (weights_layer_md_.format_desc.rnn_packed_desc.format
+ == mkldnn_ldgoi_p);
+ else
+ ok = ok && rnn_utils::is_ldgoi(&weights_layer_md_);
+
+ if (weights_iter_md_.format_kind == format_kind::rnn_packed)
+ ok = ok && (weights_iter_md_.format_desc.rnn_packed_desc.format
+ == mkldnn_ldgoi_p);
+ else
+ ok = ok && rnn_utils::is_ldgoi(&weights_iter_md_);
+
+ ok = ok && IMPLICATION(!is_zero_md(&bias_md_),
+ memory_desc_matches_tag(bias_md_, ldgo));
+
+ ok = ok && is_blocked(diff_src_layer_md_, 3)
+ && is_blocked(diff_dst_layer_md_, 3);
+ ok = ok && IMPLICATION(!is_zero_md(&diff_src_iter_md_),
+ is_blocked(diff_src_iter_md_, 5))
+ && IMPLICATION(!is_zero_md(&diff_dst_iter_md_),
+ is_blocked(diff_dst_iter_md_, 5));
+
+ ok = ok && rnn_utils::is_ldigo(&diff_weights_layer_md_)
+ && rnn_utils::is_ldigo(&diff_weights_iter_md_);
+ ok = ok && IMPLICATION(!is_zero_md(&diff_bias_md_),
+ memory_desc_matches_tag(diff_bias_md_, ldgo));
+
+ return ok ? status::success : status::unimplemented;
+ }
+};
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp
new file mode 100644
index 0000000000..09445648aa
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/jit_uni_rnn_postgemm.hpp
@@ -0,0 +1,401 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ * Cell execution LSTM
+ */
+
+#include "rnn_utils.hpp"
+#include "../jit_generator.hpp"
+#include "../jit_uni_eltwise.hpp"
+#include "c_types_map.hpp"
+#include "utils.hpp"
+
+#include "mkldnn_thread.hpp"
+
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+struct jit_uni_rnn_postgemm_kernel : public jit_generator {
+
+ typedef void (*kernel_t)(void *gates_, const void *bias, void *states_t_l_,
+ void *c_states_t_l_, void *c_states_tm1_l_);
+
+ jit_uni_rnn_postgemm_kernel(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr): rnn_(rnn), attr_(attr){}
+
+ virtual void init() = 0;
+
+template <typename src_data_t, typename acc_data_t>
+ rnn_elemwise_sig(execute) {
+ rnn_utils::ws_gates_aoc<acc_data_t> ws_gates(rnn, ws_gates_);
+ rnn_utils::bias_aoc_t bias(rnn, bias_);
+ rnn_utils::ws_states_aoc<src_data_t> states_t_l(rnn, states_t_l_);
+ rnn_utils::ws_states_aoc_t c_states_t_l(rnn, c_states_t_l_);
+ rnn_utils::ws_states_aoc_t c_states_tm1_l(rnn, c_states_tm1_l_);
+
+ // Todo: add parallelization on dic for the batch 1 case
+ // Assumption: the kernel runs a loop on dic elements
+ parallel_nd(rnn.mb, [&](int i) {
+ auto b_ = &bias(0, 0);
+ auto g_ = &ws_gates(i, 0, 0);
+ auto s_tl_ = &states_t_l(i, 0);
+ auto c_tl_ = &c_states_t_l(i, 0);
+ auto c_tm1l_ = &c_states_tm1_l(i, 0);
+ kernel_(g_, b_, s_tl_, c_tm1l_, c_tl_);
+ });
+ }
+
+protected:
+ kernel_t kernel_;
+ const rnn_utils::rnn_conf_t &rnn_;
+ const primitive_attr_t *attr_;
+};
+
+template <cpu_isa_t isa, impl::data_type_t src_data_t>
+struct jit_uni_lstm_postgemm_kernel_fwd: public jit_uni_rnn_postgemm_kernel
+{
+ DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_lstm_postgemm_kernel_fwd)
+
+ typedef typename utils::conditional<src_data_t == data_type::u8, int32_t,
+ float>::type acc_data_t;
+ typedef typename utils::conditional<isa == avx512_core,
+ jit_uni_eltwise_injector_f32<avx512_common>,
+ jit_uni_eltwise_injector_f32<isa>>::type injector_t;
+
+ jit_uni_lstm_postgemm_kernel_fwd(const rnn_utils::rnn_conf_t &rnn, const primitive_attr_t *attr)
+ : jit_uni_rnn_postgemm_kernel(rnn, attr){}
+
+ void init() override {
+ // we use rax for both constant tables as they use the same table
+ sigmoid_injector_ = new injector_t(this,
+ alg_kind::eltwise_logistic, 0.0f, 0.0f, true, rax);
+ tanh_injector_ = new injector_t(this,
+ alg_kind::eltwise_tanh, 0.0f, 0.0f, true, rax);
+ generate();
+ kernel_ = (kernel_t) this->getCode();
+ }
+
+protected:
+ injector_t *sigmoid_injector_;
+ injector_t *tanh_injector_;
+
+ // register size in bytes
+ using Vmm = typename jit_uni_eltwise_injector_f32<isa>::Vmm;
+ size_t vlen = cpu_isa_traits<isa>::vlen;
+ size_t vlen_dst = (src_data_t == data_type::u8) ? vlen/4 : vlen;
+ size_t cstate_dt_size = sizeof(float);
+ size_t hstate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint8_t) : sizeof(float);
+ size_t gate_dt_size = (src_data_t == data_type::u8) ? sizeof(uint32_t) : sizeof(float);
+ size_t qscale_dt_size = sizeof(float);
+ size_t bias_dt_size = sizeof(float);
+
+ void generate() {
+ using namespace Xbyak;
+
+ int mask = attr_->rnn_weights_qparams_.mask_;
+ float *weights_scales = attr_->rnn_weights_qparams_.scales_;
+ float data_scale = attr_->rnn_data_qparams_.scale_;
+ float data_shift = attr_->rnn_data_qparams_.shift_;
+
+ // Labels declaration
+ Label vector_loop_start_label, vector_loop_end_label;
+ Label rem_loop_start_label, rem_loop_end_label;
+ Label table_label;
+
+ // Register map
+ Reg64 loop_cnt(r11); // loop counter
+ Reg64 table_reg(rbx); // table is used for data scale and shifts
+ Reg64 weights_scales_reg(r13);
+ // We skip vmm0 as it can be used by the injector for masks on sse4.2
+ Vmm G0(1), G1(2), G2(3), G3(4), tmp1_vmm(5), tmp2_vmm(6), zero_vmm(7);
+
+ // constant table map
+ Address dscale_off_addr = ptr[table_reg];
+ Address dshift_off_addr = ptr[table_reg + vlen];
+ Address ymm_perm_mask_addr = ptr[table_reg + 2*vlen];
+ Address zmm_perm_mask_addr = ptr[table_reg + 2*vlen + cpu_isa_traits<avx>::vlen];
+
+ // quantize from float to u8
+ auto q_d = [&](Vmm f, Vmm tmp_vmm) {
+ uni_vpxor(tmp_vmm, tmp_vmm, tmp_vmm);
+ uni_vmulps(f, f, dscale_off_addr); // apply scale
+ uni_vaddps(f, f, dshift_off_addr); // apply shift
+ uni_vcvtps2dq(f, f); // convert to int32
+ uni_vpackssdw(f, f, tmp_vmm); // convert from s32 to s16
+ uni_vpackuswb(f, f, tmp_vmm); // convert from s16 to u8 with saturation
+ // Note that the results are interleaved by 128 bit chunks, so we need to merge them together
+ switch (vlen) {
+ case 64: { //avx512
+ Zmm fz(f.getIdx()), tmpz(tmp_vmm.getIdx());
+ uni_vmovups(tmpz, zmm_perm_mask_addr);
+ vpermd(fz, tmpz, fz);
+ break; }
+ case 32: { //avx
+ Ymm fy(f.getIdx()), tmpy(tmp_vmm.getIdx());
+ uni_vmovups(tmpy, ymm_perm_mask_addr);
+ vpermd(fy, tmpy, fy);
+ break; }
+ case 16: // sse: nothing to do
+ break;
+ default: assert(!"Unsupported case");
+ };
+ };
+
+ auto fast_recip =[&](Vmm s, Vmm tmp, bool packed) {
+ if (packed)
+ uni_vrcpps(tmp, s);
+ else
+ uni_vrcpss(tmp, s); // prevent divide by zero
+ // we add one Newton iteration
+ uni_vmulps(s, s, tmp);
+ uni_vmulps(s, s, tmp); // s <- s * tmp^2
+ uni_vaddps(tmp, tmp, tmp);
+ uni_vsubps(tmp, tmp, s);
+ uni_vmovups(s, tmp); // s <- 2 * tmp - s * tmp^2
+ };
+
+ // dequantize from s32 to float
+ auto deq_w = [&](Vmm s, Vmm tmp1, Vmm tmp2, int gate, bool packed) {
+ // TODO: if mask is 0 precompute mul and inverse
+ if (mask == 0)
+ uni_vbroadcastss(tmp1, ptr[weights_scales_reg]);
+ else
+ uni_vmovups(tmp1, ptr[weights_scales_reg + gate * rnn_.dic * qscale_dt_size]);
+ uni_vcvtdq2ps(s, s);
+ uni_vmulps(tmp1, tmp1, dscale_off_addr);
+ fast_recip(tmp1, tmp2, packed);
+ uni_vmulps(s, s, tmp1);
+ };
+
+ // We start code generations here
+ preamble();
+
+ // extract addresses passed as parameter
+#ifdef _WIN32
+ auto addr_ws_gates_reg = abi_param1;
+ auto addr_bias_reg = abi_param2;
+ auto addr_states_t_l_reg = abi_param3;
+ auto addr_c_states_tm1_l_reg = abi_param4;
+ auto addr_c_states_t_l_reg = r10;
+ // Here we cannot use rbp to have initial stack pointer so we
+ // use rsp and offset it with the size of pushed registers in
+ // preamble
+ mov(addr_c_states_t_l_reg, ptr[rsp + get_size_of_abi_save_regs() + 40]);
+#else
+ auto addr_ws_gates_reg = abi_param1;
+ auto addr_bias_reg = abi_param2;
+ auto addr_states_t_l_reg = abi_param3;
+ auto addr_c_states_tm1_l_reg = abi_param4;
+ auto addr_c_states_t_l_reg = abi_param5;
+#endif
+
+ // initialize registers with addresses and constants
+ mov(table_reg, table_label);
+ mov(weights_scales_reg, size_t(weights_scales));
+ // both sigmoid and tanh use the same table so load address just once in rax
+ sigmoid_injector_->load_table_addr();
+
+ mov(loop_cnt, rnn_.dic * gate_dt_size);
+ cmp(loop_cnt, vlen);
+ jl(vector_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
+
+ L(vector_loop_start_label);
+ {
+ // load G0 G1 G2 G3
+ uni_vmovups(G0, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
+ uni_vmovups(G1, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
+ uni_vmovups(G2, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
+ uni_vmovups(G3, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
+
+ // dequantize the gates from s32 to f32 if needed
+ if (src_data_t == data_type::u8){
+ deq_w(G0, tmp1_vmm, tmp2_vmm, 0, true);
+ deq_w(G1, tmp1_vmm, tmp2_vmm, 1, true);
+ deq_w(G2, tmp1_vmm, tmp2_vmm, 2, true);
+ deq_w(G3, tmp1_vmm, tmp2_vmm, 3, true);
+ }
+
+ // add biases
+ uni_vaddps(G0, G0, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G1, G1, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G2, G2, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G3, G3, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
+
+ // inject eltwise code
+ sigmoid_injector_->compute_vector(G0.getIdx());
+ sigmoid_injector_->compute_vector(G1.getIdx());
+ tanh_injector_->compute_vector(G2.getIdx());
+ sigmoid_injector_->compute_vector(G3.getIdx());
+
+ // compute c_states_t_l = G1 * c_tm1_l + G0 * G2
+ uni_vmovups(tmp1_vmm, ptr[addr_c_states_tm1_l_reg]);
+ uni_vmulps(tmp1_vmm, tmp1_vmm, G1);
+ uni_vfmadd231ps(tmp1_vmm, G0, G2);
+ uni_vmovups(ptr[addr_c_states_t_l_reg], tmp1_vmm);
+
+ // states_t_l = G3 * tanh(c_states_t_l)
+ tanh_injector_->compute_vector(tmp1_vmm.getIdx());
+ uni_vmulps(tmp1_vmm, tmp1_vmm, G3);
+
+ // if int8, we quantize the resulting state
+ if (src_data_t == data_type::u8)
+ q_d(tmp1_vmm, tmp2_vmm);
+
+ // write back the result
+ if(vlen_dst == vlen)
+ uni_vmovups(ptr[addr_states_t_l_reg], tmp1_vmm);
+ else
+ // we write only 1/4 of the register
+ switch(vlen_dst){
+ case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
+ case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
+ case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1_vmm.getIdx())); break;
+ default:
+ assert(!"Unsuported vector length for quantization");
+ }
+
+ // increment address pointers
+ add(addr_ws_gates_reg, vlen);
+ add(addr_bias_reg, vlen);
+ add(addr_states_t_l_reg, vlen_dst);
+ add(addr_c_states_tm1_l_reg, vlen);
+ add(addr_c_states_t_l_reg, vlen);
+ if (mask != 0)
+ add(weights_scales_reg, vlen);
+
+ // increment loop counter
+ sub(loop_cnt, vlen);
+ cmp(loop_cnt, vlen);
+ jge(vector_loop_start_label);
+ }
+ L(vector_loop_end_label);
+
+ cmp(loop_cnt, 0);
+ je(rem_loop_end_label, Xbyak::CodeGenerator::T_NEAR);
+ // Same code as above, we just use movuss for accessing inputs
+ // TODO: smarter handling of tails with Zmm -> Ymm -> Xmm -> scalar
+ L(rem_loop_start_label);
+ {
+ // remaping registers to Xmms
+ Xmm G0s(G0.getIdx()), G1s(G1.getIdx()), G2s(G2.getIdx()), G3s(G3.getIdx());
+ Xmm tmp1s_vmm(tmp1_vmm.getIdx());
+
+ // load G0 G1 G2 G3
+ uni_vmovss(G0s, ptr[addr_ws_gates_reg + 0 * rnn_.dic * gate_dt_size]);
+ uni_vmovss(G1s, ptr[addr_ws_gates_reg + 1 * rnn_.dic * gate_dt_size]);
+ uni_vmovss(G2s, ptr[addr_ws_gates_reg + 2 * rnn_.dic * gate_dt_size]);
+ uni_vmovss(G3s, ptr[addr_ws_gates_reg + 3 * rnn_.dic * gate_dt_size]);
+
+ // dequantize the gates from s32 to f32 if needed
+ if (src_data_t == data_type::u8){
+ deq_w(G0, tmp1_vmm, tmp2_vmm, 0, false);
+ deq_w(G1, tmp1_vmm, tmp2_vmm, 1, false);
+ deq_w(G2, tmp1_vmm, tmp2_vmm, 2, false);
+ deq_w(G3, tmp1_vmm, tmp2_vmm, 3, false);
+ }
+
+ // add biases
+ uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 0 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G0s, G0s, tmp1s_vmm);
+ uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 1 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G1s, G1s, tmp1s_vmm);
+ uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 2 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G2s, G2s, tmp1s_vmm);
+ uni_vmovss(tmp1s_vmm, ptr[addr_bias_reg + 3 * rnn_.dic * bias_dt_size]);
+ uni_vaddps(G3s, G3s, tmp1s_vmm);
+
+ // inject eltwise code
+ sigmoid_injector_->compute_vector(G0s.getIdx());
+ sigmoid_injector_->compute_vector(G1s.getIdx());
+ tanh_injector_->compute_vector(G2s.getIdx());
+ sigmoid_injector_->compute_vector(G3s.getIdx());
+
+ // compute c_states_t_l = G1 * c_tm1_l + G0s * G2
+ uni_vmovups(tmp1s_vmm, ptr[addr_c_states_tm1_l_reg]);
+ uni_vmulps(tmp1s_vmm, tmp1s_vmm, G1s);
+ uni_vfmadd231ps(tmp1s_vmm, G0s, G2s);
+ uni_vmovss(ptr[addr_c_states_t_l_reg], tmp1s_vmm);
+
+ // states_t_l = G3 * tanh(c_states_t_l)
+ tanh_injector_->compute_vector(tmp1s_vmm.getIdx());
+ uni_vmulps(tmp1s_vmm, tmp1s_vmm, G3s);
+
+ // if int8, we quantize the resulting state
+ if (src_data_t == data_type::u8)
+ q_d(tmp1_vmm, tmp2_vmm);
+
+ // write back the result
+ if(vlen_dst == vlen)
+ uni_vmovups(ptr[addr_states_t_l_reg], tmp1s_vmm);
+ else
+ // we write only 1/4 of the register
+ switch(vlen_dst){
+ case 16: uni_vmovups(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
+ case 8: uni_vmovsd(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
+ case 4: uni_vmovss(ptr[addr_states_t_l_reg], Xmm(tmp1s_vmm.getIdx())); break;
+ default:
+ assert(!"Unsuported vector length for quantization");
+ }
+
+ // increment address pointers
+ add(addr_ws_gates_reg, gate_dt_size);
+ add(addr_bias_reg, bias_dt_size);
+ add(addr_states_t_l_reg, hstate_dt_size);
+ add(addr_c_states_tm1_l_reg, cstate_dt_size);
+ add(addr_c_states_t_l_reg, cstate_dt_size);
+ if (mask != 0)
+ add(weights_scales_reg, qscale_dt_size);
+
+ // increment loop counter
+ sub(loop_cnt, gate_dt_size);
+ cmp(loop_cnt, 0);
+ jg(rem_loop_start_label);
+
+ }
+ L(rem_loop_end_label);
+
+ postamble();
+
+ // Again, only one table is needed and shared between sigmoid and tanh
+ sigmoid_injector_->prepare_table(false);
+ tanh_injector_->prepare_table(true);
+
+ L(table_label);
+ {
+ for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_scale));
+ for (size_t i = 0; i < vlen / sizeof(float); i++) dd(float2int(data_shift));
+ // perm mask for ymm
+ dd(0); dd(4); dd(2); dd(3); dd(1); dd(5); dd(6); dd(7);
+ // perm mask for zmm
+ dd(0); dd(4); dd(8); dd(12); dd(1); dd(5); dd(6); dd(7);
+ dd(2); dd(9); dd(10); dd(11); dd(3); dd(12); dd(13); dd(14);
+ }
+ }
+
+};
+
+template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::f32>;
+template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::f32>;
+template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::f32>;
+
+template struct jit_uni_lstm_postgemm_kernel_fwd<sse42, data_type::u8>;
+template struct jit_uni_lstm_postgemm_kernel_fwd<avx2, data_type::u8>;
+template struct jit_uni_lstm_postgemm_kernel_fwd<avx512_core, data_type::u8>;
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
new file mode 100644
index 0000000000..ead536816c
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.cpp
@@ -0,0 +1,788 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+ General architecture
+
+ for diff states, we have n_states + 1 as we have n_states diff
+ to propagate to the previous iteration and 1 states to propagate
+ to the previous layer
+ index 0 is dh for cell(t-1, l) to consume
+ index 1 is dc for cell(t-1, l) to consume
+ index 2 is dh for cell(t, l-1) to consume
+ this indexing enables to have the same indexing for states in elemwise
+ function
+ only the cell execution function should be impacted
+
+ */
+
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+#include "../gemm/gemm.hpp"
+#include "../simple_q10n.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::memory_tracking::names;
+using namespace rnn_utils;
+#define AOC array_offset_calculator
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::gates_reduction(
+ const rnn_conf_t &rnn, const acc_data_t *ws_gates_,
+ float *diff_bias_) const {
+ auto body = [&](int i, int k) {
+ for (int j = 0; j < rnn.mb; j++)
+ diff_bias_[i * rnn.dic + k]
+ += ws_gates_[j * rnn.gates_ws_ld + i * rnn.dic + k];
+ };
+
+ // @todo block k on simd-width
+#if MKLDNN_THR == MKLDNN_THR_OMP && _OPENMP >= 201307 \
+ /* icc 17.0 has a problem with simd collapse */ \
+ && !((defined __INTEL_COMPILER) && (__INTEL_COMPILER == 1700))
+#pragma omp parallel for simd collapse(2)
+ for (int i = 0; i < rnn.n_gates; i++)
+ for (int k = 0; k < rnn.dic; k++)
+ body(i, k);
+#else
+ parallel_nd(rnn.n_gates, rnn.dic, body);
+#endif
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::gemm)) {
+ assert(ldA * ldB * ldC != 0);
+ extended_sgemm(&transA, &transB, &m, &n, &k, &alpha, a_, &ldA, b_, &ldB,
+ &beta, c_, &ldC, nullptr, pd()->rnn_.use_jit_gemm);
+}
+
+template <>
+rnn_gemm_sig((ref_rnn_fwd_u8s8_t::gemm)) {
+ assert(!"non packed gemm is disabled for int8");
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_gemm_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::packed_gemm)) {
+#if (USE_MKL_PACKED_GEMM)
+ assert(transA == 'N');
+ cblas_sgemm_compute(CblasColMajor, CblasPacked,
+ (transB == 'T') ? CblasTrans : CblasNoTrans, m, n, k, a_, ldA, b_,
+ ldB, beta, c_, ldC);
+#else
+ UNUSED(transA);
+ UNUSED(transB);
+ UNUSED(m);
+ UNUSED(n);
+ UNUSED(k);
+ UNUSED(alpha);
+ UNUSED(ldA);
+ UNUSED(b_);
+ UNUSED(ldB);
+ UNUSED(beta);
+ UNUSED(c_);
+ UNUSED(ldC);
+ assert(!"packed gemm is disabled");
+#endif
+}
+
+template <>
+rnn_gemm_sig((ref_rnn_fwd_u8s8_t::packed_gemm)) {
+#if (USE_MKL_PACKED_GEMM)
+ int8_t offseta = 0, offsetb = 0;
+ int32_t offsetc = 0;
+ cblas_gemm_s8u8s32_compute(CblasColMajor, (CBLAS_TRANSPOSE)CblasPacked,
+ CblasNoTrans, CblasFixOffset, m, n, k, alpha, a_, ldA, offseta, b_,
+ ldB, offsetb, beta, c_, ldC, &offsetc);
+#else
+ UNUSED(transA);
+ UNUSED(transB);
+ UNUSED(m);
+ UNUSED(n);
+ UNUSED(k);
+ UNUSED(alpha);
+ UNUSED(ldA);
+ UNUSED(b_);
+ UNUSED(ldB);
+ UNUSED(beta);
+ UNUSED(c_);
+ UNUSED(ldC);
+ assert(!"packed gemm is disabled");
+#endif
+}
+
+//*************** Grid computations strategy: linear ***************//
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_grid_execution_sig(
+ (_ref_rnn_common_t<aprop, src_type, weights_type>::linear_execution)) {
+ AOC<src_data_t, 4> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
+ AOC<float, 4> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.states_nld * rnn.states_ws_ld);
+ AOC<float, 5> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
+ (rnn.n_states + 1), rnn.n_iter + 1,
+ rnn.states_nld * rnn.states_ws_ld);
+ AOC<acc_data_t, 4> ws_gates(ws_gates_, rnn.n_layer, rnn.n_dir, rnn.n_iter,
+ rnn.gates_nld * rnn.gates_ws_ld);
+ AOC<weights_data_t *, 3> weights_input(
+ weights_layer_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_layer);
+ AOC<weights_data_t *, 3> weights_states(
+ weights_states_, rnn.n_layer, rnn.n_dir, rnn.n_parts_weights_iter);
+ AOC<float*, 3> bias(
+ bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
+ AOC<float, 3> diff_weights_layer(diff_weights_layer_, rnn.n_layer,
+ rnn.n_dir,
+ rnn.diff_weights_layer_nld * rnn.diff_weights_layer_ld);
+ AOC<float, 3> diff_weights_iter(diff_weights_iter_, rnn.n_layer, rnn.n_dir,
+ rnn.diff_weights_iter_nld * rnn.diff_weights_iter_ld);
+ AOC<float, 3> diff_bias(
+ diff_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
+ AOC<float, 4> ws_grid(
+ ws_grid_, rnn.n_layer, rnn.n_dir, rnn.n_iter, (int)rnn.ws_per_cell);
+
+ // We run the grid of computation
+ for (int dir = 0; dir < rnn.n_dir; dir++) {
+ for (int j = 0; j < rnn.n_layer; j++) {
+ int lay = (aprop == prop_kind::forward) ? j : rnn.n_layer - j - 1;
+
+ if ((aprop == prop_kind::forward) && rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.n_gates * rnn.dic,
+ rnn.mb * rnn.n_iter, rnn.slc, 1.0,
+ weights_input(lay, dir, 0), rnn.weights_iter_ld,
+ &(ws_states(lay, dir, 1, 0)), rnn.states_ws_ld, 0.0,
+ &(ws_gates(lay, dir, 0, 0)), rnn.gates_ws_ld);
+ }
+
+ for (int i = 0; i < rnn.n_iter; i++) {
+ int iter = (aprop == prop_kind::forward) ? i : rnn.n_iter - i - 1;
+ (this->*cell_func)(rnn,
+ &(ws_states(lay + 1, dir, iter + 1, 0)),
+ &(ws_c_states(lay + 1, dir, iter + 1, 0)),
+ &(ws_diff_states(lay, dir, 0, iter, 0)),
+ &(weights_input(lay, dir, 0)),
+ &(weights_states(lay, dir, 0)),
+ &(bias(lay, dir, 0)),
+ &(ws_states(lay, dir, iter + 1, 0)),
+ &(ws_states(lay + 1, dir, iter, 0)),
+ &(ws_c_states(lay + 1, dir, iter, 0)),
+ &(ws_diff_states(lay + 1, dir, 0, iter, 0)),
+ &(ws_diff_states(lay, dir, 0, iter + 1, 0)),
+ &(diff_weights_layer(lay, dir, 0)),
+ &(diff_weights_iter(lay, dir, 0)),
+ &(diff_bias(lay, dir, 0)),
+ &(ws_gates(lay, dir, iter, 0)),
+ &(ws_grid(lay, dir, iter, 0)),
+ ws_cell_);
+ }
+
+ if ((aprop == prop_kind::backward) && rnn.merge_gemm_layer) {
+ (this->*gemm_layer_func)('N', 'N', rnn.slc, rnn.mb * rnn.n_iter,
+ rnn.n_gates * rnn.dic, 1.0, weights_input(lay, dir, 0),
+ rnn.weights_layer_ld,
+ (src_data_t *)(&(ws_gates(lay, dir, 0, 0))),
+ rnn.gates_ws_ld, 0.0,
+ (acc_data_t *)(&(ws_diff_states(
+ lay, dir, rnn.n_states, 0, 0))),
+ rnn.states_ws_ld);
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.slc,
+ rnn.mb * rnn.n_iter, 1.0,
+ (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
+ rnn.gates_ws_ld,
+ (src_data_t *)(&(ws_states(lay, dir, 1, 0))),
+ rnn.states_ws_ld, 1.0,
+ (acc_data_t *)(&(diff_weights_layer(lay, dir, 0))),
+ rnn.diff_weights_layer_ld);
+ }
+ if ((aprop == prop_kind::backward) && rnn.merge_gemm_iter) {
+ gemm('N', 'T', rnn.n_gates * rnn.dic, rnn.sic,
+ rnn.mb * rnn.n_iter, 1.0,
+ (weights_data_t *)(&(ws_gates(lay, dir, 0, 0))),
+ rnn.gates_ws_ld,
+ (src_data_t *)(&(ws_states(lay + 1, dir, 0, 0))),
+ rnn.states_ws_ld, 1.0,
+ (acc_data_t *)(&(diff_weights_iter(lay, dir, 0))),
+ rnn.diff_weights_iter_ld);
+ }
+ }
+ }
+}
+
+//********* GRID computations strategy: utility functions **********//
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_layer(
+ const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
+ float *__restrict ws_diff_states_, const src_data_t *__restrict xt_,
+ const float *__restrict diff_dst_layer_) const {
+
+ AOC<src_data_t, 4> ws_states(
+ ws_states_, rnn.n_dir, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ auto xt_d = memory_desc_wrapper(pd()->src_md(0));
+
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto xxt = xt_ + xt_d.blk_off(it, b);
+ src_data_t *ws_l2r_ptr = &(ws_states(0, it + 1, b, 0));
+ src_data_t *ws_r2l_ptr = &(ws_states(rnn.n_dir - 1, rnn.n_iter - it, b, 0));
+ if (rnn.exec_dir != r2l)
+ for (int c = 0; c < rnn.slc; c++)
+ ws_l2r_ptr[c] = xxt[c];
+ if (rnn.exec_dir != l2r)
+ for (int c = 0; c < rnn.slc; c++)
+ ws_r2l_ptr[c] = xxt[c];
+ });
+}
+
+template <>
+void ref_rnn_bwd_f32_t::copy_init_layer(const rnn_conf_t &rnn,
+ src_data_t *ws_states_, float *ws_diff_states_, const src_data_t *xt_,
+ const float *diff_dst_layer_) const {
+ AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
+ (rnn.n_states + 1), rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ auto diff_dst_layer_d = memory_desc_wrapper(pd()->diff_dst_md(0));
+
+ switch (rnn.exec_dir) {
+ case bi_concat:
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto diff_dst_layer_x
+ = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
+ for (int s = 0; s < rnn.dic; s++) {
+ ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
+ = diff_dst_layer_x[s];
+ ws_diff_states(
+ rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
+ = diff_dst_layer_x[rnn.dic + s];
+ }
+ });
+ break;
+ case bi_sum:
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto diff_dst_layer_x
+ = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
+ for (int s = 0; s < rnn.dic; s++) {
+ ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
+ = diff_dst_layer_x[s];
+ ws_diff_states(
+ rnn.n_layer, 1, rnn.n_states, rnn.n_iter - it - 1, b, s)
+ = diff_dst_layer_x[s];
+ }
+ });
+ break;
+ case l2r:
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto diff_dst_layer_x
+ = diff_dst_layer_ + diff_dst_layer_d.blk_off(it, b);
+ for (int s = 0; s < rnn.dic; s++) {
+ ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
+ = diff_dst_layer_x[s];
+ }
+ });
+ break;
+ case r2l:
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ auto diff_dst_layer_x = diff_dst_layer_
+ + diff_dst_layer_d.blk_off(rnn.n_iter - it - 1, b);
+ for (int s = 0; s < rnn.dic; s++) {
+ ws_diff_states(rnn.n_layer, 0, rnn.n_states, it, b, s)
+ = diff_dst_layer_x[s];
+ }
+ });
+ break;
+ default: assert(!"Unsupported direction"); break;
+ }
+}
+
+/* For int8 configuration, input iteration states may be of types f32 or u8
+ * Internally h_state is always stored in u8 and c_state is always stored in f32
+ * If input states are of type u8 then h state is copied and c state is dequantized
+ * If input states are of type f32 then h state is quantized and c_state is copied
+ * */
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+template <typename input_data_t>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_init_iter(
+ const rnn_conf_t &rnn, src_data_t *__restrict ws_states_,
+ float *__restrict ws_c_states_, float *__restrict ws_diff_states_,
+ const input_data_t *__restrict firstit_states_,
+ const float *__restrict diff_dst_iter_) const {
+ AOC<src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ AOC<float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
+ float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
+
+ const bool quantize = pd()->with_src_iter()
+ && pd()->src_md(1)->data_type == data_type::f32
+ && rnn.dt_conf != all_f32;
+ auto maybe_q = [&](input_data_t f) {
+ if (quantize) {
+ float qf = f * data_scale + data_shift;
+ return qz_a1b0<float, src_data_t>()(qf);
+ } else
+ return (src_data_t)f;
+ };
+
+ const bool dequantize = pd()->with_src_iter()
+ && pd()->src_md(1)->data_type == data_type::u8;
+ auto maybe_deq = [&](input_data_t s) {
+ if (dequantize)
+ return (((float)s - data_shift) / data_scale);
+ else
+ return (float)s;
+ };
+ auto firstit_states_d = memory_desc_wrapper(pd()->src_md(1));
+ if (firstit_states_) {
+ parallel_nd(
+ rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
+ for (int s = 0; s < rnn.sic; s++)
+ ws_states(lay + 1, dir, 0, b, s) = maybe_q(
+ firstit_states_[firstit_states_d.blk_off(
+ lay, dir, 0, b, s)]);
+ if (pd()->cell_kind() == alg_kind::vanilla_lstm)
+ for (int s = 0; s < rnn.sic; s++)
+ ws_c_states(lay + 1, dir, 0, b, s) = maybe_deq(
+ firstit_states_[firstit_states_d.blk_off(
+ lay, dir, 1, b, s)]);
+ });
+ } else {
+ parallel_nd(
+ rnn.n_layer, rnn.n_dir, rnn.mb, [&](int lay, int dir, int b) {
+ for (int j = 0; j < rnn.sic; j++) {
+ ws_states(lay + 1, dir, 0, b, j) = (src_data_t)0;
+ ws_c_states(lay + 1, dir, 0, b, j) = 0.0f;
+ }
+ });
+ }
+}
+
+template <>
+template <typename input_data_t>
+void ref_rnn_bwd_f32_t::copy_init_iter(const rnn_conf_t &rnn,
+ src_data_t *ws_states_, float *ws_c_states_, float *ws_diff_states_,
+ const input_data_t *firstit_states_,
+ const float *diff_dst_iter_) const {
+ AOC<float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_states + 1, rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ auto diff_dst_iter_d = memory_desc_wrapper(pd()->diff_dst_md(1));
+ if (diff_dst_iter_) {
+ parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
+ [&](int lay, int dir, int state, int b) {
+ array_copy(&(ws_diff_states(
+ lay, dir, state, rnn.n_iter, b, 0)),
+ diff_dst_iter_
+ + diff_dst_iter_d.blk_off(
+ lay, dir, state, b),
+ rnn.dic);
+ });
+ } else {
+ parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
+ [&](int lay, int dir, int state, int i) {
+ for (int j = 0; j < rnn.dic; j++)
+ ws_diff_states(lay, dir, state, rnn.n_iter, i, j)
+ = 0.0f;
+ });
+ }
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+template <typename dst_data_t>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_layer(
+ const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer,
+ const src_data_t *ws_states_, const float *ws_diff_states_) const {
+
+ auto dst_layer_d = memory_desc_wrapper(pd()->dst_md(0));
+ AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ float shift = (pd()->attr()->rnn_data_qparams_.shift_);
+ float scale = (pd()->attr()->rnn_data_qparams_.scale_);
+
+ const bool dequantize = pd()->dst_md(0)->data_type == data_type::f32
+ && rnn.dt_conf != all_f32;
+ auto maybe_deq = [&](src_data_t s) {
+ if (dequantize)
+ return (dst_data_t)(((float)s - shift) / scale);
+ else
+ return (dst_data_t)s;
+ };
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ int dir = 0;
+ if (rnn.exec_dir != r2l) {
+ for (int s = 0; s < rnn.dic; s++) {
+ dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
+ = maybe_deq(ws_states(rnn.n_layer, dir, it + 1, b, s));
+ }
+ dir = 1;
+ }
+ if (rnn.exec_dir != l2r) {
+ for (int s = 0; s < rnn.dic; s++)
+ switch (rnn.exec_dir) {
+ case bi_sum:
+ dst_layer_[dst_layer_d.blk_off(it, b, s)]
+ += maybe_deq(ws_states(
+ rnn.n_layer, dir, rnn.n_iter - it, b, s));
+ break;
+ default:
+ dst_layer_[dst_layer_d.blk_off(it, b, dir * rnn.dic + s)]
+ = maybe_deq(ws_states(
+ rnn.n_layer, dir, rnn.n_iter - it, b, s));
+ }
+ }
+ });
+}
+
+template <>
+template <typename dst_data_t>
+void ref_rnn_bwd_f32_t::copy_res_layer(
+ const rnn_conf_t &rnn, dst_data_t *dst_layer_, float *diff_src_layer_,
+ const src_data_t *ws_states_, const float *ws_diff_states_) const {
+ auto diff_src_layer_d = memory_desc_wrapper(pd()->diff_src_md(0));
+ AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
+ rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
+ rnn.states_ws_ld);
+
+ parallel_nd(rnn.n_iter, rnn.mb, [&](int it, int b) {
+ int dir = 0;
+ for (int s = 0; s < rnn.slc; s++) {
+ float *dst_addr = diff_src_layer_
+ + diff_src_layer_d.blk_off(
+ (rnn.exec_dir == r2l) ? rnn.n_iter - 1 - it : it,
+ b, dir * rnn.slc + s);
+ float res = ws_diff_states(0, 0, rnn.n_states, it, b, s);
+ if (rnn.n_dir - 1)
+ res += ws_diff_states(
+ 0, 1, rnn.n_states, rnn.n_iter - 1 - it, b, s);
+ dst_addr[0] = res;
+ }
+ });
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+template <typename output_data_t>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::copy_res_iter(
+ const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
+ const src_data_t *ws_states_, float *ws_c_states_,
+ const float *ws_diff_states_) const {
+ auto dst_iter_d = memory_desc_wrapper(pd()->dst_md(1));
+ AOC<const src_data_t, 5> ws_states(ws_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ AOC<const float, 5> ws_c_states(ws_c_states_, rnn.n_layer + 1, rnn.n_dir,
+ rnn.n_iter + 1, rnn.mb, rnn.states_ws_ld);
+ float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
+ float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
+
+ const bool quantize = pd()->with_dst_iter()
+ && pd()->dst_md(1)->data_type == data_type::u8
+ && rnn.dt_conf != all_f32;
+ auto maybe_q = [&](float f) {
+ if (quantize) {
+ float qf = f * data_scale + data_shift;
+ return qz_a1b0<float, output_data_t>()(qf);
+ } else
+ return (output_data_t)f;
+ };
+
+ const bool dequantize = pd()->with_dst_iter()
+ && pd()->dst_md(1)->data_type == data_type::f32
+ && rnn.dt_conf != all_f32;
+ auto maybe_deq = [&](src_data_t s) {
+ if (dequantize)
+ return (output_data_t)(((float)s - data_shift) / data_scale);
+ else
+ return (output_data_t)s;
+ };
+ if (dst_iter_) {
+ parallel_nd(rnn.n_layer, rnn.n_dir, rnn.mb,
+ [&](int lay, int dir, int b) {
+ for (int s = 0; s < rnn.dic; s++) {
+ dst_iter_[dst_iter_d.blk_off(lay, dir, 0, b, s)]
+ = maybe_deq(ws_states(lay + 1, dir, rnn.n_iter, b, s));
+ }
+ if (pd()->cell_kind() == alg_kind::vanilla_lstm)
+ for (int s = 0; s < rnn.dic; s++) {
+ dst_iter_[dst_iter_d.blk_off(lay, dir, 1, b, s)]
+ = maybe_q(ws_c_states(
+ lay + 1, dir, rnn.n_iter, b, s));
+ }
+ });
+ }
+}
+
+template <>
+template <typename output_data_t>
+void ref_rnn_bwd_f32_t::copy_res_iter(
+ const rnn_conf_t &rnn, output_data_t *dst_iter_, float *diff_src_iter_,
+ const src_data_t *ws_states_, float *ws_c_states_,
+ const float *ws_diff_states_) const {
+ auto diff_src_iter_d = memory_desc_wrapper(pd()->diff_src_md(1));
+ AOC<const float, 6> ws_diff_states(ws_diff_states_, rnn.n_layer + 1,
+ rnn.n_dir, rnn.n_states + 1, rnn.n_iter + 1, rnn.mb,
+ rnn.states_ws_ld);
+ if (diff_src_iter_) {
+ parallel_nd(rnn.n_layer, rnn.n_dir, rnn.n_states, rnn.mb,
+ [&](int lay, int dir, int state, int b) {
+ for (int s = 0; s < rnn.sic; s++) {
+ diff_src_iter_[diff_src_iter_d.blk_off(
+ lay, dir, state, b, s)]
+ = ws_diff_states(lay, dir, state, 0, b, s);
+ }
+ });
+ }
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_bias_prepare_sig((_ref_rnn_common_t<aprop, src_type, weights_type>::bias_prepare)) {
+ /* Original set of bias provided by the user */
+ AOC<const float, 5> b(
+ b_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
+ /* Array of pointers initialized in packing */
+ AOC<float *, 3> bias(bias_, rnn.n_layer, rnn.n_dir, rnn.n_parts_bias);
+ AOC<float, 3> scratch_bias(
+ scratch_bias_, rnn.n_layer, rnn.n_dir, rnn.n_bias * rnn.dic);
+
+ if (rnn.copy_bias) {
+ parallel_nd(rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic,
+ [&](size_t i) { scratch_bias_[i] = b_[i]; });
+ }
+
+ for (int i = 0; i < rnn.n_layer; i++) {
+ for (int d = 0; d < rnn.n_dir; d++) {
+ int offset_bias = 0;
+ for (int p = 0; p < rnn.n_parts_bias; p++) {
+ bias(i, d, p) = rnn.copy_bias
+ ? (float *) &scratch_bias(i, d, offset_bias)
+ : (float *) &b(i, d, offset_bias);
+ offset_bias += rnn.parts_bias[p] * rnn.dic;
+ }
+ }
+ }
+
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_bias_finalize_sig(
+ (_ref_rnn_common_t<aprop, src_type, weights_type>::bias_finalize)) {
+ if (rnn.dt_conf != all_f32) {
+ float data_shift = pd()->attr()->rnn_data_qparams_.shift_;
+ float data_scale = pd()->attr()->rnn_data_qparams_.scale_;
+ float *weights_scales = pd()->attr()->rnn_weights_qparams_.scales_;
+ bool scale_per_oc = pd()->attr()->rnn_weights_qparams_.mask_ != 0;
+ for (int i = 0; i < rnn.n_layer * rnn.n_dir; i++)
+ for (int j = 0; j < rnn.n_bias * rnn.dic; j++) {
+ size_t off = i * rnn.n_bias * rnn.dic + j;
+ float weights_scale
+ = scale_per_oc ? weights_scales[j] : weights_scales[0];
+ scratch_bias_[off] -= (w_iter_comp[off] + w_layer_comp[off])
+ * data_shift / (weights_scale * data_scale);
+ }
+ }
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_weights_assign_sig((_ref_rnn_common_t<aprop, src_type,
+ weights_type>::assign_packed_weights)) {
+ assert(md->format_kind == format_kind::rnn_packed);
+ const auto packed_desc = md->format_desc.rnn_packed_desc;
+ AOC<weights_data_t *, 3> weights(weights_,
+ rnn.n_layer, rnn.n_dir, packed_desc.n_parts);
+
+ size_t offset_packed = 0;
+ for (int l = 0; l < rnn.n_layer; l++)
+ for (int d = 0; d < rnn.n_dir; d++) {
+ for (int p = 0; p < packed_desc.n_parts; p++) {
+ weights(l, d, p) = (weights_data_t *)&w_[offset_packed];
+ offset_packed
+ += packed_desc.part_pack_size[p] / sizeof(weights_data_t);
+ }
+ }
+}
+
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+rnn_weights_assign_sig(
+ (_ref_rnn_common_t<aprop, src_type, weights_type>::assign_weights)) {
+ assert(md->format_kind == format_kind::blocked);
+ const auto &blk = md->format_desc.blocking;
+ /* Original set of weights provided by the user */
+ AOC<const weights_data_t, 3> w(w_,
+ rnn.n_layer, rnn.n_dir, (int)blk.strides[1]);
+ /* Array of pointers for each part of weights */
+ AOC<weights_data_t *, 3> weights(weights_, rnn.n_layer, rnn.n_dir, n_parts);
+
+ for (int i = 0; i < rnn.n_layer; i++)
+ for (int d = 0; d < rnn.n_dir; d++) {
+ size_t offset_weights = 0;
+ for (int p = 0; p < n_parts; p++) {
+ weights(i, d, p) = (weights_data_t *)&w(i, d, offset_weights);
+ offset_weights += gates_per_part[p] * blk.strides[3];
+ }
+ }
+}
+
+//********************* Execution function *********************//
+template <prop_kind_t aprop, data_type_t src_type, data_type_t weights_type>
+void _ref_rnn_common_t<aprop, src_type, weights_type>::execute_(
+ const exec_ctx_t &ctx) const {
+ const rnn_conf_t &rnn = this->pd()->rnn_;
+ auto input = CTX_IN_MEM(const src_data_t *, MKLDNN_ARG_SRC_LAYER);
+ auto states = CTX_IN_MEM(const char *, MKLDNN_ARG_SRC_ITER);
+ auto layer_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_LAYER);
+ auto iter_weights_n_comp = CTX_IN_MEM(const char *, MKLDNN_ARG_WEIGHTS_ITER);
+ auto bias = CTX_IN_MEM(const float *, MKLDNN_ARG_BIAS);
+
+ auto dst_last_layer = rnn.is_fwd
+ ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_LAYER)
+ : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_LAYER));
+ auto dst_last_iter = rnn.is_fwd
+ ? CTX_OUT_MEM(char *, MKLDNN_ARG_DST_ITER)
+ : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_DST_ITER));
+
+ auto diff_dst_layer = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_LAYER);
+ auto diff_dst_iter = CTX_IN_MEM(const float *, MKLDNN_ARG_DIFF_DST_ITER);
+
+ auto w_layer = reinterpret_cast<const weights_data_t *>(layer_weights_n_comp);
+ auto w_iter = reinterpret_cast<const weights_data_t *>(iter_weights_n_comp);
+ auto w_iter_comp = reinterpret_cast<const float *>(
+ iter_weights_n_comp + rnn.weights_iter_comp_offset);
+ auto w_layer_comp = reinterpret_cast<const float *>(
+ layer_weights_n_comp + rnn.weights_layer_comp_offset);
+
+ auto scratchpad = this->scratchpad(ctx);
+
+ auto ptr_wei_layer
+ = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_layer);
+ auto ptr_wei_iter
+ = scratchpad.template get<weights_data_t *>(key_rnn_ptrs_wei_iter);
+ auto ptr_bias =
+ scratchpad.template get<float *>(key_rnn_ptrs_bia);
+
+ // fetchihg buffers from the workspace
+ // if no workspace was provided we use the scratchpad
+ char *scratch_ptr = scratchpad.template get<char>(key_rnn_space);
+ char *ws_ptr = nullptr;
+ if (rnn.use_workspace)
+ ws_ptr = rnn.is_fwd
+ ? CTX_OUT_MEM(char *, MKLDNN_ARG_WORKSPACE)
+ : const_cast<char *>(CTX_IN_MEM(const char *, MKLDNN_ARG_WORKSPACE));
+
+ char *base_ptr = rnn.use_workspace ? ws_ptr : scratch_ptr;
+ acc_data_t *ws_gates = (acc_data_t *)(base_ptr + ws_gates_offset_);
+ src_data_t *ws_states = (src_data_t *)(base_ptr + ws_states_offset_);
+ float *ws_c_states = (float *)(base_ptr + ws_c_states_offset_);
+ float *ws_diff_states = (float *)(base_ptr + ws_diff_states_offset_);
+ float *ws_grid = (float *)(base_ptr + ws_grid_comp_offset_);
+ float *ws_cell = (float *)(base_ptr + ws_cell_comp_offset_);
+
+ auto diff_src_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_LAYER);
+ auto diff_src_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_SRC_ITER);
+
+ auto diff_weights_layer = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_LAYER);
+ auto diff_weights_iter = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_WEIGHTS_ITER);
+ auto diff_bias = CTX_OUT_MEM(float *, MKLDNN_ARG_DIFF_BIAS);
+
+ // Fetching extra buffers from scratchpad
+ float *ws_bias = (float *)(scratch_ptr + ws_bias_offset_);
+
+ // initialize diff_states to 0
+ if (aprop == prop_kind::backward)
+ array_set(ws_diff_states, 0.0f, rnn.ws_diff_states_size / sizeof(float));
+
+ /* Pack(if using packed gemm API) or copy(if input arrays have bad leading
+ * dimension */
+ (this->*bias_preparation_func)(rnn, ptr_bias, bias, ws_bias);
+
+ (this->*weights_iter_assign_func)(rnn, pd()->weights_md(1),
+ rnn.weights_iter_nld, rnn.weights_iter_ld, rnn.dic,
+ rnn.sic, rnn.n_parts_weights_iter, rnn.parts_weights_iter,
+ rnn.part_weights_iter_pack_size, ptr_wei_iter, w_iter,
+ ptr_bias, bias, ws_bias);
+ (this->*weights_layer_assign_func)(rnn, pd()->weights_md(0),
+ rnn.weights_layer_nld, rnn.weights_layer_ld, rnn.dic, rnn.slc,
+ rnn.n_parts_weights_layer, rnn.parts_weights_layer,
+ rnn.part_weights_layer_pack_size, ptr_wei_layer, w_layer, ptr_bias,
+ bias, ws_bias);
+
+ (this->*bias_finalization_func)(rnn, ws_bias, w_iter_comp, w_layer_comp);
+
+ // we first need to copy the initial states and input into ws
+ copy_init_layer(rnn, ws_states, ws_diff_states, input, diff_dst_layer);
+ if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
+ || rnn.dt_conf == all_f32)
+ copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
+ (const float *)states, diff_dst_iter);
+ else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
+ copy_init_iter(rnn, ws_states, ws_c_states, ws_diff_states,
+ (const uint8_t *)states, diff_dst_iter);
+ else
+ assert(!"unimplemented");
+
+ // run the execution on the grid
+ (this->*grid_computation)(rnn, ptr_wei_layer, ptr_wei_iter, ptr_bias,
+ ws_states, ws_c_states, ws_diff_states, ws_gates, ws_cell, ws_grid,
+ diff_weights_layer, diff_weights_iter, diff_bias);
+
+ // Finally we copy the results to the result buffers
+ if (rnn.dt_conf == u8u8u8f32 || rnn.dt_conf == f32u8f32f32
+ || rnn.dt_conf == all_f32)
+ copy_res_layer(rnn, (float *)dst_last_layer, diff_src_layer, ws_states,
+ ws_diff_states);
+ else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == f32u8f32u8)
+ copy_res_layer(rnn, (uint8_t *)dst_last_layer, diff_src_layer,
+ ws_states, ws_diff_states);
+ else
+ assert(!"unimplemented");
+
+ if (rnn.dt_conf == f32u8f32u8 || rnn.dt_conf == f32u8f32f32
+ || rnn.dt_conf == all_f32)
+ copy_res_iter(rnn, (float *)dst_last_iter, diff_src_iter, ws_states,
+ ws_c_states, ws_diff_states);
+ else if (rnn.dt_conf == u8u8u8u8 || rnn.dt_conf == u8u8u8f32)
+ copy_res_iter(rnn, (uint8_t *)dst_last_iter, diff_src_iter, ws_states,
+ ws_c_states, ws_diff_states);
+ else
+ assert(!"unimplemented");
+};
+
+/* Fix for MSVS warning C4661 */
+template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution);
+template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru);
+template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_f32_t::cell_execution_gru_lbr);
+template<> rnn_cell_execution_sig(ref_rnn_fwd_u8s8_t::cell_execution_gru_lbr);
+template<> rnn_cell_execution_sig(ref_rnn_bwd_f32_t::cell_execution_gru_lbr);
+template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::rnn_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::rnn_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::rnn_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::lstm_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::lstm_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::lstm_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_f32_t::gru_lbr_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_fwd_u8s8_t::gru_lbr_elemwise);
+template<> rnn_elemwise_sig(ref_rnn_bwd_f32_t::gru_lbr_elemwise);
+
+template struct _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
+template struct _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
+template struct _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
+
+#undef AOC
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp
new file mode 100644
index 0000000000..6f449a9016
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/ref_rnn.hpp
@@ -0,0 +1,328 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_REF_RNN_HPP
+#define CPU_REF_RNN_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "../cpu_isa_traits.hpp"
+#include "../gemm/os_blas.hpp"
+
+#include "cpu_rnn_pd.hpp"
+#include "../cpu_primitive.hpp"
+#include "rnn_utils.hpp"
+#include "jit_uni_rnn_postgemm.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <alg_kind_t alg_kind, prop_kind_t prop_kind>
+float activation(float s, float alpha, float cliping, float dd);
+
+template <prop_kind_t aprop, impl::data_type_t src_type,
+ impl::data_type_t weights_type>
+struct _ref_rnn_common_t : public cpu_primitive_t {
+ typedef typename prec_traits<src_type>::type src_data_t;
+ typedef typename prec_traits<weights_type>::type weights_data_t;
+ typedef typename utils::conditional<src_type == data_type::u8, int32_t,
+ float>::type acc_data_t;
+
+ using class_name = _ref_rnn_common_t<aprop, src_type, weights_type>;
+
+ typedef rnn_elemwise_sig((class_name::*elemwise_f));
+ typedef rnn_cell_execution_sig((class_name::*cell_execution_f));
+ typedef rnn_grid_execution_sig((class_name::*grid_execution_f));
+
+ typedef rnn_gemm_sig((class_name::*gemm_t));
+ typedef rnn_bias_prepare_sig((class_name::*bias_prepare_t));
+ typedef rnn_bias_finalize_sig((class_name::*bias_finalize_t));
+ typedef rnn_weights_assign_sig((class_name::*weights_assign_t));
+
+ using base_pd_t =
+ typename utils::conditional<false || aprop == prop_kind::forward,
+ cpu_rnn_fwd_pd_t, cpu_rnn_bwd_pd_t>::type;
+
+ struct pd_t : public base_pd_t {
+ using base_pd_t::base_pd_t;
+
+ DECLARE_COMMON_PD_T("ref:any", class_name);
+
+ status_t init() {
+ using namespace prop_kind;
+ using namespace utils;
+ using namespace format_tag;
+ using namespace rnn_utils;
+ const alg_kind_t cell_kind = this->desc()->cell_desc.cell_kind;
+
+ data_type_t src_layer_dt = this->desc()->src_layer_desc.data_type;
+ data_type_t weights_iter_dt
+ = this->desc()->weights_iter_desc.data_type;
+ data_type_t weights_layer_dt
+ = this->desc()->weights_layer_desc.data_type;
+
+ bool ok = true
+ && one_of(cell_kind, alg_kind::vanilla_rnn,
+ alg_kind::vanilla_lstm, alg_kind::vanilla_gru,
+ alg_kind::gru_linear_before_reset)
+ && IMPLICATION(aprop == prop_kind::forward,
+ one_of(this->desc()->prop_kind, forward_training,
+ forward_inference))
+ && IMPLICATION(aprop == backward,
+ one_of(this->desc()->prop_kind, backward))
+ && src_layer_dt == src_type
+ && everyone_is(
+ weights_type, weights_iter_dt, weights_layer_dt)
+ && this->set_default_params() == status::success
+ && this->with_bias();
+ if (!ok)
+ return status::unimplemented;
+
+ init_conf(rnn_, *this->desc(), this->src_md(0), this->src_md(1),
+ this->weights_md(0), this->weights_md(1), this->dst_md(0));
+
+ if (rnn_.dt_conf == all_f32)
+ ok = ok && this->attr()->has_default_values();
+
+ // Set weights descriptors to desired format
+ memory_desc_t new_weights_layer_md = *this->weights_md(0);
+ CHECK(set_expected_desc(rnn_, new_weights_layer_md, false));
+ if (this->weights_layer_md_.format_kind == format_kind::any) {
+ this->weights_layer_md_ = new_weights_layer_md;
+ } else if (this->weights_layer_md_.format_kind
+ == format_kind::rnn_packed) {
+ if (this->weights_layer_md_ != new_weights_layer_md)
+ return status::unimplemented;
+ }
+
+ memory_desc_t new_weights_iter_md = *this->weights_md(1);
+ CHECK(set_expected_desc(rnn_, new_weights_iter_md, true));
+ if (this->weights_iter_md_.format_kind == format_kind::any) {
+ this->weights_iter_md_ = new_weights_iter_md;
+ } else if (this->weights_iter_md_.format_kind
+ == format_kind::rnn_packed) {
+ if (this->weights_iter_md_ != new_weights_iter_md)
+ return status::unimplemented;
+ }
+
+ CHECK(this->check_layout_consistency());
+
+ set_conf(rnn_, *this->desc(), this->weights_md(0),
+ this->weights_md(1), this->diff_weights_md(0),
+ this->diff_weights_md(1));
+
+ size_t scratchpad_sz{0}, ws_sz{0};
+ get_scratchpad_and_workspace_sizes(rnn_, scratchpad_sz, ws_sz);
+
+ // initialize the workspace if needed
+ if (rnn_.is_training) {
+ dims_t ws_dims = { (int)ws_sz };
+ mkldnn_memory_desc_init_by_tag(&this->ws_md_, 1, ws_dims,
+ data_type::u8, format_tag::x);
+ }
+
+ init_scratchpad(scratchpad_sz);
+
+ return status::success;
+ }
+
+ rnn_utils::rnn_conf_t rnn_;
+
+ private:
+ void init_scratchpad(size_t scratchpad_sz) {
+ using namespace memory_tracking::names;
+ auto scratchpad = this->scratchpad_registry().registrar();
+ scratchpad.book(key_rnn_space, sizeof(float) * scratchpad_sz, 4096);
+
+ int max_nparts = this->cell_kind() == alg_kind::vanilla_gru ? 2 : 1;
+ int ptr_wei_sz = rnn_.n_layer * rnn_.n_dir * max_nparts;
+ scratchpad.book(key_rnn_ptrs_wei_layer,
+ sizeof(float *) * ptr_wei_sz);
+ scratchpad.book(key_rnn_ptrs_wei_iter,
+ sizeof(float *) * ptr_wei_sz);
+ scratchpad.book(key_rnn_ptrs_bia,
+ sizeof(float *) * ptr_wei_sz);
+ }
+ };
+
+ _ref_rnn_common_t(const pd_t *apd)
+ : cpu_primitive_t(apd, true), rnn_postgemm_(nullptr) {
+ /// @todo set max_feature_size assuming that we limit the number of
+ /// iterations and layer to one if slc != dic and sic != dic
+ /// respectively
+
+ bias_preparation_func = &class_name::bias_prepare;
+ bias_finalization_func = &class_name::bias_finalize;
+
+ auto set_gemm_funcs
+ = [](bool packed_gemm, gemm_t &g, weights_assign_t &a) {
+ if (packed_gemm) {
+ g = &class_name::packed_gemm;
+ a = &class_name::assign_packed_weights;
+ } else {
+ g = &class_name::gemm;
+ a = &class_name::assign_weights;
+ }
+ };
+ set_gemm_funcs(pd()->rnn_.use_iter_packed_gemm, gemm_iter_func,
+ weights_iter_assign_func);
+
+ set_gemm_funcs(pd()->rnn_.use_layer_packed_gemm, gemm_layer_func,
+ weights_layer_assign_func);
+
+ switch (pd()->cell_kind()) {
+ case alg_kind::vanilla_lstm:
+ cell_func = &class_name::cell_execution;
+ if (aprop == prop_kind::forward) {
+ if (mayiuse(avx512_core))
+ rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx512_core, src_type>(
+ pd()->rnn_, pd()->attr());
+ else if (mayiuse(avx2))
+ rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<avx2, src_type>(
+ pd()->rnn_, pd()->attr());
+ else if (mayiuse(sse42))
+ rnn_postgemm_ = new jit_uni_lstm_postgemm_kernel_fwd<sse42, src_type>(
+ pd()->rnn_, pd()->attr());
+ assert(rnn_postgemm_ != nullptr);
+ rnn_postgemm_->init();
+ }
+ elemwise_func = &class_name::lstm_elemwise;
+ break;
+ case alg_kind::vanilla_rnn: // @todo switch on cell kind
+ cell_func = &class_name::cell_execution;
+ elemwise_func = &class_name::rnn_elemwise;
+ switch (pd()->activation_kind()) {
+ case alg_kind::eltwise_relu:
+ activation_func = &activation<alg_kind::eltwise_relu, aprop>;
+ break;
+ case alg_kind::eltwise_tanh:
+ activation_func = &activation<alg_kind::eltwise_tanh, aprop>;
+ break;
+ case alg_kind::eltwise_logistic:
+ activation_func = &activation<alg_kind::eltwise_logistic, aprop>;
+ break;
+ default: break;
+ }
+ break;
+ case alg_kind::vanilla_gru:
+ cell_func = &class_name::cell_execution_gru;
+ break;
+ case alg_kind::gru_linear_before_reset:
+ cell_func = &class_name::cell_execution_gru_lbr;
+ elemwise_func = &class_name::gru_lbr_elemwise;
+ break;
+ default: break;
+ }
+
+ grid_computation = &class_name::linear_execution;
+
+ size_t scratchpad_size, workspace_size;
+ rnn_utils::set_offsets(pd()->rnn_, ws_gates_offset_, ws_states_offset_,
+ ws_c_states_offset_, ws_diff_states_offset_,
+ ws_grid_comp_offset_, ws_cell_comp_offset_,
+ ws_bias_offset_, scratchpad_size, workspace_size);
+ }
+
+ ~_ref_rnn_common_t() {}
+
+ // typedef typename prec_traits::type data_t;
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ execute_(ctx);
+ return status::success;
+ }
+
+private:
+ void execute_(const exec_ctx_t &ctx) const;
+ rnn_grid_execution_sig(linear_execution);
+ rnn_cell_execution_sig(cell_execution);
+ rnn_cell_execution_sig(cell_execution_gru);
+ rnn_cell_execution_sig(cell_execution_gru_lbr);
+ rnn_elemwise_sig(rnn_elemwise);
+ rnn_elemwise_sig(lstm_elemwise);
+ rnn_elemwise_sig(gru_lbr_elemwise);
+ rnn_gemm_sig(gemm);
+ rnn_gemm_sig(packed_gemm);
+ rnn_bias_prepare_sig(bias_prepare);
+ rnn_bias_finalize_sig(bias_finalize);
+ rnn_weights_assign_sig(assign_weights);
+ rnn_weights_assign_sig(assign_packed_weights);
+
+ float (*activation_func)(float dd, float s, float alpha, float cliping);
+
+ void copy_init_layer(const rnn_utils::rnn_conf_t &rnn,
+ src_data_t *ws_states_, float *ws_diff_states_,
+ const src_data_t *xt_, const float *diff_dst_layer) const;
+
+ template <typename input_data_t>
+ void copy_init_iter(const rnn_utils::rnn_conf_t &rnn,
+ src_data_t *ws_states_, float *ws_c_states, float *ws_diff_states_,
+ const input_data_t *firstit_states_,
+ const float *diff_dst_iter) const;
+
+ template <typename dst_data_t>
+ void copy_res_layer(const rnn_utils::rnn_conf_t &rnn,
+ dst_data_t *dst_layer_, float *diff_src_layer,
+ const src_data_t *ws_states_, const float *ws_diff_states_) const;
+
+ template <typename output_data_t>
+ void copy_res_iter(const rnn_utils::rnn_conf_t &rnn,
+ output_data_t *dst_iter_, float *diff_src_iter,
+ const src_data_t *ws_states_, float *ws_c_states,
+ const float *ws_diff_states_) const;
+
+ void gates_reduction(const rnn_utils::rnn_conf_t &rnn,
+ const acc_data_t *ws_gates_, float *diff_bias_) const;
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+
+ size_t ws_gates_offset_;
+ size_t ws_states_offset_;
+ size_t ws_c_states_offset_;
+ size_t ws_bias_offset_;
+ size_t ws_diff_states_offset_;
+ size_t ws_grid_comp_offset_;
+ size_t ws_cell_comp_offset_;
+ jit_uni_rnn_postgemm_kernel *rnn_postgemm_;
+
+ grid_execution_f grid_computation;
+ cell_execution_f cell_func;
+
+ bias_prepare_t bias_preparation_func;
+ bias_finalize_t bias_finalization_func;
+ weights_assign_t weights_layer_assign_func;
+ weights_assign_t weights_iter_assign_func;
+
+ gemm_t gemm_layer_func;
+ gemm_t gemm_iter_func;
+ elemwise_f elemwise_func;
+};
+
+using ref_rnn_fwd_f32_t = _ref_rnn_common_t<prop_kind::forward, data_type::f32, data_type::f32>;
+using ref_rnn_bwd_f32_t = _ref_rnn_common_t<prop_kind::backward, data_type::f32, data_type::f32>;
+using ref_rnn_fwd_u8s8_t = _ref_rnn_common_t<prop_kind::forward, data_type::u8, data_type::s8>;
+}
+}
+}
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp
new file mode 100644
index 0000000000..597c63e3f8
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_reorders.hpp
@@ -0,0 +1,380 @@
+/*******************************************************************************
+ * Copyright 2018 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#ifndef CPU_RNN_REORDERS_HPP
+#define CPU_RNN_REORDERS_HPP
+
+#include <assert.h>
+
+#include "type_helpers.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+#include "simple_q10n.hpp"
+#include "cpu_reorder_pd.hpp"
+#include "../gemm/os_blas.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t type_i, data_type_t type_o>
+struct rnn_data_reorder_t : public cpu_primitive_t {
+ struct pd_t : public cpu_reorder_pd_t {
+ using cpu_reorder_pd_t::cpu_reorder_pd_t;
+
+ DECLARE_COMMON_PD_T("rnn_data_reorder", rnn_data_reorder_t);
+
+ static status_t create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+ const memory_desc_wrapper id(src_md), od(dst_md);
+ bool args_ok = true
+ && id.data_type() == type_i
+ && od.data_type() == type_o
+ && id.matches_one_of_tag(format_tag::tnc, format_tag::ldsnc)
+ && od == id;
+ if (!args_ok) return status::invalid_arguments;
+
+ auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
+ dst_md);
+ if (_pd == nullptr) return out_of_memory;
+ if (_pd->init() != success) { delete _pd; return unimplemented; }
+ return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
+ }
+ };
+
+private:
+ typedef typename prec_traits<type_i>::type in_data_t;
+ typedef typename prec_traits<type_o>::type out_data_t;
+
+ rnn_data_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
+ auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO);
+ const memory_desc_wrapper &input_d = pd()->src_md();
+ const memory_desc_wrapper &output_d = pd()->dst_md();
+ const size_t nelems = input_d.nelems();
+ const float scale = pd()->attr()->rnn_data_qparams_.scale_;
+ const float shift = pd()->attr()->rnn_data_qparams_.shift_;
+
+ parallel_nd(nelems, [&](size_t i) {
+ float in = (float)input[input_d.off_l(i)] * scale + shift;
+ output[output_d.off_l(i)] = qz_a1b0<float, out_data_t>()(in);
+ });
+
+ return status::success;
+ }
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <data_type_t type_i, data_type_t type_o>
+struct rnn_weights_reorder_t : public cpu_primitive_t {
+ struct pd_t : public cpu_reorder_pd_t {
+ using cpu_reorder_pd_t::cpu_reorder_pd_t;
+
+ DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
+
+ static status_t create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+#if !USE_MKL_PACKED_GEMM
+ return status::unimplemented;
+#endif
+ const memory_desc_wrapper id(src_md), od(dst_md);
+ bool args_ok = true
+ && id.data_type() == type_i
+ && od.data_type() == type_o
+ && od.format_kind() == format_kind::rnn_packed
+ && od.rnn_packed_desc().format == mkldnn_ldigo_p
+ && od.rnn_packed_desc().n_parts == 1
+ && attr != nullptr;
+ if (!args_ok) return status::invalid_arguments;
+
+ format_tag_t itag = id.matches_one_of_tag(
+ format_tag::ldigo, format_tag::ldgoi);
+ if (itag == format_tag::undef) return status::invalid_arguments;
+
+ const int mask = attr->rnn_weights_qparams_.mask_;
+ if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
+
+ auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
+ dst_md);
+ if (_pd == nullptr) return out_of_memory;
+ _pd->itag_ = itag;
+ if (_pd->init() != success) { delete _pd; return unimplemented; }
+ return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
+ }
+
+ status_t init() {
+ status_t status = cpu_reorder_pd_t::init();
+ if (status != status::success) return status;
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ format_tag_t itag_;
+
+ private:
+ void init_scratchpad() {
+ const memory_desc_wrapper id(src_md());
+ const size_t nelems = id.nelems();
+ const auto &dims = id.dims();
+
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ size_t quantization_size = sizeof(int8_t) * nelems;
+ size_t reduction_size = itag_ == ldigo
+ ? sizeof(int32_t) * mkldnn_get_max_threads() * dims[0]
+ * dims[1] * dims[3] * dims[4]
+ : 0;
+ scratchpad.book(
+ key_reorder_rnn_weights_quantization, quantization_size);
+ scratchpad.book(key_reorder_rnn_weights_reduction, reduction_size);
+ }
+ };
+
+private:
+ typedef typename prec_traits<type_i>::type in_data_t;
+ typedef typename prec_traits<type_o>::type out_data_t;
+
+ rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+#if USE_MKL_PACKED_GEMM
+ auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
+ auto output = CTX_OUT_MEM(char *, MKLDNN_ARG_TO);
+ const memory_desc_wrapper &input_d = pd()->src_md();
+ const memory_desc_wrapper &output_d = pd()->dst_md();
+ const auto &dims = input_d.dims();
+
+ const int L = dims[0];
+ const int D = dims[1];
+ const int I = dims[2];
+ const int G = dims[3];
+ const int O = dims[4];
+
+ const bool is_igo = pd()->itag_ == format_tag::ldigo;
+
+ /* Quantize input & compute compensation */
+ auto quantized = (int8_t * __restrict)scratchpad(ctx).template get<void>(
+ memory_tracking::names::key_reorder_rnn_weights_quantization);
+ auto reduction = (int32_t * __restrict)scratchpad(ctx).template get<void>(
+ memory_tracking::names::key_reorder_rnn_weights_reduction);
+ float *comp = reinterpret_cast<float *>(
+ output + output_d.rnn_packed_desc().offset_compensation);
+ const float *scales = pd()->attr()->rnn_weights_qparams_.scales_;
+ const int mask = pd()->attr()->rnn_weights_qparams_.mask_;
+
+ if (is_igo) {
+ int nthr = mkldnn_get_max_threads();
+ int LD_nthr = nstl::min(L * D, nthr);
+ int I_nthr = nstl::min(I, nthr / LD_nthr);
+ parallel(nthr, [&](const int ithr, const int nthr) {
+ int LD_ithr = -1, LD_s = -1, LD_e = -1;
+ int I_ithr = -1, I_s = -1, I_e = -1;
+ if (ithr < LD_nthr * I_nthr) {
+ LD_ithr = ithr % LD_nthr;
+ I_ithr = ithr / LD_nthr;
+ balance211(L * D, LD_nthr, LD_ithr, LD_s, LD_e);
+ balance211(I, I_nthr, I_ithr, I_s, I_e);
+ }
+ int32_t *comp_ithr = reduction + I_ithr * L * D * G * O;
+ for (int ld = LD_s; ld < LD_e; ld++) {
+ for (int go = 0; go < G * O; go++)
+ comp_ithr[ld * G * O + go] = 0;
+ for (int i = I_s; i < I_e; i++) {
+ PRAGMA_OMP_SIMD()
+ for (int go = 0; go < G * O; go++) {
+ const float s = scales[(mask == 0) ? 0 : go];
+ int8_t q = qz_b0<in_data_t, out_data_t>()(
+ input[ld * I * G * O + i * G * O + go], s);
+ quantized[ld * I * G * O + i * G * O + go]
+ = (int32_t)q;
+ comp_ithr[ld * G * O + go] += (int32_t)q;
+ }
+ }
+ }
+ });
+ parallel_nd(L * D * G * O,
+ [&](int s) { comp[s] = saturate<float>(reduction[s]); });
+ for (int i = 1; i < I_nthr; i++) {
+ parallel_nd(L * D * G * O, [&](int s) {
+ comp[s] += saturate<float>(
+ reduction[i * L * D * G * O + s]);
+ });
+ }
+ } else {
+ parallel_nd(L * D, G * O, [&](int ld, int go) {
+ int32_t compensation = 0;
+ const float s = scales[(mask == 0) ? 0 : go];
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < I; i++) {
+ int8_t q = qz_b0<in_data_t, out_data_t>()(
+ input[ld * G * O * I + go * I + i], s);
+ compensation += (int32_t)q;
+ quantized[ld * G * O * I + go * I + i] = q;
+ }
+ comp[ld * G * O + go] = saturate<float>(compensation);
+ });
+ }
+
+ /* Pack */
+ auto off_igo = [&](int l, int d, int i, int g, int o) {
+ return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
+ };
+ auto off_goi = [&](int l, int d, int i, int g, int o) {
+ return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
+ };
+ int n_parts = output_d.rnn_packed_desc().n_parts;
+ const size_t *size_packed_cell
+ = output_d.rnn_packed_desc().part_pack_size;
+ const int *parts = output_d.rnn_packed_desc().parts;
+ const int n = output_d.rnn_packed_desc().n;
+ char *to_pack = output;
+ for (int l = 0; l < L; l++) {
+ for (int d = 0; d < D; d++) {
+ for (int p = 0; p < n_parts; p++) {
+ int g = (p > 0) ? parts[p - 1] : 0;
+ int m_p = parts[p] * O;
+ int k_p = I;
+ cblas_gemm_s8u8s32_pack(CblasColMajor, CblasAMatrix,
+ is_igo ? CblasNoTrans : CblasTrans, m_p, n, k_p,
+ &quantized[is_igo ? off_igo(l, d, 0, g, 0) :
+ off_goi(l, d, g, 0, 0)],
+ is_igo ? G * O : I, to_pack);
+ to_pack += size_packed_cell[p];
+ }
+ }
+ }
+#endif
+ return status::success;
+ }
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+template <>
+struct rnn_weights_reorder_t<data_type::f32, data_type::f32>
+ : public cpu_primitive_t {
+ struct pd_t : public cpu_reorder_pd_t {
+ using cpu_reorder_pd_t::cpu_reorder_pd_t;
+
+ DECLARE_COMMON_PD_T("rnn_weights_reorder", rnn_weights_reorder_t);
+
+ static status_t create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+#if !USE_MKL_PACKED_GEMM
+ return status::unimplemented;
+#endif
+ const memory_desc_wrapper id(src_md), od(dst_md);
+ bool args_ok = true
+ && id.data_type() == data_type::f32
+ && od.data_type() == data_type::f32
+ && od.format_kind() == format_kind::rnn_packed
+ && utils::one_of(od.rnn_packed_desc().format,
+ mkldnn_ldigo_p, mkldnn_ldgoi_p)
+ && attr->has_default_values();
+ if (!args_ok) return status::invalid_arguments;
+
+ format_tag_t itag = id.matches_one_of_tag(
+ format_tag::ldigo, format_tag::ldgoi);
+ if (itag == format_tag::undef) return status::invalid_arguments;
+
+ const int mask = attr->rnn_weights_qparams_.mask_;
+ if (!utils::one_of(mask, 0, 3)) return status::unimplemented;
+
+ auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
+ dst_md);
+ if (_pd == nullptr) return out_of_memory;
+ if (_pd->init() != success) { delete _pd; return unimplemented; }
+ _pd->itag_ = itag;
+ return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
+ }
+
+ format_tag_t itag_;
+ };
+
+private:
+ rnn_weights_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+#if USE_MKL_PACKED_GEMM
+ auto input = CTX_IN_MEM(const float *, MKLDNN_ARG_FROM);
+ auto output = CTX_OUT_MEM(float *, MKLDNN_ARG_TO);
+ const memory_desc_wrapper &input_d = pd()->src_md();
+ const memory_desc_wrapper &output_d = pd()->dst_md();
+ const auto &dims = input_d.dims();
+ const rnn_packed_desc_t &rnn_pdata = output_d.rnn_packed_desc();
+ const int L = dims[0];
+ const int D = dims[1];
+ const int I = dims[2];
+ const int G = dims[3];
+ const int O = dims[4];
+
+ /* Pack */
+ bool cross_case = false
+ || (pd()->itag_ == format_tag::ldigo
+ && rnn_pdata.format == mkldnn_ldgoi_p)
+ || (pd()->itag_ == format_tag::ldgoi
+ && rnn_pdata.format == mkldnn_ldigo_p);
+ auto trans = cross_case ? CblasTrans : CblasNoTrans;
+ int n_parts = rnn_pdata.n_parts;
+ const size_t *size_packed_cell = rnn_pdata.part_pack_size;
+ const int *parts = rnn_pdata.parts;
+ const int n = rnn_pdata.n;
+
+ const bool is_igo = pd()->itag_ == format_tag::ldigo;
+ auto off_igo = [&](int l, int d, int i, int g, int o) {
+ return l * D * I * G * O + d * I * G * O + i * G * O + g * O + o;
+ };
+ auto off_goi = [&](int l, int d, int i, int g, int o) {
+ return l * D * G * O * I + d * G * O * I + g * O * I + o * I + i;
+ };
+ for (int l = 0; l < L; l++) {
+ for (int d = 0; d < D; d++) {
+ for (int p = 0; p < n_parts; p++) {
+ int g = (p > 0) ? parts[p - 1] : 0;
+ int m_p = is_igo ? parts[p] * O : I;
+ int k_p = is_igo ? I : parts[p] * O;
+ int ld = is_igo ? G * O : I;
+ cblas_sgemm_pack(CblasColMajor, CblasAMatrix, trans, m_p, n,
+ k_p, 1.0f, &input[is_igo ? off_igo(l, d, 0, g, 0) :
+ off_goi(l, d, 0, g, 0)],
+ ld, output);
+ output += size_packed_cell[p] / sizeof(float);
+ }
+ }
+ }
+#endif
+ return status::success;
+ }
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+} // namespace cpu
+} // namespace impl
+} // namespace mkldnn
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp
new file mode 100644
index 0000000000..1d60415cbc
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.cpp
@@ -0,0 +1,426 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "c_types_map.hpp"
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+
+#include "ref_rnn.hpp"
+#include "rnn_utils.hpp"
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::utils;
+using namespace rnn_utils;
+using namespace format_tag;
+using namespace rnn_packed_format;
+using namespace data_type;
+
+bool rnn_utils::is_ldigo(const memory_desc_wrapper &md) {
+ if (md.format_kind() != format_kind::blocked)
+ return false;
+
+ auto blk = md.blocking_desc();
+ auto str = blk.strides;
+ auto dims = md.dims();
+ return md.ndims() == 5 && blk.inner_nblks == 0 && str[4] == 1
+ && str[3] == dims[4] && str[1] == str[2] * dims[2]
+ && str[0] == str[1] * dims[1];
+};
+
+bool rnn_utils::is_ldgoi(const memory_desc_wrapper &md) {
+ if (md.format_kind() != format_kind::blocked)
+ return false;
+
+ auto blk = md.blocking_desc();
+ auto str = blk.strides;
+ auto dims = md.dims();
+ return md.ndims() == 5 && blk.inner_nblks == 0 && str[2] == 1
+ && str[3] == dims[4] * str[4] && str[1] == str[3] * dims[3]
+ && str[0] == str[1] * dims[1];
+};
+
+void rnn_utils::init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
+ const memory_desc_wrapper &src_layer_d,
+ const memory_desc_wrapper &src_iter_d,
+ const memory_desc_wrapper &weights_layer_d,
+ const memory_desc_wrapper &weights_iter_d,
+ const memory_desc_wrapper &dst_layer_d) {
+ rnn.is_fwd = utils::one_of(rd.prop_kind, prop_kind::forward_training,
+ prop_kind::forward_inference);
+ rnn.is_training = utils::one_of(
+ rd.prop_kind, prop_kind::forward_training, prop_kind::backward);
+ rnn.is_lbr = rd.cell_desc.cell_kind == mkldnn_gru_linear_before_reset;
+
+ switch (rd.direction) {
+ case mkldnn_unidirectional_left2right: rnn.exec_dir = l2r; break;
+ case mkldnn_unidirectional_right2left: rnn.exec_dir = r2l; break;
+ case mkldnn_bidirectional_concat: rnn.exec_dir = bi_concat; break;
+ case mkldnn_bidirectional_sum: rnn.exec_dir = bi_sum; break;
+ default: break;
+ }
+
+ if (everyone_is(f32, src_layer_d.data_type(), dst_layer_d.data_type(),
+ weights_layer_d.data_type()))
+ rnn.dt_conf = all_f32;
+ else if (dst_layer_d.data_type() == u8) {
+ if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
+ rnn.dt_conf = u8u8u8u8;
+ else
+ rnn.dt_conf = f32u8f32u8;
+ } else {
+ if (IMPLICATION(src_iter_d.md_, src_iter_d.data_type() == u8))
+ rnn.dt_conf = u8u8u8f32;
+ else
+ rnn.dt_conf = f32u8f32f32;
+ }
+
+ rnn.n_layer = weights_layer_d.dims()[0];
+ rnn.n_iter = src_layer_d.dims()[0];
+ rnn.n_dir = weights_layer_d.dims()[1];
+ rnn.n_gates = weights_layer_d.dims()[3];
+ rnn.n_states = mkldnn_rnn_cell_get_states_count(&rd.cell_desc);
+ rnn.n_bias = rnn.n_gates + rnn.is_lbr;
+ rnn.mb = src_layer_d.dims()[1];
+ rnn.sic = weights_iter_d.dims()[2];
+ rnn.slc = weights_layer_d.dims()[2];
+ rnn.dic = weights_layer_d.dims()[4];
+ rnn.dlc = dst_layer_d.dims()[2];
+
+ rnn.gates_ld = rnn.dic * rnn.n_gates;
+ rnn.gates_nld = rnn.mb;
+ rnn.states_nld = rnn.mb;
+
+ /* Set the correct number of weights parts */
+ bool is_orig_gru = rd.cell_desc.cell_kind == alg_kind::vanilla_gru;
+ rnn.n_parts_weights_layer = 1;
+ rnn.parts_weights_layer[0] = rnn.n_gates;
+ rnn.parts_weights_layer[1] = 0;
+
+ rnn.n_parts_weights_iter = is_orig_gru ? 2 : 1;
+ rnn.parts_weights_iter[0] = is_orig_gru ? 2 : rnn.n_gates;
+ rnn.parts_weights_iter[1] = is_orig_gru ? 1 : 0;
+
+ rnn.n_parts_bias = 1;
+ rnn.parts_bias[0] = rnn.n_bias;
+ rnn.parts_bias[1] = 0;
+
+ /* Decide wich gemm implementation to use: packed/nonpacked jit/cblas
+ * and if to mergre gemm across iterations */
+ bool is_int8 = rnn.dt_conf != all_f32;
+ rnn.merge_gemm_layer = ((rnn.is_fwd && rnn.mb < 128) || !rnn.is_fwd)
+ || is_int8;
+ bool is_gru = utils::one_of(rd.cell_desc.cell_kind, alg_kind::vanilla_gru,
+ alg_kind::gru_linear_before_reset);
+ rnn.merge_gemm_iter = !(rnn.is_fwd || is_gru) || is_int8;
+ bool is_inference = !rnn.is_training;
+
+ rnn.use_jit_gemm = !mayiuse(avx512_mic)
+ && ((is_inference && (rnn.n_layer > 1 || rnn.mb < 100))
+ || (rnn.is_training && rnn.dic < 500));
+
+ /* Decide to copy bias */
+ rnn.copy_bias = rnn.dt_conf != all_f32;
+
+#if USE_MKL_PACKED_GEMM
+ rnn.use_layer_packed_gemm
+ = (weights_layer_d.format_kind() == format_kind::any
+ && rnn.slc > 760 && rnn.dic > 760 && is_inference)
+ || is_int8; // packed gemm is the only supported option for int8
+ rnn.use_iter_packed_gemm
+ = (weights_iter_d.format_kind() == format_kind::any && rnn.sic > 760
+ && rnn.dic > 760 && is_inference)
+ || is_int8;
+#else
+ rnn.use_layer_packed_gemm = false;
+ rnn.use_iter_packed_gemm = false;
+#endif
+
+ /* Set packed gemm sizes */
+ if (rnn.use_layer_packed_gemm) {
+ rnn.weights_layer_pack_size = 0;
+ for (int p = 0; p < rnn.n_parts_weights_layer; p++) {
+ int m_p = rnn.is_fwd
+ ? (rnn.parts_weights_layer[p] * rnn.dic)
+ : rnn.slc;
+ int k_p = rnn.is_fwd
+ ? rnn.slc
+ : (rnn.parts_weights_layer[p] * rnn.dic);
+ int n_p = rnn.merge_gemm_layer ? rnn.mb * rnn.n_iter : rnn.mb;
+
+#if USE_MKL_PACKED_GEMM
+ if (rnn.dt_conf == all_f32)
+ rnn.part_weights_layer_pack_size[p] = cblas_sgemm_pack_get_size(
+ CblasAMatrix, m_p, n_p, k_p);
+ else
+ rnn.part_weights_layer_pack_size[p]
+ = cblas_gemm_s8u8s32_pack_get_size(
+ CblasAMatrix, m_p, n_p, k_p);
+#else
+ UNUSED(m_p);
+ UNUSED(k_p);
+ UNUSED(n_p);
+ rnn.part_weights_layer_pack_size[p] = 0;
+#endif
+ rnn.weights_layer_pack_size += rnn.n_layer * rnn.n_dir
+ * rnn.part_weights_layer_pack_size[p];
+ }
+ rnn.weights_layer_comp_offset = rnn.weights_layer_pack_size;
+ rnn.weights_layer_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
+ * rnn.n_dir * rnn.n_gates * rnn.dlc * sizeof(float);
+ }
+
+ if (rnn.use_iter_packed_gemm) {
+ rnn.weights_iter_pack_size = 0;
+ for (int p = 0; p < rnn.n_parts_weights_iter; p++) {
+ int m_p = rnn.is_fwd ? (rnn.parts_weights_iter[p] * rnn.dic) :
+ rnn.sic;
+ int k_p = rnn.is_fwd ? rnn.sic :
+ (rnn.parts_weights_iter[p] * rnn.dic);
+ int n_p = rnn.merge_gemm_iter ? rnn.mb * rnn.n_iter : rnn.mb;
+
+#if USE_MKL_PACKED_GEMM
+ if (rnn.dt_conf == all_f32)
+ rnn.part_weights_iter_pack_size[p] = cblas_sgemm_pack_get_size(
+ CblasAMatrix, m_p, n_p, k_p);
+ else
+ rnn.part_weights_iter_pack_size[p]
+ = cblas_gemm_s8u8s32_pack_get_size(
+ CblasAMatrix, m_p, n_p, k_p);
+#else
+ UNUSED(m_p);
+ UNUSED(k_p);
+ UNUSED(n_p);
+ rnn.part_weights_iter_pack_size[p] = 0;
+#endif
+ rnn.weights_iter_pack_size += rnn.n_layer * rnn.n_dir
+ * rnn.part_weights_iter_pack_size[p];
+ }
+ rnn.weights_iter_comp_offset = rnn.weights_iter_pack_size;
+ rnn.weights_iter_pack_size += rnn.dt_conf == all_f32 ? 0 : rnn.n_layer
+ * rnn.n_dir * rnn.n_gates * rnn.dic * sizeof(float);
+ }
+
+}
+
+void rnn_utils::set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
+ const memory_desc_wrapper &weights_layer_d,
+ const memory_desc_wrapper &weights_iter_d,
+ const memory_desc_wrapper &diff_weights_layer_d,
+ const memory_desc_wrapper &diff_weights_iter_d) {
+
+ /* Set leading dimensions for input weights arrays depending on input format
+ */
+ rnn.weights_layer_is_packed
+ = weights_layer_d.format_kind() == format_kind::rnn_packed;
+ rnn.weights_iter_is_packed
+ = weights_iter_d.format_kind() == format_kind::rnn_packed;
+
+ auto set_dims = [&](const memory_desc_wrapper &md, int &ld, int &nld) {
+ ld = 0; nld = 0;
+ if (md.is_blocking_desc()) {
+ if (is_ldigo(md)) {
+ ld = (int)md.blocking_desc().strides[2];
+ nld = md.dims()[2];
+ } else if (is_ldgoi(md)) {
+ ld = (int)md.blocking_desc().strides[4];
+ nld = md.dims()[3] * md.dims()[4];
+ } else
+ assert(!"unsupported weights format");
+ }
+ };
+ set_dims(weights_layer_d, rnn.weights_layer_ld, rnn.weights_layer_nld);
+ set_dims(weights_iter_d, rnn.weights_iter_ld, rnn.weights_iter_nld);
+ if (!rnn.is_fwd) {
+ set_dims(diff_weights_layer_d, rnn.diff_weights_layer_ld,
+ rnn.diff_weights_layer_nld);
+ set_dims(diff_weights_iter_d, rnn.diff_weights_iter_ld,
+ rnn.diff_weights_iter_nld);
+ }
+
+ int sizeof_states_dt
+ = rnn.dt_conf == all_f32 ? sizeof(float) : sizeof(uint8_t);
+ rnn.states_ws_ld
+ = get_good_ld(nstl::max(rnn.slc, nstl::max(rnn.sic, rnn.dic)),
+ sizeof_states_dt);
+ rnn.gates_ws_ld = get_good_ld(rnn.gates_ld, sizeof(float));
+
+ /* Set workspace sizes to store:
+ * states to copmute a pass
+ * diff states to copmute bwd pass (training only)
+ * intermediate results from the gates
+ */
+ rnn.use_workspace = rnn.is_training;
+ rnn.ws_states_size = (size_t)(rnn.n_layer + 1) * rnn.n_dir
+ * (rnn.n_iter + 1) * rnn.mb * rnn.states_ws_ld * sizeof_states_dt;
+ bool is_lstm = rd.cell_desc.cell_kind == mkldnn_vanilla_lstm;
+ rnn.ws_c_states_size = is_lstm
+ ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1) * rnn.mb
+ * rnn.states_ws_ld * sizeof(float)
+ : 0;
+ rnn.ws_diff_states_size = rnn.is_training
+ ? (size_t)(rnn.n_layer + 1) * rnn.n_dir * (rnn.n_iter + 1)
+ * (rnn.n_states + 1) * rnn.mb * rnn.states_ws_ld
+ * sizeof(float)
+ : (size_t)0;
+ rnn.ws_gates_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_iter * rnn.mb
+ * rnn.gates_ws_ld * sizeof(float);
+
+ /* set other sizes */
+ rnn.ws_per_cell = (size_t)rnn.is_lbr * rnn.mb * rnn.dic * sizeof(float);
+ rnn.ws_cell_comp_size
+ = rnn.is_lbr || rnn.dt_conf != all_f32
+ ? (size_t) rnn.gates_nld * rnn.gates_ws_ld * sizeof(float)
+ : 0;
+ rnn.ws_grid_comp_size = (size_t)rnn.is_lbr * rnn.is_training * rnn.n_layer
+ * rnn.n_dir * rnn.n_iter * rnn.ws_per_cell * sizeof(float);
+ rnn.ws_bias_size = (size_t)rnn.n_layer * rnn.n_dir * rnn.n_bias * rnn.dic
+ * sizeof(float);
+}
+
+int rnn_utils::get_good_ld(int dim, int sizeof_dt) {
+ // we want matrices leading dimentions to be 64-byte aligned,
+ // and not divisible by 256 to avoid 4K aliasing effects
+ int ld = rnd_up(dim, 64 / sizeof_dt);
+ return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld;
+}
+
+void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
+ size_t &ws_states_offset, size_t &ws_c_states_offset,
+ size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
+ size_t &ws_cell_comp_offset, size_t &ws_bias_offset,
+ size_t &scratchpad_size, size_t &workspace_size) {
+
+ const size_t page_size = 4096; // 2097152;
+ size_t current_offset;
+ /* Mandatory workspaces: go to workspace if use_workspace, scratchpad
+ * otherwise */
+ current_offset = 0; // assumes the workspace base pointer is page aligned
+ ws_gates_offset = current_offset;
+ current_offset += rnn.ws_gates_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_states_offset = current_offset;
+ current_offset += rnn.ws_states_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_c_states_offset = current_offset;
+ current_offset += rnn.ws_c_states_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_diff_states_offset = current_offset;
+ current_offset += rnn.ws_diff_states_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_grid_comp_offset = current_offset;
+ current_offset += rnn.ws_grid_comp_size;
+
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_cell_comp_offset = current_offset;
+ current_offset += rnn.ws_cell_comp_size;
+
+ workspace_size = rnn.use_workspace ? current_offset : 0;
+
+ /* Optional scratchpads */
+ // Assumes the scratchpad base pointer is page aligned.
+ // If use_workspace, the following goes to scratchpad alone,
+ // otherwise, all goes to scratchpad and continue incrementing offset
+ current_offset = rnn.use_workspace ? 0 : current_offset;
+
+ if (rnn.copy_bias) {
+ current_offset = utils::rnd_up(current_offset, page_size);
+ ws_bias_offset = current_offset;
+ current_offset += rnn.ws_bias_size;
+ }
+
+ scratchpad_size = current_offset;
+}
+
+void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
+ size_t &scratchpad_size, size_t &workspace_size) {
+ size_t ws_gates_offset, ws_states_offset, ws_c_states_offset,
+ ws_diff_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
+ ws_bias_offset;
+ set_offsets(rnn, ws_gates_offset, ws_states_offset, ws_diff_states_offset,
+ ws_c_states_offset, ws_grid_comp_offset, ws_cell_comp_offset,
+ ws_bias_offset, scratchpad_size, workspace_size);
+}
+
+status_t rnn_utils::set_good_strides(
+ memory_desc_t &weights_md, format_tag_t tag) {
+ auto &strides = weights_md.format_desc.blocking.strides;
+ auto dims = weights_md.dims;
+
+ if (tag == ldigo) {
+ strides[2] = rnn_utils::get_good_ld((int)strides[2],
+ (int)types::data_type_size(weights_md.data_type));
+ strides[1] = dims[2] * strides[2];
+ strides[0] = dims[1] * strides[1];
+ } else if (tag == ldgoi) {
+ strides[4] = rnn_utils::get_good_ld((int)strides[4],
+ (int)types::data_type_size(weights_md.data_type));
+ strides[3] = dims[4] * strides[4];
+ strides[1] = dims[3] * strides[3];
+ strides[0] = dims[1] * strides[1];
+ } else
+ return status::unimplemented;
+
+ return status::success;
+}
+
+status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,
+ memory_desc_t &weights_md, bool is_iter) {
+ using namespace format_tag;
+ bool use_packed_gemm = is_iter
+ ? rnn.use_iter_packed_gemm
+ : rnn.use_layer_packed_gemm;
+ if (use_packed_gemm) {
+ weights_md.format_kind = format_kind::rnn_packed;
+ rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc;
+ rnn_pdata.format = rnn.is_fwd ? mkldnn_ldigo_p : mkldnn_ldgoi_p;
+ if (is_iter) {
+ rnn_pdata.n = rnn.mb;
+ rnn_pdata.n_parts = rnn.n_parts_weights_iter;
+ array_copy(rnn_pdata.parts, rnn.parts_weights_iter,
+ MKLDNN_RNN_MAX_N_PARTS);
+ array_copy(rnn_pdata.part_pack_size,
+ rnn.part_weights_iter_pack_size, MKLDNN_RNN_MAX_N_PARTS);
+ rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset;
+ rnn_pdata.size = rnn.weights_iter_pack_size;
+ } else {
+ rnn_pdata.n = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb;
+ rnn_pdata.n_parts = rnn.n_parts_weights_layer;
+ array_copy(rnn_pdata.parts, rnn.parts_weights_layer,
+ MKLDNN_RNN_MAX_N_PARTS);
+ array_copy(rnn_pdata.part_pack_size,
+ rnn.part_weights_layer_pack_size, MKLDNN_RNN_MAX_N_PARTS);
+ rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset;
+ rnn_pdata.size = rnn.weights_layer_pack_size;
+ }
+ } else {
+ CHECK(memory_desc_init_by_tag(weights_md, rnn.is_fwd ? ldigo : ldgoi));
+ // Adjust strides for good leading dimension in GEMM
+ CHECK(set_good_strides(weights_md, rnn.is_fwd ? ldigo : ldgoi));
+ }
+ return status::success;
+}
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp
new file mode 100644
index 0000000000..99eb787a64
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/rnn/rnn_utils.hpp
@@ -0,0 +1,225 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef RNN_UTILS_HPP
+#define RNN_UTILS_HPP
+
+#include "mkldnn.h"
+
+#include "cpu_rnn_pd.hpp"
+
+
+#define rnn_elemwise_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, acc_data_t *ws_gates_, \
+ src_data_t *states_t_l_, float *c_states_t_l_, \
+ src_data_t *states_tm1_l_, float *c_states_tm1_l_, \
+ float *diff_states_t_l_, float *diff_states_t_lp1_, \
+ float *diff_states_tp1_l_, float *bias_, float *ws_grid_, \
+ float *ws_cell_) const
+
+#define rnn_cell_execution_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, src_data_t *states_t_l_, \
+ float *c_states_t_l_, float *diff_states_t_l_, \
+ weights_data_t **w_layer_, weights_data_t **w_iter_, \
+ float **bias_, src_data_t *states_t_lm1_, \
+ src_data_t *states_tm1_l_, float *c_states_tm1_l_, \
+ float *diff_states_t_lp1_, float *diff_states_tp1_l_, \
+ float *diff_w_layer_, float *diff_w_iter_, float *diff_bias_, \
+ acc_data_t *ws_gates_, float *ws_grid_, float *ws_cell_) const
+
+#define rnn_grid_execution_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, weights_data_t **weights_layer_, \
+ weights_data_t **weights_states_, float **bias_, \
+ src_data_t *ws_states_, float *ws_c_states_, \
+ float *ws_diff_states_, acc_data_t *ws_gates_, float *ws_cell_, \
+ float *ws_grid_, float *diff_weights_layer_, \
+ float *diff_weights_iter_, float *diff_bias_) const
+
+#define rnn_gemm_sig(f) \
+ void f(const char transA, const char transB, int m, int n, int k, \
+ const float alpha, const weights_data_t *a_, const int ldA, \
+ const src_data_t *b_, const int ldB, const float beta, \
+ acc_data_t *c_, const int ldC) const
+
+#define rnn_bias_prepare_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, float **bias_, const float *b_, \
+ float *scratch_bias_) const
+
+#define rnn_bias_finalize_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, float *scratch_bias_, \
+ const float *w_iter_comp, const float *w_layer_comp) const
+
+#define rnn_weights_assign_sig(f) \
+ void f(const rnn_utils::rnn_conf_t &rnn, const memory_desc_t *md, int nld, \
+ int ld, int OC_size, int IC_size, const int n_parts, \
+ const int *gates_per_part, const size_t *part_weights_pack_size, \
+ weights_data_t **weights_, const weights_data_t *w_, \
+ float **bias_, const float *b_, float *scratch_bias_) const
+
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+namespace rnn_utils {
+
+using namespace mkldnn::impl::utils;
+
+enum execution_direction_t {
+ l2r,
+ r2l,
+ bi_concat,
+ bi_sum,
+};
+
+enum data_type_conf_t {
+ all_f32,
+ u8u8u8f32,
+ f32u8f32f32,
+ u8u8u8u8,
+ f32u8f32u8
+};
+
+struct rnn_conf_t {
+ execution_direction_t exec_dir;
+ data_type_conf_t dt_conf;
+ int n_layer, n_iter, n_dir, n_gates, n_states;
+ int mb;
+ int slc, sic, dic, dlc;
+ int gates_ld, gates_nld, gates_ws_ld;
+ int n_parts_weights_layer, parts_weights_layer[MKLDNN_RNN_MAX_N_PARTS];
+ int n_parts_weights_iter, parts_weights_iter[MKLDNN_RNN_MAX_N_PARTS];
+ int n_bias, n_parts_bias, parts_bias[MKLDNN_RNN_MAX_N_PARTS];
+ size_t part_weights_iter_pack_size[MKLDNN_RNN_MAX_N_PARTS],
+ part_weights_layer_pack_size[MKLDNN_RNN_MAX_N_PARTS];
+ bool weights_layer_is_packed, weights_iter_is_packed;
+ /* Size of packed data in bytes */
+ size_t weights_layer_comp_offset, weights_layer_pack_size,
+ weights_iter_comp_offset, weights_iter_pack_size;
+
+ bool copy_bias;
+ int weights_layer_ld, weights_layer_nld;
+ int diff_weights_layer_ld, diff_weights_layer_nld;
+ int weights_iter_ld, weights_iter_nld;
+ int diff_weights_iter_ld, diff_weights_iter_nld;
+ int states_nld, states_ws_ld;
+ int weights_iter_compensation_size, weights_layer_compensation_size;
+ bool is_fwd, is_training, is_lbr;
+ bool use_workspace;
+
+ /* Size of workspace for each tensor in bytes */
+ size_t ws_gates_size, ws_states_size, ws_c_states_size, ws_diff_states_size,
+ ws_cell_comp_size, ws_grid_comp_size, ws_per_cell, ws_bias_size;
+ bool merge_gemm_iter, merge_gemm_layer, use_jit_gemm, use_layer_packed_gemm,
+ use_iter_packed_gemm;
+};
+
+bool is_ldigo(const memory_desc_wrapper &md);
+bool is_ldgoi(const memory_desc_wrapper &md);
+
+int get_good_ld(int dim, int sizeof_dt);
+
+void init_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
+ const memory_desc_wrapper &src_layer_d,
+ const memory_desc_wrapper &src_iter_d,
+ const memory_desc_wrapper &weights_layer_d,
+ const memory_desc_wrapper &weights_iter_d,
+ const memory_desc_wrapper &dst_layer_d);
+
+void set_conf(rnn_conf_t &rnn, const rnn_desc_t &rd,
+ const memory_desc_wrapper &weights_layer_d,
+ const memory_desc_wrapper &weights_iter_d,
+ const memory_desc_wrapper &diff_weights_layer_d,
+ const memory_desc_wrapper &diff_weights_iter_d);
+
+void set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
+ size_t &ws_h_state_offset, size_t &ws_c_state_offset,
+ size_t &ws_diff_states_offset, size_t &ws_grid_comp_offset,
+ size_t &ws_cell_comp_offset, size_t &ws_bias_offset,
+ size_t &scratchpad_size, size_t &workspace_size);
+
+void get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
+ size_t &scratchpad_size, size_t &workspace_size);
+status_t set_expected_desc(
+ rnn_conf_t &rnn, memory_desc_t &weights_md, bool is_iter);
+status_t set_good_strides(memory_desc_t &weights_md, format_tag_t tag);
+
+template <typename T>
+struct ws_gates_aoc {
+ ws_gates_aoc(const rnn_conf_t &rnn, T *data)
+ : gates_(data, rnn.gates_nld, rnn.gates_ws_ld), DIC_(rnn.dic) {}
+ T &operator()(int batch, int gate, int dic) {
+ return gates_(batch, gate * DIC_ + dic);
+ }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<T, 2> gates_;
+ int DIC_;
+};
+using ws_gates_aoc_t = ws_gates_aoc<float>;
+using ws_gates_aoc_s32_t = ws_gates_aoc<int32_t>;
+
+struct bias_aoc_t {
+ bias_aoc_t(const rnn_conf_t &rnn, const float *data)
+ : bias_(data, rnn.n_bias, rnn.dic) {}
+ const float &operator()(int bias_n, int dic) { return bias_(bias_n, dic); }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<const float, 2> bias_;
+};
+
+template <typename T>
+struct ws_states_aoc {
+ ws_states_aoc(const rnn_conf_t &rnn, T *data)
+ : state_(data, rnn.states_nld, rnn.states_ws_ld) {}
+ T &operator()(int batch, int dic) { return state_(batch, dic); }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<T, 2> state_;
+};
+using ws_states_aoc_t = ws_states_aoc<float>;
+using ws_states_aoc_u8_t = ws_states_aoc<uint8_t>;
+
+struct ws_diff_states_aoc_t {
+ ws_diff_states_aoc_t(const rnn_conf_t &rnn, float *data)
+ : diff_states_(data, rnn.n_states + 1, rnn.n_iter + 1, rnn.states_nld,
+ rnn.states_ws_ld) {}
+ float &operator()(int state_n, int batch, int dic) {
+ return diff_states_(state_n, 0, batch, dic);
+ }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<float, 4> diff_states_;
+};
+
+struct ws_diff_w_iter_aoc_t {
+ ws_diff_w_iter_aoc_t(const rnn_conf_t &rnn, float *data)
+ : diff_weights_iter_(
+ data, rnn.diff_weights_iter_nld, rnn.diff_weights_iter_ld)
+ , DIC_(rnn.dic) {}
+ float &operator()(int sic, int gate, int dic) {
+ return diff_weights_iter_(sic, gate * DIC_ + dic);
+ }
+
+private:
+ mkldnn::impl::utils::array_offset_calculator<float, 2> diff_weights_iter_;
+ int DIC_;
+};
+}
+}
+}
+}
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp
new file mode 100644
index 0000000000..0420f87aa5
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.cpp
@@ -0,0 +1,126 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_thread.hpp"
+
+#include "simple_concat.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace memory_tracking::names;
+
+template <data_type_t data_type>
+status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const {
+ auto scratchpad = this->scratchpad(ctx);
+ auto iptrs = scratchpad.template get<const data_t *>(key_concat_iptrs);
+ auto optrs = scratchpad.template get<data_t *>(key_concat_optrs);
+ auto nelems_to_copy = scratchpad.template get<dim_t>(key_concat_nelems);
+ auto is = scratchpad.template get<strides_t>(key_concat_istrides);
+
+ const int num_arrs = pd()->n_inputs();
+ const int *perm = pd()->perm_, *iperm = pd()->iperm_;
+ const int concat_dim = pd()->concat_dim();
+ auto o_base_ptr = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ for (int a = 0; a < num_arrs; ++a) {
+ const memory_desc_wrapper i_d(pd()->src_md(a));
+ const memory_desc_wrapper o_d(pd()->src_image_md(a));
+
+ iptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a)
+ + i_d.blk_off(0);
+ optrs[a] = o_base_ptr + o_d.blk_off(0);
+ nelems_to_copy[a] = pd()->nelems_to_concat(i_d);
+ for (int i = 0; i < MKLDNN_MAX_NDIMS; i++) {
+ if (i < perm[concat_dim])
+ is[a][i] = size_t(i_d.blocking_desc().strides[iperm[i]]);
+ else
+ is[a][i] = 0;
+ }
+ }
+
+ const memory_desc_wrapper o_d(pd()->src_image_md(0));
+
+ strides_t os = { 0 };
+ for (int i = 0; i < perm[concat_dim]; i++)
+ os[i] = o_d.blocking_desc().strides[iperm[i]];
+
+ dims_t phys_dims;
+ for (size_t i = 0; i < sizeof(phys_dims)/sizeof(phys_dims[0]); i++)
+ phys_dims[i] = (i < (size_t)perm[concat_dim])
+ ? o_d.dims()[iperm[i]] / pd()->blocks_[iperm[i]] : 1;
+
+ if (perm[concat_dim] == 0) {
+ for (int a = 0; a < num_arrs; ++a) {
+ const data_t *i = &iptrs[a][0];
+ data_t *o = &optrs[a][0];
+ parallel_nd((ptrdiff_t)nelems_to_copy[a],
+ [&](ptrdiff_t e) { o[e] = i[e]; });
+ }
+ } else {
+ parallel_nd(phys_dims[0], phys_dims[1], phys_dims[2], phys_dims[3],
+ phys_dims[4], num_arrs,
+ [&](dim_t n0, dim_t n1, dim_t n2, dim_t n3, dim_t n4, int a) {
+ // XXX: this code may access uninitialized values in is[*][0-4] --
+ // that's why we have to set them to zero although this is
+ // probably benign
+ size_t in_off = is[a][0] * n0 + is[a][1] * n1 + is[a][2] * n2
+ + is[a][3] * n3 + is[a][4] * n4;
+ size_t out_off = os[0] * n0 + os[1] * n1 + os[2] * n2
+ + os[3] * n3 + os[4] * n4;
+ const data_t *i = &iptrs[a][in_off];
+ data_t *o = &optrs[a][out_off];
+#if defined(__GNUC__) && !defined(__INTEL_COMPILER)
+ // The code below performs data copying: o[e] = i[e]
+ // and uses a workaround to make GNU compilers optimize it
+ uint8_t *ptro = reinterpret_cast<uint8_t *>(o);
+ const uint8_t *ptri = reinterpret_cast<const uint8_t *>(i);
+ const dim_t main_part =
+ nelems_to_copy[a] * sizeof(data_t) / sizeof(uint32_t);
+ const dim_t tail_part =
+ nelems_to_copy[a] % sizeof(data_t) / sizeof(uint32_t);
+
+ PRAGMA_OMP_SIMD()
+ for (dim_t e = 0; e < main_part; ++e) {
+ *(reinterpret_cast<uint32_t *>(ptro))
+ = *(reinterpret_cast<const uint32_t *>(ptri));
+ ptro += sizeof(uint32_t);
+ ptri += sizeof(uint32_t);
+ }
+ for (dim_t e = 0; e < tail_part; ++e) {
+ *ptro = *ptri;
+ ++ptro;
+ ++ptri;
+ }
+#else
+ PRAGMA_OMP_SIMD()
+ for (dim_t e = 0; e < nelems_to_copy[a]; ++e) o[e] = i[e];
+#endif
+ });
+ }
+
+ return status::success;
+}
+
+template struct simple_concat_t<data_type::f32>;
+template struct simple_concat_t<data_type::u8>;
+template struct simple_concat_t<data_type::s8>;
+template struct simple_concat_t<data_type::s32>;
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
new file mode 100644
index 0000000000..5177275452
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_concat.hpp
@@ -0,0 +1,155 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SIMPLE_CONCAT_HPP
+#define SIMPLE_CONCAT_HPP
+
+#include "memory_tracking.hpp"
+
+#include "cpu_concat_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t data_type>
+struct simple_concat_t: public cpu_primitive_t {
+ struct pd_t: public cpu_concat_pd_t {
+ using cpu_concat_pd_t::cpu_concat_pd_t;
+
+ pd_t(const pd_t &rhs): cpu_concat_pd_t(rhs) {
+ int ndims = rhs.dst_md_.ndims;
+ utils::array_copy(perm_, rhs.perm_, ndims);
+ utils::array_copy(iperm_, rhs.iperm_, ndims);
+ utils::array_copy(blocks_, rhs.blocks_, ndims);
+ }
+
+ DECLARE_CONCAT_PD_T("simple:any", simple_concat_t);
+
+ status_t init() {
+ const memory_desc_wrapper dst_d(dst_md());
+ bool ok = true
+ && cpu_concat_pd_t::init() == status::success
+ && dst_d.ndims() <= 6;
+ if (!ok) return status::unimplemented;
+
+ for (size_t i = 0; i < src_mds_.size(); ++i) {
+ const memory_desc_wrapper i_d(&src_mds_[i]);
+ const memory_desc_wrapper o_d(&src_image_mds_[i]);
+
+ const int ignore_strides = 0;
+
+ ok = ok
+ && utils::everyone_is(data_type, i_d.data_type(),
+ o_d.data_type())
+ && utils::everyone_is(format_kind::blocked,
+ i_d.format_kind(), o_d.format_kind())
+ && types::blocking_desc_is_equal(i_d.blocking_desc(),
+ o_d.blocking_desc(), ignore_strides)
+ && types::blocking_desc_is_equal(i_d.blocking_desc(),
+ dst_d.blocking_desc(), ignore_strides)
+ && !i_d.is_additional_buffer();
+ if (!ok) return status::unimplemented;
+ }
+
+ dst_d.compute_blocks(blocks_);
+ format_perm();
+
+ // start dim is the first dimension after which the concatenation
+ // would happen contiguously
+ const int start_dim = perm_[concat_dim()];
+
+ // check that contiguous part is indeed contiguous (i.e. dense)
+ if (nelems_to_concat(dst_d) !=
+ dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()]
+ * dst_d.blocking_desc().strides[concat_dim()])
+ return status::unimplemented;
+
+ // check that all inputs have the same strides for the
+ // contiguous part [concat_dim .. ndims] for the *major* dims.
+ // the block part is already checked above
+ for (size_t i = 0; i < src_mds_.size(); ++i) {
+ const memory_desc_wrapper i_d(&src_mds_[i]);
+ for (int d = start_dim; d < dst_d.ndims(); ++d) {
+ if (dst_d.blocking_desc().strides[iperm_[d]]
+ != i_d.blocking_desc().strides[iperm_[d]])
+ return status::unimplemented;
+ }
+ }
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ int perm_[MKLDNN_MAX_NDIMS];
+ int iperm_[MKLDNN_MAX_NDIMS];
+ dims_t blocks_;
+
+ dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
+ const int ndims = data_d.ndims();
+
+ dim_t nelems = 1;
+ for (int i = perm_[concat_dim()]; i < ndims; i++)
+ nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]];
+ for (int i = 0; i < ndims; i++)
+ nelems *= blocks_[i];
+
+ return nelems;
+ }
+
+ private:
+ void format_perm() {
+ const memory_desc_wrapper dst_d(dst_md());
+ const int ndims = dst_d.ndims();
+
+ strides_t strides;
+ utils::array_copy(strides, dst_d.blocking_desc().strides, ndims);
+ for (int i = 0; i < ndims; i++) iperm_[i] = i;
+
+ utils::simultaneous_sort(strides, iperm_, ndims,
+ [](stride_t a, stride_t b) { return b - a; });
+
+ for (int i = 0; i < ndims; i++) perm_[iperm_[i]] = i;
+ }
+
+ void init_scratchpad() {
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ scratchpad.book(key_concat_iptrs, sizeof(data_t *) * n_inputs());
+ scratchpad.book(key_concat_optrs, sizeof(data_t *) * n_inputs());
+ scratchpad.book(key_concat_nelems, sizeof(dim_t) * n_inputs());
+ scratchpad.book(key_concat_istrides,
+ sizeof(strides_t) * n_inputs());
+ }
+ };
+
+ simple_concat_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override;
+
+ typedef typename prec_traits<data_type>::type data_t;
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp
new file mode 100644
index 0000000000..e6c3b8d7af
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_q10n.hpp
@@ -0,0 +1,98 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_SIMPLE_Q10N_HPP
+#define CPU_SIMPLE_Q10N_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "math_utils.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::math;
+
+template <typename out_t>
+inline out_t round_and_saturate(float f)
+{ return math::saturate<out_t>(out_round<int>(f)); }
+
+/* Quantization with alpha == 1 and beta == 0 */
+template <typename in_t, typename out_t, typename enabled = void>
+struct qz_a1b0 {
+ out_t operator()(in_t in)
+ { return round_and_saturate<out_t>((float)in); }
+};
+
+template <typename in_t, typename out_t>
+struct qz_a1b0<in_t, out_t,
+ typename utils::enable_if<true
+ && nstl::is_integral<in_t>::value
+ && !is_subset<in_t, out_t>::value
+ >::type> {
+ out_t operator()(in_t in) { return math::saturate<out_t>(in); }
+};
+
+template <typename in_t, typename out_t>
+struct qz_a1b0<in_t, out_t,
+ typename utils::enable_if<is_subset<in_t, out_t>::value>::type> {
+ out_t operator()(in_t in) { return (out_t)in; }
+};
+
+/* Quantization with alpha == 1 */
+template <typename in_t, typename out_t> struct qz_a1 {
+ out_t operator()(in_t in, out_t out, float beta)
+ { return round_and_saturate<out_t>((float)in + beta * out); }
+};
+
+template <typename in_t> struct qz_a1<in_t, float> {
+ float operator()(in_t in, float out, float beta)
+ { return (float)in + beta * out; }
+};
+
+/* Quantization with beta == 0 */
+template <typename in_t, typename out_t> struct qz_b0 {
+ out_t operator()(in_t in, float alpha)
+ { return round_and_saturate<out_t>(alpha * in); }
+};
+
+template <typename in_t> struct qz_b0<in_t, float> {
+ float operator()(in_t in, float alpha) { return alpha * in; }
+};
+
+/* Quantization */
+template <typename in_t, typename out_t> struct qz {
+ out_t operator()(in_t in, out_t out, float alpha, float beta) {
+ return round_and_saturate<out_t>(
+ alpha * in + (beta ? beta * out : 0));
+ }
+};
+
+template <typename in_t> struct qz<in_t, float> {
+ float operator()(in_t in, float out, float alpha, float beta)
+ { return alpha * in + (beta ? beta * out : 0); }
+};
+
+}
+}
+}
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp
new file mode 100644
index 0000000000..ff845f5bd3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_reorder.hpp
@@ -0,0 +1,1022 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CPU_SIMPLE_REORDER_HPP
+#define CPU_SIMPLE_REORDER_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "math_utils.hpp"
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "tag_traits.hpp"
+#include "cpu_reorder_pd.hpp"
+#include "cpu_primitive.hpp"
+
+#include "simple_q10n.hpp"
+#include "cpu_isa_traits.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::format_tag;
+using namespace mkldnn::impl::data_type;
+
+using bd = block_dim_t;
+using ib = inner_blk_t;
+
+using namespace mkldnn::impl::utils;
+using math::saturate;
+
+template<impl::data_type_t type>
+using data_t = typename prec_traits<type>::type;
+
+template<impl::data_type_t type_i, impl::data_type_t type_o>
+using _qz_a1b0 = qz_a1b0<data_t<type_i>, data_t<type_o>>;
+
+template<impl::data_type_t type_i, impl::data_type_t type_o>
+using _qz = qz<data_t<type_i>, data_t<type_o>>;
+
+namespace fmt_order {
+ const bool keep = true;
+ const bool reverse = false;
+ const bool any = keep;
+}
+
+namespace spec {
+struct direct_copy {};
+struct direct_copy_except_dim_0 {};
+struct reference {};
+struct conv_s8s8 {};
+}
+
+#define SIMPLE_REORDER_TEMPL_DECL \
+ impl::data_type_t type_i, impl::format_tag_t tag_i, \
+ impl::data_type_t type_o, impl::format_tag_t tag_o, bool order_keep
+#define SIMPLE_REORDER_TEMPL_CALL \
+ type_i, tag_i, type_o, tag_o, order_keep
+
+#define DECLARE_COMMON_PARAMS() \
+ const memory_desc_wrapper &input_d = pd->src_md(); \
+ const memory_desc_wrapper &output_d = pd->dst_md(); \
+ const float alpha = pd->alpha(); MAYBE_UNUSED(alpha); \
+ const float beta = pd->beta(); MAYBE_UNUSED(beta);
+
+/* specific reorders: common template */
+template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
+struct simple_reorder_impl {};
+
+namespace {
+inline bool simple_fmt_check(bool order_keep, impl::format_tag_t tag_i,
+ impl::format_tag_t tag_o, const memory_desc_wrapper &input_d,
+ const memory_desc_wrapper &output_d) {
+ return input_d.matches_tag(order_keep ? tag_i : tag_o)
+ && output_d.matches_tag(order_keep ? tag_o : tag_i);
+}
+inline bool simple_attr_check(const primitive_attr_t *attr, bool many_scales_support) {
+ if (many_scales_support)
+ return true;
+ return IMPLICATION(attr, attr->output_scales_.mask_ == 0);
+}
+}
+
+/* specific reorders: implementation */
+template <SIMPLE_REORDER_TEMPL_DECL>
+struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
+typename utils::enable_if<tag_i == any && (false
+ || tag_o == hwio
+ || tag_o == hwigo)
+ , spec::conv_s8s8>::type>
+{
+ static bool is_applicable(const memory_desc_wrapper &input_d,
+ const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
+ {
+ const size_t D_mask = utils::array_product(input_d.dims(),
+ math::ilog2q(attr->output_scales_.mask_ + 1));
+ const int oc = (input_d.dims()[tag_o == hwigo + 0]);
+ const int g = (tag_o == hwigo) ? (input_d.dims()[0]) : 1;
+
+ return output_d.matches_tag(tag_o)
+ && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8)
+ && (input_d.data_type() == f32 || input_d.data_type() == s8)
+ && output_d.data_type() == s8
+ && (D_mask == 1 || D_mask == (size_t)g * oc);
+ }
+
+ static status_t execute(const cpu_reorder_pd_t *pd,
+ const data_t<type_i> *input, data_t<type_o> *output) {
+ DECLARE_COMMON_PARAMS();
+
+ static constexpr bool w_groups = tag_o == hwigo;
+
+ const auto &dims = input_d.dims();
+ const auto &pdims = output_d.padded_dims();
+
+ const int G = w_groups ? dims[0] : 1;
+ const int OC = dims[w_groups + 0];
+ const int IC = dims[w_groups + 1];
+ const int H = dims[w_groups + 2];
+ const int W = dims[w_groups + 3];
+
+ const float *scales = pd->attr()->output_scales_.scales_;
+ const size_t D_mask = utils::array_product(input_d.dims(),
+ math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
+
+ assert(output_d.extra().flags
+ & memory_extra_flags::compensation_conv_s8s8);
+ float adj_scale =
+ (output_d.extra().flags & memory_extra_flags::scale_adjust)
+ ? output_d.extra().scale_adjust : 1.f;
+
+ size_t offset = G * pdims[w_groups + 0] * pdims[w_groups + 1] * H * W;
+ int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
+
+ parallel_nd(G, OC, [&](int g, int oc) {
+ cp[g * OC + oc] = 0;
+ for (int ic = 0; ic < IC; ic++)
+ for (int h = 0; h < H; h++)
+ for (int w = 0; w < W; w++) {
+ auto i = input[input_d.blk_off<!w_groups>(g, oc, ic, h, w)];
+ auto &o = output[output_d.blk_off<!w_groups>(g, oc, ic, h, w)];
+ const float s = scales[(D_mask == 1) ? 0 : g * OC + oc];
+
+ o = qz_b0<data_t<type_i>, data_t<type_o>>()(
+ i, s * adj_scale);
+ cp[g * OC + oc] -= (int32_t)o;
+ }
+ cp [g * OC + oc] *= 128;
+ });
+ return success;
+ }
+};
+
+template <SIMPLE_REORDER_TEMPL_DECL>
+struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
+ typename utils::enable_if<
+ (tag_i == goiw && tag_o == gOIw4i16o4i)
+ || (tag_i == oiw && tag_o == OIw4i16o4i)
+ || (tag_i == goihw && tag_o == gOIhw4i16o4i)
+ || (tag_i == oihw && tag_o == OIhw4i16o4i)
+ || (tag_i == goihw && tag_o == gOIhw2i8o4i)
+ || (tag_i == goihw && tag_o == gOIhw4o4i)
+ , spec::conv_s8s8>::type>
+{
+ static bool is_applicable(const memory_desc_wrapper &input_d,
+ const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
+ {
+ const size_t D_mask = utils::array_product(input_d.dims(),
+ math::ilog2q(attr->output_scales_.mask_ + 1));
+ const bool w_groups = !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i);
+ const int oc = (input_d.dims()[w_groups ? 1 : 0]);
+ const int g = w_groups ? input_d.dims()[0] : 1;
+
+ return input_d.matches_tag(tag_i)
+ && output_d.matches_tag(tag_o)
+ && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8)
+ && (input_d.data_type() == f32 || input_d.data_type() == s8)
+ && output_d.data_type() == s8
+ && (D_mask == 1 || D_mask == (size_t)g * oc);
+ }
+
+ static status_t execute(const cpu_reorder_pd_t *pd,
+ const data_t<type_i> *input, data_t<type_o> *output) {
+ DECLARE_COMMON_PARAMS();
+
+ static constexpr bool w_groups =
+ !utils::one_of(tag_o, OIw4i16o4i, OIhw4i16o4i);
+ constexpr int is_1d =
+ utils::one_of(tag_o, gOIw4i16o4i, OIw4i16o4i);
+ constexpr int blksize = tag_traits<tag_o>::inner_blks == ib::_4b4c
+ ? 4
+ : tag_traits<tag_o>::inner_blks == ib::_2c8b4c
+ ? 8
+ : 16;
+
+ const auto &_g_oihw_d = order_keep ? input_d : output_d;
+ const auto &dims = input_d.dims();
+ const auto &pdims = order_keep
+ ? output_d.padded_dims()
+ : input_d.padded_dims();
+
+ const int G = w_groups ? dims[0] : 1;
+ const int OC = dims[w_groups + 0];
+ const int NB_OC = pdims[w_groups + 0] / blksize;
+ const int IC = dims[w_groups + 1];
+ const int NB_IC = pdims[w_groups + 1] / blksize;
+ const int H = is_1d ? 1 : dims[w_groups + 2];
+ const int W = dims[w_groups + 3 - is_1d];
+
+ const float *scales = pd->attr()->output_scales_.scales_;
+ const size_t D_mask = utils::array_product(input_d.dims(),
+ math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
+
+ assert(output_d.extra().flags
+ & memory_extra_flags::compensation_conv_s8s8);
+ float adj_scale =
+ (output_d.extra().flags & memory_extra_flags::scale_adjust)
+ ? output_d.extra().scale_adjust : 1.f;
+
+ auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
+ int32_t *c, const float *s, const int oc_block, const int ic_block) {
+# define index AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
+
+ for (int ic = 0; ic < ic_block; ++ic) {
+ for (int oc = 0; oc < oc_block; ++oc) {
+ const auto _g_oihw_off =
+ oc * _g_oihw_d.blocking_desc().strides[w_groups + 0]
+ + ic * _g_oihw_d.blocking_desc().strides[w_groups + 1];
+ out[index(oc, ic)]
+ = qz_b0<data_t<type_i>, data_t<type_o>>()(
+ inp[_g_oihw_off], s[oc] * adj_scale);
+ c[oc] -= (128 * (int32_t)(out[index(oc, ic)]));
+ }
+ }
+# undef index
+ };
+
+ constexpr int i_mult = blksize;
+ constexpr int o_mult = 1;
+
+ size_t offset = G * pdims[w_groups+0] * pdims[w_groups+1] * H * W;
+ int32_t *cp = reinterpret_cast<int32_t *>(output + offset);
+ parallel_nd(G * NB_OC * blksize, [&](int i) {
+ cp[i] = 0;
+ });
+
+# define wei_blk_off(md, g, o, i, h, w) \
+ (is_1d ? (md).blk_off<!w_groups>(g, o, i, w) \
+ : (md).blk_off<!w_groups>(g, o, i, h, w))
+
+ parallel_nd(G, NB_OC, [&](int g, int O) {
+ for (int I = 0; I < NB_IC; I++)
+ for (int h = 0; h < H; h++)
+ for (int w = 0; w < W; w++) {
+ auto i = &input[wei_blk_off(
+ input_d, g, i_mult * O, i_mult * I, h, w)];
+ auto o = &output[wei_blk_off(
+ output_d, g, o_mult * O, o_mult * I, h, w)];
+ const int oc_block = nstl::min(blksize, OC - O * blksize);
+ const int ic_block = nstl::min(blksize, IC - I * blksize);
+
+ int _offset = (g * NB_OC + O) * blksize;
+ ker(i, o, (order_keep) ? &cp[_offset] : nullptr,
+ &scales[(D_mask == 1) ? 0 : _offset],
+ oc_block, ic_block);
+ }
+ });
+
+# undef wei_blk_off
+
+ return success;
+ }
+};
+
+template <SIMPLE_REORDER_TEMPL_DECL>
+struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
+ typename utils::enable_if<false
+ ||(tag_i == goiw && tag_o == Goiw16g)
+ ||(tag_i == goihw && tag_o == Goihw16g)
+ , spec::conv_s8s8>::type>
+{
+ static bool is_applicable(const memory_desc_wrapper &input_d,
+ const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
+ const size_t D_mask = utils::array_product(input_d.dims(),
+ math::ilog2q(attr->output_scales_.mask_ + 1));
+ const int oc = input_d.dims()[1];
+ const int g = input_d.dims()[0];
+
+ return true
+ && order_keep
+ && input_d.matches_tag(tag_i)
+ && output_d.matches_tag(tag_o)
+ && (output_d.extra().flags & memory_extra_flags::compensation_conv_s8s8)
+ && (input_d.data_type() == f32 || input_d.data_type() == s8)
+ && output_d.data_type() == s8
+ && (D_mask == 1 || D_mask == (size_t)g * oc);
+ }
+
+ static status_t execute(const cpu_reorder_pd_t *pd,
+ const data_t<type_i> *input, data_t<type_o> *output) {
+ DECLARE_COMMON_PARAMS();
+
+ constexpr bool is_1d = tag_i == goiw;
+ constexpr int blksize = 16;
+
+ const auto &dims = input_d.dims();
+ const auto &pdims = output_d.padded_dims();
+ const int G = dims[0];
+ const int Gp = pdims[0];
+ const int OC = dims[1];
+ const int IC = dims[2];
+ const int H = is_1d ? 1 : dims[3];
+ const int W = dims[4 - is_1d];
+
+ const size_t D_mask = utils::array_product(input_d.dims(),
+ math::ilog2q(pd->attr()->output_scales_.mask_ + 1));
+ const float *scales = pd->attr()->output_scales_.scales_;
+
+ assert(output_d.extra().flags
+ & memory_extra_flags::compensation_conv_s8s8);
+ float adj_scale =
+ (output_d.extra().flags & memory_extra_flags::scale_adjust)
+ ? output_d.extra().scale_adjust : 1.f;
+
+ auto ker = [&](const data_t<type_i> *inp, data_t<type_o> *out,
+ int32_t *cp, const float *s, const int g_block) {
+ PRAGMA_OMP_SIMD()
+ for (int g = 0; g < g_block; g++) {
+ const auto i_off = g * input_d.blocking_desc().strides[0];
+ out[g] = qz_b0<data_t<type_i>, data_t<type_o>>()(
+ inp[i_off], s[g * OC] * adj_scale);
+ cp[g * OC] -= 128 * (int32_t)(out[g]);
+ }
+ };
+
+ size_t cp_offset = output_d.size() - output_d.additional_buffer_size();
+ int32_t *cp = reinterpret_cast<int32_t *>(output + cp_offset);
+ parallel_nd((Gp/blksize) * OC, [&](int ib) {
+ PRAGMA_OMP_SIMD()
+ for (int i = 0; i < blksize; i++)
+ cp[ib * blksize + i] = 0;
+ });
+
+# define wei_blk_off(md, g, o, i, h, w) \
+ (is_1d ? (md).blk_off(g, o, i, w) : (md).blk_off(g, o, i, h, w))
+
+ parallel_nd(Gp/blksize, OC, [&](int gb, int O) {
+ for (int I = 0; I < IC; I++) {
+ for (int h = 0; h < H; h++)
+ for (int w = 0; w < W; w++)
+ {
+ const int g_block = nstl::min(G - gb * blksize, blksize);
+ const auto inp = &input[wei_blk_off(
+ input_d, gb * blksize, O, I, h, w)];
+ const auto out = &output[wei_blk_off(
+ output_d, gb, O, I, h, w)];
+ int offset = gb * blksize + O;
+ ker(inp, out, &cp[offset],
+ &scales[(D_mask == 1) ? 0 : offset], g_block);
+ }
+ }
+ });
+
+# undef wei_blk_off
+
+ return success;
+ }
+};
+
+/* reorders with tail support */
+
+template <SIMPLE_REORDER_TEMPL_DECL>
+struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
+typename utils::enable_if<false
+ || (tag_i == nCdhw8c && tag_o == nCdhw16c)
+ || (tag_i == nChw8c && tag_o == nChw16c)
+ || (tag_i == nCw8c && tag_o == nCw16c)
+ >::type>
+{
+ static bool is_applicable(const memory_desc_wrapper &input_d,
+ const memory_desc_wrapper &output_d, const primitive_attr_t *attr)
+ {
+ return simple_fmt_check(order_keep, tag_i, tag_o, input_d, output_d)
+ && simple_attr_check(attr, false);
+ }
+
+ static status_t execute(const cpu_reorder_pd_t *pd,
+ const data_t<type_i> *input, data_t<type_o> *output) {
+ DECLARE_COMMON_PARAMS();
+
+ constexpr int is_1d = tag_i == nCw8c;
+ constexpr int is_3d = tag_i == nCdhw8c;
+ constexpr int blksize_16 = 16;
+ constexpr int blksize_8 = 8;
+ constexpr int ic_mult = order_keep ? 2 : 1;
+ constexpr int oc_mult = order_keep ? 1 : 2;
+
+ const auto &dims = input_d.dims();
+ const auto &pdims = order_keep ? output_d.padded_dims()
+ : input_d.padded_dims();
+
+ const int C = dims[1];
+ const int D = is_3d ? dims[2] : 1;
+ const int H = is_1d ? 1 : dims[2 + is_3d];
+ const int W = dims[3 + is_3d - is_1d];
+
+ auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
+ const int block_16) {
+ const int nb = (block_16 - 1) / blksize_8 + 1;
+ if (alpha == 1.0 && beta == 0.0) {
+ for (int b = 0; b < nb; ++b) {
+ const ptrdiff_t i_off = order_keep ? b : b * blksize_8;
+ const ptrdiff_t o_off = order_keep ? b * blksize_8 : b;
+ const int block_8 = nstl::min(blksize_8,
+ block_16 - b * blksize_8);
+ for (int c = 0; c < block_8; ++c) {
+ o[o_off + c] = _qz_a1b0<type_i, type_o>()(
+ i[i_off + c]);
+ }
+ }
+ } else {
+ for (int b = 0; b < nb; ++b) {
+ const ptrdiff_t i_off = order_keep ? b : b * blksize_8;
+ const ptrdiff_t o_off = order_keep ? b * blksize_8 : b;
+ const int block_8 = nstl::min(blksize_8,
+ block_16 - b * blksize_8);
+ for (int c = 0; c < block_8; ++c) {
+ o[o_off + c] = _qz<type_i, type_o>()(i[i_off + c],
+ o[o_off + c], alpha, beta);
+ }
+ }
+ }
+ };
+
+# define data_blk_off(md, n, c, d, h, w) \
+ ( is_1d ? (md).blk_off(n, c, w) \
+ : is_3d ? (md).blk_off(n, c, d, h, w) : (md).blk_off(n, c, h, w))
+
+ parallel_nd(dims[0], pdims[1] / blksize_16, D, H, W,
+ [&](int n, int nb_c, int d, int h, int w) {
+ auto i = &input[data_blk_off(input_d, n, ic_mult * nb_c, d, h, w)];
+ auto o = &output[data_blk_off(output_d, n, oc_mult * nb_c, d, h, w)];
+ const int block_16 = nstl::min(blksize_16, C - nb_c * blksize_16);
+ ker(i, o, block_16);
+ });
+
+# undef data_blk_off
+
+ return success;
+ }
+};
+
+#define PLAIN_TO_BLOCKED_IS_APPLICABLE() \
+ static bool is_applicable(const memory_desc_wrapper &input_d, \
+ const memory_desc_wrapper &output_d, const primitive_attr_t *attr) { \
+ return simple_attr_check(attr, false) && (order_keep \
+ ? output_d.matches_tag(tag_o) && input_d.is_plain() \
+ : input_d.matches_tag(tag_o) && output_d.is_plain()); \
+ }
+
+template <SIMPLE_REORDER_TEMPL_DECL>
+struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
+typename utils::enable_if<tag_i == any
+ && (tag_traits<tag_o>::block_dims == bd::_A
+ || tag_traits<tag_o>::block_dims == bd::_B)
+ && tag_traits<tag_o>::ndims >= 3
+ && tag_traits<tag_o>::ndims <= 6
+ >::type>
+{
+ PLAIN_TO_BLOCKED_IS_APPLICABLE();
+
+ static status_t execute(const cpu_reorder_pd_t *pd,
+ const data_t<type_i> *input, data_t<type_o> *output) {
+ DECLARE_COMMON_PARAMS();
+
+ const auto &flat_d = order_keep ? input_d : output_d;
+ const auto &block_d = order_keep ? output_d : input_d;
+ const auto &dims = input_d.dims();
+ const auto &pdims = block_d.padded_dims();
+
+ constexpr int ndims = tag_traits<tag_o>::ndims;
+ constexpr int blk_idx = tag_traits<tag_o>::block_dims == bd::_A ? 0 : 1;
+
+ const dim_t H0 = dims[0];
+ const dim_t H1 = dims[1];
+ const dim_t M0 = ndims >= 6 ? dims[ndims - 4] : 1;
+ const dim_t M1 = ndims >= 5 ? dims[ndims - 3] : 1;
+ const dim_t M2 = ndims >= 4 ? dims[ndims - 2] : 1;
+ const dim_t L = dims[ndims - 1];
+ const dim_t l_blk_stride = block_d.blocking_desc().strides[ndims - 1];
+
+ constexpr int blksize = false ? 0
+ : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_4a, ib::_4b) ? 4
+ : utils::one_of(tag_traits<tag_o>::inner_blks, ib::_8a, ib::_8b) ? 8
+ : 16;
+
+ auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o, int block) {
+ if (alpha == 1.0 && beta == 0.0) {
+ for (int l = 0; l < L; ++l)
+ for (int blk = 0; blk < block; ++blk) {
+ const dim_t flat_off = 0
+ + blk * flat_d.blocking_desc().strides[blk_idx]
+ + l * flat_d.blocking_desc().strides[ndims - 1];
+ if (order_keep) {
+ o[l * l_blk_stride + blk] = _qz_a1b0<type_i, type_o>()(
+ i[flat_off]);
+ } else {
+ o[flat_off] = _qz_a1b0<type_i, type_o>()(
+ i[l * l_blk_stride + blk]);
+ }
+ }
+ } else {
+ for (int l = 0; l < L; ++l)
+ for (int blk = 0; blk < block; ++blk) {
+ const dim_t flat_off = 0
+ + blk * flat_d.blocking_desc().strides[blk_idx]
+ + l * flat_d.blocking_desc().strides[ndims - 1];
+ if (order_keep) {
+ o[l * l_blk_stride + blk] = _qz<type_i, type_o>()(
+ i[flat_off], o[l * blksize + blk],
+ alpha, beta);
+ } else {
+ o[flat_off] = _qz<type_i, type_o>()(
+ i[l * l_blk_stride + blk], o[flat_off],
+ alpha, beta);
+ }
+ }
+ }
+ };
+
+# define off(md, h0, h1, m0, m1, m2) \
+ (ndims >= 6 ? (md).blk_off(h0, h1, m0, m1, m2) \
+ : ndims >= 5 ? (md).blk_off(h0, h1, m1, m2) \
+ : ndims >= 4 ? (md).blk_off(h0, h1, m2) \
+ : /* ndims >= 3 ? */ (md).blk_off(h0, h1))
+
+ constexpr int i_mult = order_keep ? blksize : 1;
+ constexpr int o_mult = order_keep ? 1 : blksize;
+
+ if (blk_idx == 0) {
+ const dim_t BH0 = pdims[0] / blksize;
+ parallel_nd(BH0, H1, M0, M1, M2,
+ [&](dim_t bh0, dim_t h1, dim_t m0, dim_t m1, dim_t m2) {
+ auto i = &input[off(input_d, bh0 * i_mult, h1, m0, m1, m2)];
+ auto o = &output[off(output_d, bh0 * o_mult, h1, m0, m1, m2)];
+ const int block = nstl::min<int>(blksize, H0 - bh0 * blksize);
+ ker(i, o, block);
+ });
+ } else if (blk_idx == 1) {
+ const dim_t BH1 = pdims[1] / blksize;
+ parallel_nd(H0, BH1, M0, M1, M2,
+ [&](dim_t h0, dim_t bh1, dim_t m0, dim_t m1, dim_t m2) {
+ auto i = &input[off(input_d, h0, bh1 * i_mult, m0, m1, m2)];
+ auto o = &output[off(output_d, h0, bh1 * o_mult, m0, m1, m2)];
+ const int block = nstl::min<int>(blksize, H1 - bh1 * blksize);
+ ker(i, o, block);
+ });
+ } else {
+ assert(!"unimplemented");
+ }
+
+# undef off
+
+ return success;
+ }
+};
+
+template <SIMPLE_REORDER_TEMPL_DECL>
+struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
+typename utils::enable_if<tag_i == any
+ && (tag_traits<tag_o>::block_dims == bd::_AB
+ || tag_traits<tag_o>::block_dims == bd::_BC)
+ && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_AB,
+ tag_traits<tag_o>::ndims >= 3 && tag_traits<tag_o>::ndims <= 5)
+ && IMPLICATION(tag_traits<tag_o>::block_dims == bd::_BC,
+ tag_traits<tag_o>::ndims >= 4 && tag_traits<tag_o>::ndims <= 6)
+ >::type>
+{
+ PLAIN_TO_BLOCKED_IS_APPLICABLE();
+
+ static status_t execute(const cpu_reorder_pd_t *pd,
+ const data_t<type_i> *input, data_t<type_o> *output) {
+ DECLARE_COMMON_PARAMS();
+
+ const auto &flat_d = order_keep ? input_d : output_d;
+ const auto &dims = input_d.dims();
+ const auto &pdims = order_keep
+ ? output_d.padded_dims()
+ : input_d.padded_dims();
+
+ constexpr int ndims = tag_traits<tag_o>::ndims;
+
+ static constexpr bool with_g = tag_traits<tag_o>::block_dims == bd::_BC;
+ const dim_t G = with_g ? dims[0] : 1;
+
+ const dim_t H0 = dims[0 + with_g];
+ const dim_t H1 = dims[1 + with_g];
+
+ const dim_t M0 = ndims >= 5 + with_g ? dims[ndims - 3] : 1;
+ const dim_t M1 = ndims >= 4 + with_g ? dims[ndims - 2] : 1;
+ const dim_t M2 = ndims >= 3 + with_g ? dims[ndims - 1] : 1;
+
+ constexpr int blksize_0 = false ? 0
+ : utils::one_of(tag_traits<tag_o>::inner_blks,
+ ib::_4b4a, ib::_4b4c, ib::_4c4b)
+ ? 4
+ : utils::one_of(tag_traits<tag_o>::inner_blks,
+ ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c)
+ ? 8
+ : utils::one_of(tag_traits<tag_o>::inner_blks,
+ ib::_16a16b, ib::_16a4b, ib::_16b16a, ib::_16b4c,
+ ib::_16b16c, ib::_16c16b, ib::_8a16b2a, ib::_4b16a4b,
+ ib::_8b16a2b, ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c)
+ ? 16 : INT_MIN;
+
+ constexpr int blksize_1 = utils::one_of(tag_traits<tag_o>::inner_blks,
+ ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_2c8b4c)
+ ? 8
+ : utils::one_of(tag_traits<tag_o>::inner_blks,
+ ib::_16a16b, ib::_16b16a, ib::_16b16c, ib::_16c16b,
+ ib::_8a16b2a, ib::_4b16a4b, ib::_8b16a2b, ib::_8b16c2b,
+ ib::_4c16b4c, ib::_8c16b2c)
+ ? 16
+ : utils::one_of(tag_traits<tag_o>::inner_blks,
+ ib::_4b4a, ib::_4b4c, ib::_4c4b,
+ ib::_16a4b, ib::_16b4c)
+ ? 4
+ : INT_MIN;
+
+ const dim_t NB_H0 = pdims[0 + with_g] / blksize_0;
+ const dim_t NB_H1 = pdims[1 + with_g] / blksize_1;
+
+ auto ker = [&](const data_t<type_i> *i, data_t<type_o> *o,
+ const int block_h0, const int block_h1) {
+# define blk_off AB_or_BC_blk_off<tag_traits<tag_o>::inner_blks>
+
+ if (alpha == 1.0 && beta == 0.0) {
+ for (int h0 = 0; h0 < block_h0; ++h0)
+ for (int h1 = 0; h1 < block_h1; ++h1) {
+ const dim_t flat_off = 0
+ + h0 * flat_d.blocking_desc().strides[with_g + 0]
+ + h1 * flat_d.blocking_desc().strides[with_g + 1];
+ if (order_keep) {
+ o[blk_off(h0, h1)] = _qz_a1b0<type_i, type_o>()(
+ i[flat_off]);
+ } else {
+ o[flat_off] = _qz_a1b0<type_i, type_o>()(
+ i[blk_off(h0, h1)]);
+ }
+ }
+ } else {
+ for (int h0 = 0; h0 < block_h0; ++h0)
+ for (int h1 = 0; h1 < block_h1; ++h1) {
+ const dim_t flat_off = 0
+ + h0 * flat_d.blocking_desc().strides[with_g + 0]
+ + h1 * flat_d.blocking_desc().strides[with_g + 1];
+ if (order_keep) {
+ o[blk_off(h0, h1)] = _qz<type_i, type_o>()(i[flat_off],
+ o[blk_off(h0, h1)], alpha, beta);
+ } else {
+ o[flat_off] = _qz<type_i, type_o>()(i[blk_off(h0, h1)],
+ o[flat_off], alpha, beta);
+ }
+ }
+ }
+
+# undef blk_off
+ };
+
+ constexpr int i_mult_0 = order_keep ? blksize_0 : 1;
+ constexpr int o_mult_0 = order_keep ? 1 : blksize_0;
+
+ constexpr int i_mult_1 = order_keep ? blksize_1 : 1;
+ constexpr int o_mult_1 = order_keep ? 1 : blksize_1;
+
+# define off(md, g, h0, h1, m0, m1, m2) \
+ (ndims >= 5 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m0, m1, m2) \
+ : ndims >= 4 + with_g ? (md).blk_off<!with_g>(g, h0, h1, m1, m2) \
+ : /* ndims >= 3 + with_g ? */ (md).blk_off<!with_g>(g, h0, h1, m2))
+
+ parallel_nd(G, NB_H0, NB_H1, M0, M1, M2,
+ [&](dim_t g, dim_t nb_h0, dim_t nb_h1, dim_t m0, dim_t m1, dim_t m2) {
+ auto i = &input[off(input_d,
+ g, i_mult_0 * nb_h0, i_mult_1 * nb_h1, m0, m1, m2)];
+ auto o = &output[off(output_d,
+ g, o_mult_0 * nb_h0, o_mult_1 * nb_h1, m0, m1, m2)];
+ const int block_h0 = nstl::min<int>(blksize_0, H0 - nb_h0 * blksize_0);
+ const int block_h1 = nstl::min<int>(blksize_1, H1 - nb_h1 * blksize_1);
+ ker(i, o, block_h0, block_h1);
+ });
+
+# undef off
+
+ return success;
+ }
+};
+
+/* generic and direct-copy reorders */
+
+template <SIMPLE_REORDER_TEMPL_DECL>
+struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
+ typename utils::enable_if<
+ tag_i == any && tag_o == any && order_keep == fmt_order::any,
+ spec::direct_copy>::type>
+{
+ static bool is_applicable(const memory_desc_wrapper &input_d,
+ const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
+ /* FIXME: is the formula correct? */
+ return input_d.similar_to(output_d, true, false, 0)
+ && input_d.is_dense() && output_d.is_dense()
+ && simple_attr_check(attr, false);
+ }
+
+ static status_t execute(const cpu_reorder_pd_t *pd,
+ const data_t<type_i> *input, data_t<type_o> *output) {
+ DECLARE_COMMON_PARAMS();
+
+ assert(input_d.is_dense());
+
+ input += input_d.blk_off(0);
+ output += output_d.blk_off(0);
+
+ const size_t nelems = input_d.nelems();
+
+ constexpr int block_size = 16;
+ const auto num_blocks = nelems / block_size;
+ const auto rem_elems = nelems % block_size;
+
+ parallel(0, [&](const int ithr, const int nthr) {
+ size_t start{0}, end{0};
+ balance211(num_blocks, nthr, ithr, start, end);
+ start = start * block_size;
+ end = end * block_size;
+
+ if (alpha == 1.0 && beta == 0.0) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start; e < end; ++e) {
+ output[e] = qz_a1b0<data_t<type_i>, data_t<type_o>>()
+ (input[e]);
+ }
+ } else if (alpha == 1.0) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start; e < end; ++e) {
+ output[e] = qz_a1<data_t<type_i>, data_t<type_o>>()
+ (input[e], output[e], beta);
+ }
+ } else if (beta == 0.0) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start; e < end; ++e) {
+ output[e] = qz_b0<data_t<type_i>, data_t<type_o>>()
+ (input[e], alpha);
+ }
+ } else {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start; e < end; ++e) {
+ output[e] = qz<data_t<type_i>, data_t<type_o>>()
+ (input[e], output[e], alpha, beta);
+ }
+ }
+
+ if (rem_elems != 0 && ithr == nthr - 1){
+ if (alpha == 1.0 && beta == 0.0) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = nelems - rem_elems; e < nelems; ++e) {
+ output[e] = qz_a1b0<data_t<type_i>,
+ data_t<type_o>>()(input[e]);
+ }
+ } else if (alpha == 1.0) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = nelems - rem_elems; e < nelems; ++e) {
+ output[e] = qz_a1<data_t<type_i>,
+ data_t<type_o>>()(input[e], output[e], beta);
+ }
+ } else if (beta == 0.0) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = nelems - rem_elems; e < nelems; ++e) {
+ output[e] = qz_b0<data_t<type_i>,
+ data_t<type_o>>()(input[e], alpha);
+ }
+ } else {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = nelems - rem_elems; e < nelems; ++e) {
+ output[e] = qz<data_t<type_i>, data_t<type_o>>()
+ (input[e], output[e], alpha, beta);
+ }
+ }
+ }
+ });
+ return success;
+ }
+};
+
+template <SIMPLE_REORDER_TEMPL_DECL>
+struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
+ typename utils::enable_if<
+ tag_i == any && tag_o == any && order_keep == fmt_order::any,
+ spec::direct_copy_except_dim_0>::type>
+{
+ static bool is_applicable(const memory_desc_wrapper &input_d,
+ const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
+ auto is_dense_no_0 = [](const memory_desc_wrapper &data_d) {
+ return nelems_no_dim_0(data_d) == _size_no_dim_0(data_d);
+ };
+ /* FIXME: is the formula correct? */
+ return input_d.similar_to(output_d, true, false, 1)
+ && is_dense_no_0(input_d) && is_dense_no_0(output_d)
+ && simple_attr_check(attr, false);
+ }
+
+ static status_t execute(const cpu_reorder_pd_t *pd,
+ const data_t<type_i> *input, data_t<type_o> *output) {
+ DECLARE_COMMON_PARAMS();
+
+ input += input_d.blk_off(0);
+ output += output_d.blk_off(0);
+
+ const int N = input_d.dims()[0];
+ const dim_t is = input_d.blocking_desc().strides[0];
+ const dim_t os = output_d.blocking_desc().strides[0];
+ const dim_t nelems_no_d0 = nelems_no_dim_0(input_d);
+ const dim_t work_amount = N * nelems_no_d0;
+
+ if (alpha == 1.0 && beta == 0.0) {
+ parallel(0, [&](const int ithr, const int nthr) {
+ dim_t n{0}, dim1_s{0};
+ dim_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+ nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
+ while(start < end) {
+ dim_t work_rem = end - start;
+ dim_t dim1_e = dim1_s + work_rem > nelems_no_d0
+ ? nelems_no_d0 : dim1_s + work_rem;
+ PRAGMA_OMP_SIMD()
+ for (dim_t e = dim1_s; e < dim1_e; ++e) {
+ output[os * n + e] = _qz_a1b0<type_i, type_o>()(
+ input[is * n + e]);
+ }
+ nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
+ }
+ });
+ } else {
+ parallel(0, [&](const int ithr, const int nthr) {
+ dim_t n{0}, dim1_s{0};
+ dim_t start{0}, end{0};
+ balance211(work_amount, nthr, ithr, start, end);
+ nd_iterator_init(start, n, N, dim1_s, nelems_no_d0);
+ while(start < end) {
+ dim_t work_rem = end - start;
+ dim_t dim1_e =
+ dim1_s + work_rem > nelems_no_d0 ? nelems_no_d0
+ : dim1_s + work_rem;
+ PRAGMA_OMP_SIMD()
+ for (dim_t e = dim1_s; e < dim1_e; ++e){
+ output[os * n + e] = _qz<type_i, type_o>()(
+ input[is * n + e], output[os * n + e], alpha,
+ beta);
+ }
+ nd_iterator_jump(start, end, n, N, dim1_s, nelems_no_d0);
+ }
+ });
+ }
+
+ return success;
+ }
+
+private:
+ static dim_t nelems_no_dim_0(const memory_desc_wrapper &data_d) {
+ const int ndims = data_d.ndims();
+ if (ndims <= 1) return 1;
+ return utils::array_product(data_d.dims() + 1, data_d.ndims() - 1);
+ }
+
+ static dim_t _size_no_dim_0(const memory_desc_wrapper &data_d) {
+ dims_t blocks;
+ data_d.compute_blocks(blocks);
+
+ const auto &blk = data_d.blocking_desc();
+
+ dim_t blk_size = 1;
+ for (int iblk = 0; iblk < blk.inner_nblks; ++iblk)
+ blk_size *= blk.inner_blks[iblk];
+
+ dim_t max_size = blk_size;
+ for (int d = 1; d < data_d.ndims(); ++d) {
+ max_size = nstl::max(max_size,
+ data_d.padded_dims()[d] / blocks[d] * blk.strides[d]);
+ }
+
+ return max_size;
+ }
+};
+
+template <SIMPLE_REORDER_TEMPL_DECL>
+struct simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL,
+ typename utils::enable_if<
+ tag_i == any && tag_o == any && order_keep == fmt_order::any,
+ spec::reference>::type>
+{
+ static bool is_applicable(const memory_desc_wrapper &input_d,
+ const memory_desc_wrapper &output_d, const primitive_attr_t *attr) {
+ /* supported smask: 0x0...011..10...0,
+ * i.e. 1 should be contiguous */
+ int smask = attr ? attr->output_scales_.mask_ : 0;
+ for (; smask > 0 && !(smask & 0x1); smask >>= 1);
+ for (; smask > 0 && smask & 0x1; smask >>= 1);
+ return true
+ && input_d.is_blocking_desc()
+ && output_d.is_blocking_desc()
+ && !output_d.is_additional_buffer()
+ && !input_d.is_additional_buffer()
+ && smask == 0;
+ }
+
+ static status_t execute(const cpu_reorder_pd_t *pd,
+ const data_t<type_i> *input, data_t<type_o> *output) {
+ DECLARE_COMMON_PARAMS();
+
+ const size_t nelems = input_d.nelems();
+
+ int ndims_start = 0, ndims_mask = 0;
+ int smask = pd->attr()->output_scales_.mask_;
+ for (; smask > 0 && !(smask & 0x1); smask >>= 1) ++ndims_start;
+ for (; smask > 0 && smask & 0x1; smask >>= 1) ++ndims_mask;
+ assert(smask == 0);
+
+ const ptrdiff_t D_start
+ = utils::array_product(input_d.dims(), ndims_start);
+ const ptrdiff_t D_mask
+ = utils::array_product(input_d.dims() + ndims_start, ndims_mask);
+ const ptrdiff_t D_rest = nelems / D_start / D_mask;
+
+ const float *scales = pd->attr()->output_scales_.scales_;
+
+ parallel_nd(D_start, D_mask, D_rest,
+ [&](ptrdiff_t ds, ptrdiff_t dm, ptrdiff_t dr) {
+ const float scale = scales[dm];
+
+ const size_t e = (ds * D_mask + dm) * D_rest + dr;
+ const auto &i = input[input_d.off_l(e)];
+ auto &o = output[output_d.off_l(e)];
+
+ o = _qz<type_i, type_o>()(i, o, scale, beta);
+ });
+
+ return success;
+ }
+};
+
+
+/* high level class declaration */
+
+template <SIMPLE_REORDER_TEMPL_DECL, typename spec = void>
+struct simple_reorder_t: public cpu_primitive_t {
+ struct pd_t: public cpu_reorder_pd_t {
+ using cpu_reorder_pd_t::cpu_reorder_pd_t;
+
+ DECLARE_COMMON_PD_T("simple:any", simple_reorder_t);
+
+ static status_t create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+ bool args_ok = true
+ && src_md->data_type == type_i
+ && dst_md->data_type == type_o
+ && simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::
+ is_applicable(src_md, dst_md, attr);
+ if (!args_ok)
+ return status::invalid_arguments;
+
+ auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
+ dst_md);
+ if (_pd == nullptr) return status::out_of_memory;
+ if (_pd->init() != status::success) {
+ delete _pd;
+ return status::unimplemented;
+ }
+ return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
+ }
+ };
+
+ simple_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto input = CTX_IN_MEM(const data_t<type_i> *, MKLDNN_ARG_FROM);
+ auto output = CTX_OUT_MEM(data_t<type_o> *, MKLDNN_ARG_TO);
+ simple_reorder_impl<SIMPLE_REORDER_TEMPL_CALL, spec>::execute(
+ pd(), input, output);
+ return status::success;
+ }
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+#undef SIMPLE_REORDER_TEMPL_DECL
+#undef SIMPLE_REORDER_TEMPL_CALL
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp
new file mode 100644
index 0000000000..f0947573a9
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.cpp
@@ -0,0 +1,91 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_thread.hpp"
+
+#include "simple_sum.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t data_type>
+status_t simple_sum_t<data_type>::execute(const exec_ctx_t &ctx) const {
+ auto output = CTX_OUT_MEM(data_t *, MKLDNN_ARG_DST);
+
+ const memory_desc_wrapper o_d(pd()->dst_md());
+ output += o_d.blk_off(0);
+
+ const int num_arrs = pd()->n_inputs();
+ const data_t *input_ptrs[max_num_arrs];
+ const size_t nelems = o_d.nelems();
+
+ for (int a = 0; a < num_arrs; ++a) {
+ const memory_desc_wrapper i_d(pd()->src_md(a));
+ input_ptrs[a] = CTX_IN_MEM(const data_t *, MKLDNN_ARG_MULTIPLE_SRC + a)
+ + i_d.blk_off(0);
+ }
+
+ const size_t block_size = 16 * 1024 / sizeof(data_type);
+ const size_t blocks_number = nelems / block_size;
+ const size_t tail = nelems % block_size;
+
+ const auto scales = pd()->scales();
+ parallel(0, [&](const int ithr, const int nthr) {
+ size_t start{0}, end{0};
+ balance211(blocks_number, nthr, ithr, start, end);
+
+ for (size_t nb = start; nb < end; ++nb) {
+ size_t start_e = nb * block_size;
+ size_t end_e = start_e + block_size;
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < end_e; e++) {
+ output[e] = data_t(scales[0] * input_ptrs[0][e]);
+ }
+ for (int a = 1; a < num_arrs; a++) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < end_e; e++) {
+ output[e] += data_t(scales[a] * input_ptrs[a][e]);
+ }
+ }
+ }
+
+ if (tail != 0 && ithr == nthr - 1) {
+ size_t start_e = nelems - tail;
+ size_t end_e = nelems;
+
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < end_e; e++) {
+ output[e] = data_t(scales[0] * input_ptrs[0][e]);
+ }
+ for (int a = 1; a < num_arrs; a++) {
+ PRAGMA_OMP_SIMD()
+ for (size_t e = start_e; e < end_e; e++) {
+ output[e] += data_t(scales[a] * input_ptrs[a][e]);
+ }
+ }
+ }
+ });
+
+ return status::success;
+}
+
+template struct simple_sum_t<data_type::f32>;
+
+}
+}
+}
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp
new file mode 100644
index 0000000000..2a0187a184
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/simple_sum.hpp
@@ -0,0 +1,74 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef SIMPLE_SUM_HPP
+#define SIMPLE_SUM_HPP
+
+#include "cpu_sum_pd.hpp"
+#include "cpu_primitive.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t data_type>
+struct simple_sum_t: public cpu_primitive_t {
+ struct pd_t: public cpu_sum_pd_t {
+ using cpu_sum_pd_t::cpu_sum_pd_t;
+
+ DECLARE_SUM_PD_T("simple:any", simple_sum_t);
+
+ status_t init() {
+ const int n = n_inputs();
+
+ bool ok = true
+ && cpu_sum_pd_t::init() == status::success
+ && n <= max_num_arrs;
+ if (!ok) return status::unimplemented;
+
+ const memory_desc_wrapper o_d(dst_md());
+ ok = ok
+ && o_d.data_type() == data_type
+ && o_d.is_dense();
+ if (!ok) return status::unimplemented;
+
+ for (int i = 0; i < n; ++i) {
+ const memory_desc_wrapper i_d(src_md(i));
+ if (i_d != o_d) return status::unimplemented;
+ }
+
+ return status::success;
+ }
+ };
+
+ simple_sum_t(const pd_t *apd): cpu_primitive_t(apd) {}
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override;
+
+ enum {max_num_arrs = 16 };
+ typedef typename prec_traits<data_type>::type data_t;
+
+private:
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp
new file mode 100644
index 0000000000..c2082d7d62
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/wino_reorder.hpp
@@ -0,0 +1,376 @@
+/*******************************************************************************
+ * Copyright 2017-2018 Intel Corporation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *******************************************************************************/
+
+#ifndef CPU_WINO_REORDER_HPP
+#define CPU_WINO_REORDER_HPP
+
+#include "mkldnn_thread.hpp"
+
+#include "simple_q10n.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace cpu {
+
+template <data_type_t type_i, data_type_t type_o>
+struct wino_reorder_t : public cpu_primitive_t {
+ struct pd_t : public cpu_reorder_pd_t {
+ using cpu_reorder_pd_t::cpu_reorder_pd_t;
+
+ DECLARE_COMMON_PD_T("wino_reorder", wino_reorder_t);
+
+ static status_t create(reorder_pd_t **reorder_pd,
+ engine_t *engine, const primitive_attr_t *attr,
+ engine_t *src_engine, const memory_desc_t *src_md,
+ engine_t *dst_engine, const memory_desc_t *dst_md) {
+ const memory_desc_wrapper id(src_md), od(dst_md);
+ bool args_ok = true
+ && id.data_type() == type_i
+ && od.data_type() == type_o
+ && id.matches_tag(utils::pick(id.ndims() - 4,
+ format_tag::oihw, format_tag::goihw))
+ && od.format_kind() == format_kind::wino
+ && utils::one_of(od.wino_desc().wino_format,
+ mkldnn_wino_wei_aaOIoi, mkldnn_wino_wei_aaOio,
+ mkldnn_wino_wei_aaOBiOo, mkldnn_wino_wei_OBaaIBOIio);
+ if (!args_ok) return status::invalid_arguments;
+
+ auto _pd = new pd_t(engine, attr, src_engine, src_md, dst_engine,
+ dst_md);
+ if (_pd == nullptr) return status::out_of_memory;
+ if (_pd->init() != status::success) {
+ delete _pd;
+ return status::unimplemented;
+ }
+ return safe_ptr_assign<reorder_pd_t>(*reorder_pd, _pd);
+ }
+
+ status_t init() {
+ status_t status = cpu_reorder_pd_t::init();
+ if (status != status::success) return status;
+
+ init_scratchpad();
+
+ return status::success;
+ }
+
+ private:
+ void init_scratchpad() {
+ auto &o = memory_desc_wrapper(dst_md()).wino_desc();
+ size_t transform_space_size = (size_t)o.r * o.alpha * o.oc_block;
+ size_t plain_size = (size_t)o.alpha * o.alpha * o.oc * o.ic;
+
+ using namespace memory_tracking::names;
+ auto scratchpad = scratchpad_registry().registrar();
+ scratchpad.book(key_reorder_wino_transform_space,
+ sizeof(in_data_t) * transform_space_size);
+ scratchpad.book(key_reorder_wino_plain,
+ sizeof(out_data_t) * plain_size);
+ }
+ };
+
+private:
+ typedef typename prec_traits<type_i>::type in_data_t;
+ typedef typename prec_traits<type_o>::type out_data_t;
+ const int unsign_val_in_wino_domain_ = 5;
+
+ wino_reorder_t(const pd_t *apd): cpu_primitive_t(apd) {
+ const memory_desc_wrapper src_d(pd()->src_md());
+ const memory_desc_wrapper dst_d(pd()->dst_md());
+
+ r_ = dst_d.wino_desc().r;
+ w_alpha_ = dst_d.wino_desc().alpha;
+ wino_format_ = dst_d.wino_desc().wino_format;
+
+ const auto &in_dims = src_d.dims();
+ int groups;
+ int groups_offset;
+ if (src_d.ndims() == 5) {
+ groups = in_dims[0];
+ groups_offset = 1;
+ } else {
+ groups = 1;
+ groups_offset = 0;
+ }
+ assert(groups == 1); // groups are not supported now
+ MAYBE_UNUSED(groups);
+
+ or_oc_ = in_dims[0 + groups_offset];
+ or_ic_ = in_dims[1 + groups_offset];
+ kh_ = in_dims[2 + groups_offset];
+ kw_ = in_dims[3 + groups_offset];
+
+ oc_ = dst_d.wino_desc().oc;
+ ic_ = dst_d.wino_desc().ic;
+ oc_block_ = dst_d.wino_desc().oc_block;
+ ic_block_ = dst_d.wino_desc().ic_block;
+ assert(oc_ % oc_block_ == 0 && ic_ % ic_block_ == 0);
+ nb_oc_ = oc_ / oc_block_;
+ nb_ic_ = ic_ / ic_block_;
+ ic2_block_ = 1;
+ if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio)
+ ic2_block_ = dst_d.wino_desc().ic2_block;
+ oc2_block_ = dst_d.wino_desc().oc2_block;
+ assert(nb_ic_ % ic2_block_ == 0 && nb_oc_ % oc2_block_ == 0);
+
+ adj_scale_ = dst_d.wino_desc().adj_scale;
+
+ size_wino_wei_ = w_alpha_ * w_alpha_ * oc_ * ic_;
+ size_wspace_ = r_ * w_alpha_ * oc_block_;
+ }
+
+ void transform(out_data_t *__restrict tmp_wei,
+ const in_data_t *__restrict input,
+ in_data_t *__restrict wspace) const {
+ const memory_desc_wrapper src_d(pd()->src_md());
+
+ const int smask = pd()->attr()->output_scales_.mask_;
+ const int ndims_mask = math::ilog2q(smask + 1);
+ const size_t D_mask = utils::array_product(src_d.dims(), ndims_mask);
+ const float *__restrict scales = pd()->attr()->output_scales_.scales_;
+ assert(D_mask == 1 || D_mask == (size_t)oc_);
+
+ /* transform weights to winograd domain */
+ const float G_2x2_3x3[4][3] = { { 1.0, 0.0, 0.0 }, { 0.5, 0.5, 0.5 },
+ { 0.5, -0.5, 0.5 }, { 0.0, 0.0, 1.0 } };
+
+ const float G_4x4_3x3[6][3] = { { 1.13777777777778f, 0.f, 0.f },
+ { -0.688403361344538f, -0.430252100840336f, -0.26890756302521f },
+ { -0.688403361344538f, 0.430252100840336f, -0.26890756302521f },
+ { 0.119514472455649f, 0.179271708683473f, 0.26890756302521f },
+ { 0.119514472455649f, -0.179271708683473f, 0.26890756302521f },
+ { 0.f, 0.f, 1.f } };
+
+ float *__restrict g;
+ if (utils::one_of(wino_format_, mkldnn_wino_wei_aaOIoi,
+ mkldnn_wino_wei_aaOio, mkldnn_wino_wei_aaOBiOo))
+ g = (float *)G_2x2_3x3;
+ else if (wino_format_ == mkldnn_wino_wei_OBaaIBOIio)
+ g = (float *)G_4x4_3x3;
+ else {
+ assert("Unknown winograd weights target layout");
+ return;
+ }
+
+ int Z = oc_ * ic_;
+ assert(r_ == kh_ && r_ == kw_);
+
+ for (int iic = 0; iic < ic_; iic++) {
+ for (int ob = 0; ob < nb_oc_; ob++) {
+ const in_data_t *__restrict _inp
+ = input + (ob * oc_block_ * or_ic_ + iic) * kh_ * kw_;
+ out_data_t *__restrict _out
+ = tmp_wei + (iic * nb_oc_ + ob) * oc_block_;
+
+ for_nd(0, 1, size_wspace_, [&](int i) { wspace[i] = 0.f; });
+
+ for_nd(0, 1, r_, w_alpha_, oc_block_,
+ [&](int ih, int j, int ioc) {
+ for (int iw = 0; iw < r_; ++iw) {
+ int inp_oc = ob * oc_block_ + ioc;
+ int inp_ic = iic;
+ in_data_t inp_v = (inp_ic < or_ic_ && inp_oc < or_oc_)
+ ? _inp[ioc * or_ic_ * kh_ * kw_ + ih * kw_ + iw]
+ : 0.f;
+ wspace[(ih * w_alpha_ + j) * oc_block_ + ioc]
+ += inp_v * g[j * r_ + iw];
+ }
+ });
+
+ for_nd(0, 1, w_alpha_, w_alpha_, oc_block_,
+ [&](int i, int j, int ioc) {
+ float t = 0;
+ for (int k = 0; k < r_; ++k)
+ t += g[i * r_ + k]
+ * wspace[(k * w_alpha_ + j) * oc_block_ + ioc];
+ if (type_o == data_type::s8) {
+ const float scale = (D_mask == 1)
+ ? scales[0]
+ : scales[ob * oc_block_ + ioc];
+ _out[(i * w_alpha_ + j) * Z + ioc]
+ = qz_b0<in_data_t, out_data_t>()(
+ (in_data_t)t, scale * adj_scale_);
+ } else {
+ _out[(i * w_alpha_ + j) * Z + ioc] = (out_data_t)t;
+ }
+ });
+ }}
+ }
+
+ void reorder_to_aaOIoi(out_data_t *__restrict output,
+ const out_data_t *__restrict tmp_wei) const {
+ int32_t *__restrict dst_bias = nullptr;
+ if (type_o == data_type::s8) {
+ const auto bias_shift = sizeof(out_data_t) * size_wino_wei_;
+ const size_t bias_size = w_alpha_ * w_alpha_ * oc_;
+
+ dst_bias = (int32_t *)(output + bias_shift);
+ utils::array_set((int32_t *)dst_bias, 0, bias_size);
+ }
+ int index = 0;
+ for (int u_h = 0; u_h < w_alpha_; u_h++) {
+ for (int u_w = 0; u_w < w_alpha_; u_w++) {
+ for_nd(0, 1, nb_oc_, oc_block_, [&](int ob, int o) {
+ int u_h_shift = u_h * w_alpha_ * ic_ * oc_;
+ int u_w_shift = u_w * ic_ * oc_;
+ int u_h_shift_b = u_h * w_alpha_ * oc_;
+ int u_w_shift_b = u_w * oc_;
+ int oc_block_shift = ob * oc_block_ * ic_ + o * ic_block_;
+ for (int ib = 0; ib < nb_ic_; ib++) {
+ for (int i = 0; i < ic_block_; i++) {
+ int _i = ib * ic_block_;
+ int _o = ob * oc_block_;
+ int ic_shift = (_i + i) * oc_;
+ int oc_shift = (_o + o);
+ int ic_block_shift = ib * oc_block_ * ic_block_ + i;
+ int src_offset =
+ u_h_shift + u_w_shift + ic_shift + oc_shift;
+ int dst_offset = u_h_shift + u_w_shift + oc_block_shift
+ + ic_block_shift;
+
+ output[dst_offset] = tmp_wei[src_offset];
+ if (type_o == data_type::s8) {
+ int bias_offset = u_h_shift_b + u_w_shift_b + oc_shift;
+ if (index != unsign_val_in_wino_domain_)
+ dst_bias[bias_offset]
+ -= (128 * (int32_t)output[dst_offset]);
+ else
+ dst_bias[bias_offset] = 0;
+ }
+ }}
+ });
+ index++;
+ }}
+ }
+
+ void reorder_to_aaOio(out_data_t *__restrict output,
+ const out_data_t *__restrict tmp_wei) const {
+ for_nd(0, 1, w_alpha_, w_alpha_, nb_oc_,
+ [&](int u_h, int u_w, int ob) {
+ for (int ib = 0; ib < nb_ic_; ib++) {
+ for (int i = 0; i < ic_block_; i++) {
+ for (int o = 0; o < oc_block_; o++) {
+ int src_offset = u_h * w_alpha_ * ic_ * oc_ + u_w * ic_ * oc_
+ + (ib * ic_block_ + i) * oc_ + (ob * oc_block_ + o);
+
+ int dst_offset
+ = u_h * w_alpha_ * nb_oc_ * nb_ic_ * ic_block_ * oc_block_
+ + u_w * nb_oc_ * nb_ic_ * ic_block_ * oc_block_
+ + ob * nb_ic_ * ic_block_ * oc_block_
+ + ib * ic_block_ * oc_block_ + i * oc_block_ + o;
+ output[dst_offset] = tmp_wei[src_offset];
+ }}}
+ });
+ }
+
+ void reorder_to_aaOBiOo(out_data_t *__restrict output,
+ const out_data_t *__restrict tmp_wei) const {
+ int oc_chunks = nb_oc_ / oc2_block_;
+
+ for_nd(0, 1, w_alpha_, w_alpha_, oc_chunks,
+ [&](int u_h, int u_w, int occ) {
+ for (int ib = 0; ib < nb_ic_; ib++) {
+ out_data_t *__restrict wei_ptr = output
+ + (((u_h * w_alpha_ + u_w) * oc_chunks + occ) * nb_ic_ + ib)
+ * oc2_block_ * ic_block_ * oc_block_;
+ int wei_offset = 0;
+ for (int i = 0; i < ic_block_; i++) {
+ for (int ob2 = 0; ob2 < oc2_block_; ob2++) {
+ for (int o = 0; o < oc_block_; o++) {
+ int icp = ib * ic_block_ + i;
+ int ocp =
+ occ * oc2_block_ * oc_block_ + ob2 * oc_block_ + o;
+
+ int src_offset = u_h * w_alpha_ * ic_ * oc_
+ + u_w * ic_ * oc_ + icp * oc_ + ocp;
+ wei_ptr[wei_offset + o] = tmp_wei[src_offset];
+ }
+ wei_offset += oc_block_;
+ }}
+ }
+ });
+ }
+
+ void reorder_to_OBaaIBOIio(out_data_t *__restrict output,
+ const out_data_t *__restrict tmp_wei) const {
+ int ic_chunks = nb_ic_ / ic2_block_;
+ int oc_chunks = nb_oc_ / oc2_block_;
+
+ for_nd(0, 1, oc_chunks, w_alpha_, w_alpha_,
+ [&](int occ, int u_h, int u_w) {
+ for (int icc = 0; icc < ic_chunks; icc++) {
+ for (int ob = 0; ob < oc2_block_; ob++) {
+ int ocp = (occ * oc2_block_ + ob) * oc_block_;
+ for (int ib = 0; ib < ic2_block_; ib++) {
+ for (int i = 0; i < ic_block_; i++) {
+ int icp = (icc * ic2_block_ + ib) * ic_block_ + i;
+
+ int src_offset = u_h * w_alpha_ * ic_ * oc_
+ + u_w * ic_ * oc_ + icp * oc_ + ocp;
+ int wei_offset
+ = ((((((occ * w_alpha_ + u_h) * w_alpha_ + u_w)
+ * ic_chunks + icc) * oc2_block_ + ob) * ic2_block_
+ + ib) * ic_block_ + i) * oc_block_;
+ for (int o = 0; o < oc_block_; o++)
+ output[wei_offset + o] = tmp_wei[src_offset + o];
+ }}
+ }}
+ });
+ }
+
+ virtual status_t execute(const exec_ctx_t &ctx) const override {
+ auto input = CTX_IN_MEM(const in_data_t *, MKLDNN_ARG_FROM);
+ auto output = CTX_OUT_MEM(out_data_t *, MKLDNN_ARG_TO);
+
+ auto wspace = (in_data_t *__restrict)scratchpad(ctx).template get<void>(
+ memory_tracking::names::key_reorder_wino_transform_space);
+ auto tmp_wei = (out_data_t *__restrict)scratchpad(ctx).template get<void>(
+ memory_tracking::names::key_reorder_wino_plain);
+
+ transform(tmp_wei, input, wspace);
+
+ /* reorder to winograd domain */
+ switch (wino_format_) {
+ case mkldnn_wino_wei_aaOIoi:
+ reorder_to_aaOIoi(output, tmp_wei); break;
+ case mkldnn_wino_wei_aaOio:
+ reorder_to_aaOio(output, tmp_wei); break;
+ case mkldnn_wino_wei_aaOBiOo:
+ reorder_to_aaOBiOo(output, tmp_wei); break;
+ case mkldnn_wino_wei_OBaaIBOIio:
+ reorder_to_OBaaIBOIio(output, tmp_wei); break;
+ default: assert("Unknown wino format"); break;
+ }
+
+ return status::success;
+ }
+
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
+ int r_, w_alpha_;
+ int ic_, oc_, or_ic_, or_oc_, kh_, kw_;
+ int oc_block_, ic_block_, oc2_block_, ic2_block_;
+ float adj_scale_;
+ int nb_oc_, nb_ic_;
+ mkldnn_wino_memory_format_t wino_format_;
+ int size_wino_wei_;
+ int size_wspace_;
+};
+
+} // namespace cpu
+} // namespace impl
+} // namespace mkldnn
+
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT
new file mode 100644
index 0000000000..66b6ea55d0
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/COPYRIGHT
@@ -0,0 +1,47 @@
+
+Copyright (c) 2007 MITSUNARI Shigeo
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+Redistributions of source code must retain the above copyright notice, this
+list of conditions and the following disclaimer.
+Redistributions in binary form must reproduce the above copyright notice,
+this list of conditions and the following disclaimer in the documentation
+and/or other materials provided with the distribution.
+Neither the name of the copyright owner nor the names of its contributors may
+be used to endorse or promote products derived from this software without
+specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+THE POSSIBILITY OF SUCH DAMAGE.
+-----------------------------------------------------------------------------
+ソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た
+す場合に限り、再頒布および使用が許可されます。
+
+ソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項
+を含めること。
+バイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作
+権表示、本条件一覧、および下記免責条項を含めること。
+書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進
+に、著作権者の名前またはコントリビューターの名前を使用してはならない。
+本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ
+れており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性
+に関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。
+著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを
+問わず、かつ責任の根拠が契約であるか厳格責任であるか(過失その他の)不法行為で
+あるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、
+本ソフトウェアの使用によって発生した(代替品または代用サービスの調達、使用の
+喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない)直接
+損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、
+一切責任を負わないものとします。
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h
new file mode 100644
index 0000000000..cf5771332f
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak.h
@@ -0,0 +1,2658 @@
+/*******************************************************************************
+* Copyright 2016-2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*******************************************************************************
+* Copyright (c) 2007 MITSUNARI Shigeo
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+* Redistributions of source code must retain the above copyright notice, this
+* list of conditions and the following disclaimer.
+* Redistributions in binary form must reproduce the above copyright notice,
+* this list of conditions and the following disclaimer in the documentation
+* and/or other materials provided with the distribution.
+* Neither the name of the copyright owner nor the names of its contributors may
+* be used to endorse or promote products derived from this software without
+* specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+*******************************************************************************/
+
+#pragma once
+#ifndef XBYAK_XBYAK_H_
+#define XBYAK_XBYAK_H_
+/*!
+ @file xbyak.h
+ @brief Xbyak ; JIT assembler for x86(IA32)/x64 by C++
+ @author herumi
+ @url https://github.com/herumi/xbyak
+ @note modified new BSD license
+ http://opensource.org/licenses/BSD-3-Clause
+*/
+#ifndef XBYAK_NO_OP_NAMES
+ #if not +0 // trick to detect whether 'not' is operator or not
+ #error "use -fno-operator-names option if you want to use and(), or(), xor(), not() as function names, Or define XBYAK_NO_OP_NAMES and use and_(), or_(), xor_(), not_()."
+ #endif
+#endif
+
+#include <stdio.h> // for debug print
+#include <assert.h>
+#include <list>
+#include <string>
+#include <algorithm>
+#ifndef NDEBUG
+#include <iostream>
+#endif
+
+// #define XBYAK_DISABLE_AVX512
+
+//#define XBYAK_USE_MMAP_ALLOCATOR
+#if !defined(__GNUC__) || defined(__MINGW32__)
+ #undef XBYAK_USE_MMAP_ALLOCATOR
+#endif
+
+#ifdef __GNUC__
+ #define XBYAK_GNUC_PREREQ(major, minor) ((__GNUC__) * 100 + (__GNUC_MINOR__) >= (major) * 100 + (minor))
+#else
+ #define XBYAK_GNUC_PREREQ(major, minor) 0
+#endif
+
+// This covers -std=(gnu|c)++(0x|11|1y), -stdlib=libc++, and modern Microsoft.
+#if ((defined(_MSC_VER) && (_MSC_VER >= 1600)) || defined(_LIBCPP_VERSION) ||\
+ ((__cplusplus >= 201103) || defined(__GXX_EXPERIMENTAL_CXX0X__)))
+ #include <unordered_set>
+ #define XBYAK_STD_UNORDERED_SET std::unordered_set
+ #include <unordered_map>
+ #define XBYAK_STD_UNORDERED_MAP std::unordered_map
+ #define XBYAK_STD_UNORDERED_MULTIMAP std::unordered_multimap
+
+/*
+ Clang/llvm-gcc and ICC-EDG in 'GCC-mode' always claim to be GCC 4.2, using
+ libstdcxx 20070719 (from GCC 4.2.1, the last GPL 2 version).
+*/
+#elif XBYAK_GNUC_PREREQ(4, 5) || (XBYAK_GNUC_PREREQ(4, 2) && __GLIBCXX__ >= 20070719) || defined(__INTEL_COMPILER) || defined(__llvm__)
+ #include <tr1/unordered_set>
+ #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set
+ #include <tr1/unordered_map>
+ #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map
+ #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap
+
+#elif defined(_MSC_VER) && (_MSC_VER >= 1500) && (_MSC_VER < 1600)
+ #include <unordered_set>
+ #define XBYAK_STD_UNORDERED_SET std::tr1::unordered_set
+ #include <unordered_map>
+ #define XBYAK_STD_UNORDERED_MAP std::tr1::unordered_map
+ #define XBYAK_STD_UNORDERED_MULTIMAP std::tr1::unordered_multimap
+
+#else
+ #include <set>
+ #define XBYAK_STD_UNORDERED_SET std::set
+ #include <map>
+ #define XBYAK_STD_UNORDERED_MAP std::map
+ #define XBYAK_STD_UNORDERED_MULTIMAP std::multimap
+#endif
+#ifdef _WIN32
+ #include <winsock2.h>
+ #include <windows.h>
+ #include <malloc.h>
+#elif defined(__GNUC__)
+ #include <unistd.h>
+ #include <sys/mman.h>
+ #include <stdlib.h>
+#endif
+#if !defined(_MSC_VER) || (_MSC_VER >= 1600)
+ #include <stdint.h>
+#endif
+
+#if defined(_WIN64) || defined(__MINGW64__) || (defined(__CYGWIN__) && defined(__x86_64__))
+ #define XBYAK64_WIN
+#elif defined(__x86_64__)
+ #define XBYAK64_GCC
+#endif
+#if !defined(XBYAK64) && !defined(XBYAK32)
+ #if defined(XBYAK64_GCC) || defined(XBYAK64_WIN)
+ #define XBYAK64
+ #else
+ #define XBYAK32
+ #endif
+#endif
+
+#if (__cplusplus >= 201103) || (_MSC_VER >= 1800)
+ #define XBYAK_VARIADIC_TEMPLATE
+#endif
+
+#ifdef _MSC_VER
+ #pragma warning(push)
+ #pragma warning(disable : 4514) /* remove inline function */
+ #pragma warning(disable : 4786) /* identifier is too long */
+ #pragma warning(disable : 4503) /* name is too long */
+ #pragma warning(disable : 4127) /* constant expresison */
+#endif
+
+namespace Xbyak {
+
+enum {
+ DEFAULT_MAX_CODE_SIZE = 4096,
+ VERSION = 0x5760 /* 0xABCD = A.BC(D) */
+};
+
+#ifndef MIE_INTEGER_TYPE_DEFINED
+#define MIE_INTEGER_TYPE_DEFINED
+#ifdef _MSC_VER
+ typedef unsigned __int64 uint64;
+ typedef __int64 sint64;
+#else
+ typedef uint64_t uint64;
+ typedef int64_t sint64;
+#endif
+typedef unsigned int uint32;
+typedef unsigned short uint16;
+typedef unsigned char uint8;
+#endif
+
+#ifndef MIE_ALIGN
+ #ifdef _MSC_VER
+ #define MIE_ALIGN(x) __declspec(align(x))
+ #else
+ #define MIE_ALIGN(x) __attribute__((aligned(x)))
+ #endif
+#endif
+#ifndef MIE_PACK // for shufps
+ #define MIE_PACK(x, y, z, w) ((x) * 64 + (y) * 16 + (z) * 4 + (w))
+#endif
+
+enum {
+ ERR_NONE = 0,
+ ERR_BAD_ADDRESSING,
+ ERR_CODE_IS_TOO_BIG,
+ ERR_BAD_SCALE,
+ ERR_ESP_CANT_BE_INDEX,
+ ERR_BAD_COMBINATION,
+ ERR_BAD_SIZE_OF_REGISTER,
+ ERR_IMM_IS_TOO_BIG,
+ ERR_BAD_ALIGN,
+ ERR_LABEL_IS_REDEFINED,
+ ERR_LABEL_IS_TOO_FAR,
+ ERR_LABEL_IS_NOT_FOUND,
+ ERR_CODE_ISNOT_COPYABLE,
+ ERR_BAD_PARAMETER,
+ ERR_CANT_PROTECT,
+ ERR_CANT_USE_64BIT_DISP,
+ ERR_OFFSET_IS_TOO_BIG,
+ ERR_MEM_SIZE_IS_NOT_SPECIFIED,
+ ERR_BAD_MEM_SIZE,
+ ERR_BAD_ST_COMBINATION,
+ ERR_OVER_LOCAL_LABEL, // not used
+ ERR_UNDER_LOCAL_LABEL,
+ ERR_CANT_ALLOC,
+ ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW,
+ ERR_BAD_PROTECT_MODE,
+ ERR_BAD_PNUM,
+ ERR_BAD_TNUM,
+ ERR_BAD_VSIB_ADDRESSING,
+ ERR_CANT_CONVERT,
+ ERR_LABEL_ISNOT_SET_BY_L,
+ ERR_LABEL_IS_ALREADY_SET_BY_L,
+ ERR_BAD_LABEL_STR,
+ ERR_MUNMAP,
+ ERR_OPMASK_IS_ALREADY_SET,
+ ERR_ROUNDING_IS_ALREADY_SET,
+ ERR_K0_IS_INVALID,
+ ERR_EVEX_IS_INVALID,
+ ERR_SAE_IS_INVALID,
+ ERR_ER_IS_INVALID,
+ ERR_INVALID_BROADCAST,
+ ERR_INVALID_OPMASK_WITH_MEMORY,
+ ERR_INVALID_ZERO,
+ ERR_INVALID_RIP_IN_AUTO_GROW,
+ ERR_INVALID_MIB_ADDRESS,
+ ERR_INTERNAL,
+ ERR_X2APIC_IS_NOT_SUPPORTED
+};
+
+class Error : public std::exception {
+ int err_;
+public:
+ explicit Error(int err) : err_(err)
+ {
+ if (err_ < 0 || err_ > ERR_INTERNAL) {
+ fprintf(stderr, "bad err=%d in Xbyak::Error\n", err_);
+ //exit(1);
+ }
+ }
+ operator int() const { return err_; }
+ const char *what() const throw()
+ {
+ static const char *errTbl[] = {
+ "none",
+ "bad addressing",
+ "code is too big",
+ "bad scale",
+ "esp can't be index",
+ "bad combination",
+ "bad size of register",
+ "imm is too big",
+ "bad align",
+ "label is redefined",
+ "label is too far",
+ "label is not found",
+ "code is not copyable",
+ "bad parameter",
+ "can't protect",
+ "can't use 64bit disp(use (void*))",
+ "offset is too big",
+ "MEM size is not specified",
+ "bad mem size",
+ "bad st combination",
+ "over local label",
+ "under local label",
+ "can't alloc",
+ "T_SHORT is not supported in AutoGrow",
+ "bad protect mode",
+ "bad pNum",
+ "bad tNum",
+ "bad vsib addressing",
+ "can't convert",
+ "label is not set by L()",
+ "label is already set by L()",
+ "bad label string",
+ "err munmap",
+ "opmask is already set",
+ "rounding is already set",
+ "k0 is invalid",
+ "evex is invalid",
+ "sae(suppress all exceptions) is invalid",
+ "er(embedded rounding) is invalid",
+ "invalid broadcast",
+ "invalid opmask with memory",
+ "invalid zero",
+ "invalid rip in AutoGrow",
+ "invalid mib address",
+ "internal error",
+ "x2APIC is not supported"
+ };
+ assert((size_t)err_ < sizeof(errTbl) / sizeof(*errTbl));
+ return errTbl[err_];
+ }
+};
+
+inline const char *ConvertErrorToString(const Error& err)
+{
+ return err.what();
+}
+
+inline void *AlignedMalloc(size_t size, size_t alignment)
+{
+#ifdef __MINGW32__
+ return __mingw_aligned_malloc(size, alignment);
+#elif defined(_WIN32)
+ return _aligned_malloc(size, alignment);
+#else
+ void *p;
+ int ret = posix_memalign(&p, alignment, size);
+ return (ret == 0) ? p : 0;
+#endif
+}
+
+inline void AlignedFree(void *p)
+{
+#ifdef __MINGW32__
+ __mingw_aligned_free(p);
+#elif defined(_MSC_VER)
+ _aligned_free(p);
+#else
+ free(p);
+#endif
+}
+
+template<class To, class From>
+inline const To CastTo(From p) throw()
+{
+ return (const To)(size_t)(p);
+}
+namespace inner {
+
+static const size_t ALIGN_PAGE_SIZE = 4096;
+
+inline bool IsInDisp8(uint32 x) { return 0xFFFFFF80 <= x || x <= 0x7F; }
+inline bool IsInInt32(uint64 x) { return ~uint64(0x7fffffffu) <= x || x <= 0x7FFFFFFFU; }
+
+inline uint32 VerifyInInt32(uint64 x)
+{
+#ifdef XBYAK64
+ if (!IsInInt32(x)) throw Error(ERR_OFFSET_IS_TOO_BIG);
+#endif
+ return static_cast<uint32>(x);
+}
+
+enum LabelMode {
+ LasIs, // as is
+ Labs, // absolute
+ LaddTop // (addr + top) for mov(reg, label) with AutoGrow
+};
+
+} // inner
+
+/*
+ custom allocator
+*/
+struct Allocator {
+ virtual uint8 *alloc(size_t size) { return reinterpret_cast<uint8*>(AlignedMalloc(size, inner::ALIGN_PAGE_SIZE)); }
+ virtual void free(uint8 *p) { AlignedFree(p); }
+ virtual ~Allocator() {}
+ /* override to return false if you call protect() manually */
+ virtual bool useProtect() const { return true; }
+};
+
+#ifdef XBYAK_USE_MMAP_ALLOCATOR
+class MmapAllocator : Allocator {
+ typedef XBYAK_STD_UNORDERED_MAP<uintptr_t, size_t> SizeList;
+ SizeList sizeList_;
+public:
+ uint8 *alloc(size_t size)
+ {
+ const size_t alignedSizeM1 = inner::ALIGN_PAGE_SIZE - 1;
+ size = (size + alignedSizeM1) & ~alignedSizeM1;
+#ifdef MAP_ANONYMOUS
+ const int mode = MAP_PRIVATE | MAP_ANONYMOUS;
+#elif defined(MAP_ANON)
+ const int mode = MAP_PRIVATE | MAP_ANON;
+#else
+ #error "not supported"
+#endif
+ void *p = mmap(NULL, size, PROT_READ | PROT_WRITE, mode, -1, 0);
+ if (p == MAP_FAILED) throw Error(ERR_CANT_ALLOC);
+ assert(p);
+ sizeList_[(uintptr_t)p] = size;
+ return (uint8*)p;
+ }
+ void free(uint8 *p)
+ {
+ if (p == 0) return;
+ SizeList::iterator i = sizeList_.find((uintptr_t)p);
+ if (i == sizeList_.end()) throw Error(ERR_BAD_PARAMETER);
+ if (munmap((void*)i->first, i->second) < 0) throw Error(ERR_MUNMAP);
+ sizeList_.erase(i);
+ }
+};
+#endif
+
+class Address;
+class Reg;
+
+class Operand {
+ static const uint8 EXT8BIT = 0x20;
+ unsigned int idx_:6; // 0..31 + EXT8BIT = 1 if spl/bpl/sil/dil
+ unsigned int kind_:9;
+ unsigned int bit_:10;
+protected:
+ unsigned int zero_:1;
+ unsigned int mask_:3;
+ unsigned int rounding_:3;
+ void setIdx(int idx) { idx_ = idx; }
+public:
+ enum Kind {
+ NONE = 0,
+ MEM = 1 << 0,
+ REG = 1 << 1,
+ MMX = 1 << 2,
+ FPU = 1 << 3,
+ XMM = 1 << 4,
+ YMM = 1 << 5,
+ ZMM = 1 << 6,
+ OPMASK = 1 << 7,
+ BNDREG = 1 << 8
+ };
+ enum Code {
+#ifdef XBYAK64
+ RAX = 0, RCX, RDX, RBX, RSP, RBP, RSI, RDI, R8, R9, R10, R11, R12, R13, R14, R15,
+ R8D = 8, R9D, R10D, R11D, R12D, R13D, R14D, R15D,
+ R8W = 8, R9W, R10W, R11W, R12W, R13W, R14W, R15W,
+ R8B = 8, R9B, R10B, R11B, R12B, R13B, R14B, R15B,
+ SPL = 4, BPL, SIL, DIL,
+#endif
+ EAX = 0, ECX, EDX, EBX, ESP, EBP, ESI, EDI,
+ AX = 0, CX, DX, BX, SP, BP, SI, DI,
+ AL = 0, CL, DL, BL, AH, CH, DH, BH
+ };
+ Operand() : idx_(0), kind_(0), bit_(0), zero_(0), mask_(0), rounding_(0) { }
+ Operand(int idx, Kind kind, int bit, bool ext8bit = 0)
+ : idx_(static_cast<uint8>(idx | (ext8bit ? EXT8BIT : 0)))
+ , kind_(kind)
+ , bit_(bit)
+ , zero_(0), mask_(0), rounding_(0)
+ {
+ assert((bit_ & (bit_ - 1)) == 0); // bit must be power of two
+ }
+ Kind getKind() const { return static_cast<Kind>(kind_); }
+ int getIdx() const { return idx_ & (EXT8BIT - 1); }
+ bool isNone() const { return kind_ == 0; }
+ bool isMMX() const { return is(MMX); }
+ bool isXMM() const { return is(XMM); }
+ bool isYMM() const { return is(YMM); }
+ bool isZMM() const { return is(ZMM); }
+ bool isXMEM() const { return is(XMM | MEM); }
+ bool isYMEM() const { return is(YMM | MEM); }
+ bool isZMEM() const { return is(ZMM | MEM); }
+ bool isOPMASK() const { return is(OPMASK); }
+ bool isBNDREG() const { return is(BNDREG); }
+ bool isREG(int bit = 0) const { return is(REG, bit); }
+ bool isMEM(int bit = 0) const { return is(MEM, bit); }
+ bool isFPU() const { return is(FPU); }
+ bool isExt8bit() const { return (idx_ & EXT8BIT) != 0; }
+ bool isExtIdx() const { return (getIdx() & 8) != 0; }
+ bool isExtIdx2() const { return (getIdx() & 16) != 0; }
+ bool hasEvex() const { return isZMM() || isExtIdx2() || getOpmaskIdx() || getRounding(); }
+ bool hasRex() const { return isExt8bit() || isREG(64) || isExtIdx(); }
+ bool hasZero() const { return zero_; }
+ int getOpmaskIdx() const { return mask_; }
+ int getRounding() const { return rounding_; }
+ void setKind(Kind kind)
+ {
+ if ((kind & (XMM|YMM|ZMM)) == 0) return;
+ kind_ = kind;
+ bit_ = kind == XMM ? 128 : kind == YMM ? 256 : 512;
+ }
+ void setBit(int bit) { bit_ = bit; }
+ void setOpmaskIdx(int idx, bool ignore_idx0 = false)
+ {
+ if (!ignore_idx0 && idx == 0) throw Error(ERR_K0_IS_INVALID);
+ if (mask_) throw Error(ERR_OPMASK_IS_ALREADY_SET);
+ mask_ = idx;
+ }
+ void setRounding(int idx)
+ {
+ if (rounding_) throw Error(ERR_ROUNDING_IS_ALREADY_SET);
+ rounding_ = idx;
+ }
+ void setZero() { zero_ = true; }
+ // ah, ch, dh, bh?
+ bool isHigh8bit() const
+ {
+ if (!isBit(8)) return false;
+ if (isExt8bit()) return false;
+ const int idx = getIdx();
+ return AH <= idx && idx <= BH;
+ }
+ // any bit is accetable if bit == 0
+ bool is(int kind, uint32 bit = 0) const
+ {
+ return (kind == 0 || (kind_ & kind)) && (bit == 0 || (bit_ & bit)); // cf. you can set (8|16)
+ }
+ bool isBit(uint32 bit) const { return (bit_ & bit) != 0; }
+ uint32 getBit() const { return bit_; }
+ const char *toString() const
+ {
+ const int idx = getIdx();
+ if (kind_ == REG) {
+ if (isExt8bit()) {
+ static const char *tbl[4] = { "spl", "bpl", "sil", "dil" };
+ return tbl[idx - 4];
+ }
+ static const char *tbl[4][16] = {
+ { "al", "cl", "dl", "bl", "ah", "ch", "dh", "bh", "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b" },
+ { "ax", "cx", "dx", "bx", "sp", "bp", "si", "di", "r8w", "r9w", "r10w", "r11w", "r12w", "r13w", "r14w", "r15w" },
+ { "eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d" },
+ { "rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15" },
+ };
+ return tbl[bit_ == 8 ? 0 : bit_ == 16 ? 1 : bit_ == 32 ? 2 : 3][idx];
+ } else if (isOPMASK()) {
+ static const char *tbl[8] = { "k0", "k1", "k2", "k3", "k4", "k5", "k6", "k7" };
+ return tbl[idx];
+ } else if (isZMM()) {
+ static const char *tbl[32] = {
+ "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", "zmm14", "zmm15",
+ "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", "zmm30", "zmm31"
+ };
+ return tbl[idx];
+ } else if (isYMM()) {
+ static const char *tbl[32] = {
+ "ymm0", "ymm1", "ymm2", "ymm3", "ymm4", "ymm5", "ymm6", "ymm7", "ymm8", "ymm9", "ymm10", "ymm11", "ymm12", "ymm13", "ymm14", "ymm15",
+ "ymm16", "ymm17", "ymm18", "ymm19", "ymm20", "ymm21", "ymm22", "ymm23", "ymm24", "ymm25", "ymm26", "ymm27", "ymm28", "ymm29", "ymm30", "ymm31"
+ };
+ return tbl[idx];
+ } else if (isXMM()) {
+ static const char *tbl[32] = {
+ "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", "xmm8", "xmm9", "xmm10", "xmm11", "xmm12", "xmm13", "xmm14", "xmm15",
+ "xmm16", "xmm17", "xmm18", "xmm19", "xmm20", "xmm21", "xmm22", "xmm23", "xmm24", "xmm25", "xmm26", "xmm27", "xmm28", "xmm29", "xmm30", "xmm31"
+ };
+ return tbl[idx];
+ } else if (isMMX()) {
+ static const char *tbl[8] = { "mm0", "mm1", "mm2", "mm3", "mm4", "mm5", "mm6", "mm7" };
+ return tbl[idx];
+ } else if (isFPU()) {
+ static const char *tbl[8] = { "st0", "st1", "st2", "st3", "st4", "st5", "st6", "st7" };
+ return tbl[idx];
+ } else if (isBNDREG()) {
+ static const char *tbl[4] = { "bnd0", "bnd1", "bnd2", "bnd3" };
+ return tbl[idx];
+ }
+ throw Error(ERR_INTERNAL);
+ }
+ bool isEqualIfNotInherited(const Operand& rhs) const { return idx_ == rhs.idx_ && kind_ == rhs.kind_ && bit_ == rhs.bit_ && zero_ == rhs.zero_ && mask_ == rhs.mask_ && rounding_ == rhs.rounding_; }
+ bool operator==(const Operand& rhs) const;
+ bool operator!=(const Operand& rhs) const { return !operator==(rhs); }
+ const Address& getAddress() const;
+ const Reg& getReg() const;
+};
+
+class Label;
+
+struct Reg8;
+struct Reg16;
+struct Reg32;
+#ifdef XBYAK64
+struct Reg64;
+#endif
+class Reg : public Operand {
+public:
+ Reg() { }
+ Reg(int idx, Kind kind, int bit = 0, bool ext8bit = false) : Operand(idx, kind, bit, ext8bit) { }
+ Reg changeBit(int bit) const { return Reg(getIdx(), getKind(), bit, isExt8bit()); }
+ uint8 getRexW() const { return isREG(64) ? 8 : 0; }
+ uint8 getRexR() const { return isExtIdx() ? 4 : 0; }
+ uint8 getRexX() const { return isExtIdx() ? 2 : 0; }
+ uint8 getRexB() const { return isExtIdx() ? 1 : 0; }
+ uint8 getRex(const Reg& base = Reg()) const
+ {
+ uint8 rex = getRexW() | getRexR() | base.getRexW() | base.getRexB();
+ if (rex || isExt8bit() || base.isExt8bit()) rex |= 0x40;
+ return rex;
+ }
+ Reg8 cvt8() const;
+ Reg16 cvt16() const;
+ Reg32 cvt32() const;
+#ifdef XBYAK64
+ Reg64 cvt64() const;
+#endif
+};
+
+inline const Reg& Operand::getReg() const
+{
+ assert(!isMEM());
+ return static_cast<const Reg&>(*this);
+}
+
+struct Reg8 : public Reg {
+ explicit Reg8(int idx = 0, bool ext8bit = false) : Reg(idx, Operand::REG, 8, ext8bit) { }
+};
+
+struct Reg16 : public Reg {
+ explicit Reg16(int idx = 0) : Reg(idx, Operand::REG, 16) { }
+};
+
+struct Mmx : public Reg {
+ explicit Mmx(int idx = 0, Kind kind = Operand::MMX, int bit = 64) : Reg(idx, kind, bit) { }
+};
+
+struct EvexModifierRounding {
+ enum {
+ T_RN_SAE = 1,
+ T_RD_SAE = 2,
+ T_RU_SAE = 3,
+ T_RZ_SAE = 4,
+ T_SAE = 5
+ };
+ explicit EvexModifierRounding(int rounding) : rounding(rounding) {}
+ int rounding;
+};
+struct EvexModifierZero{EvexModifierZero() {}};
+
+struct Xmm : public Mmx {
+ explicit Xmm(int idx = 0, Kind kind = Operand::XMM, int bit = 128) : Mmx(idx, kind, bit) { }
+ Xmm(Kind kind, int idx) : Mmx(idx, kind, kind == XMM ? 128 : kind == YMM ? 256 : 512) { }
+ Xmm operator|(const EvexModifierRounding& emr) const { Xmm r(*this); r.setRounding(emr.rounding); return r; }
+ Xmm copyAndSetIdx(int idx) const { Xmm ret(*this); ret.setIdx(idx); return ret; }
+ Xmm copyAndSetKind(Operand::Kind kind) const { Xmm ret(*this); ret.setKind(kind); return ret; }
+};
+
+struct Ymm : public Xmm {
+ explicit Ymm(int idx = 0, Kind kind = Operand::YMM, int bit = 256) : Xmm(idx, kind, bit) { }
+ Ymm operator|(const EvexModifierRounding& emr) const { Ymm r(*this); r.setRounding(emr.rounding); return r; }
+};
+
+struct Zmm : public Ymm {
+ explicit Zmm(int idx = 0) : Ymm(idx, Operand::ZMM, 512) { }
+ Zmm operator|(const EvexModifierRounding& emr) const { Zmm r(*this); r.setRounding(emr.rounding); return r; }
+};
+
+struct Opmask : public Reg {
+ explicit Opmask(int idx = 0) : Reg(idx, Operand::OPMASK, 64) {}
+};
+
+struct BoundsReg : public Reg {
+ explicit BoundsReg(int idx = 0) : Reg(idx, Operand::BNDREG, 128) {}
+};
+
+template<class T>T operator|(const T& x, const Opmask& k) { T r(x); r.setOpmaskIdx(k.getIdx()); return r; }
+template<class T>T operator|(const T& x, const EvexModifierZero&) { T r(x); r.setZero(); return r; }
+template<class T>T operator|(const T& x, const EvexModifierRounding& emr) { T r(x); r.setRounding(emr.rounding); return r; }
+
+struct Fpu : public Reg {
+ explicit Fpu(int idx = 0) : Reg(idx, Operand::FPU, 32) { }
+};
+
+struct Reg32e : public Reg {
+ explicit Reg32e(int idx, int bit) : Reg(idx, Operand::REG, bit) {}
+};
+struct Reg32 : public Reg32e {
+ explicit Reg32(int idx = 0) : Reg32e(idx, 32) {}
+};
+#ifdef XBYAK64
+struct Reg64 : public Reg32e {
+ explicit Reg64(int idx = 0) : Reg32e(idx, 64) {}
+};
+struct RegRip {
+ sint64 disp_;
+ const Label* label_;
+ bool isAddr_;
+ explicit RegRip(sint64 disp = 0, const Label* label = 0, bool isAddr = false) : disp_(disp), label_(label), isAddr_(isAddr) {}
+ friend const RegRip operator+(const RegRip& r, int disp) {
+ return RegRip(r.disp_ + disp, r.label_, r.isAddr_);
+ }
+ friend const RegRip operator-(const RegRip& r, int disp) {
+ return RegRip(r.disp_ - disp, r.label_, r.isAddr_);
+ }
+ friend const RegRip operator+(const RegRip& r, sint64 disp) {
+ return RegRip(r.disp_ + disp, r.label_, r.isAddr_);
+ }
+ friend const RegRip operator-(const RegRip& r, sint64 disp) {
+ return RegRip(r.disp_ - disp, r.label_, r.isAddr_);
+ }
+ friend const RegRip operator+(const RegRip& r, const Label& label) {
+ if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING);
+ return RegRip(r.disp_, &label);
+ }
+ friend const RegRip operator+(const RegRip& r, const void *addr) {
+ if (r.label_ || r.isAddr_) throw Error(ERR_BAD_ADDRESSING);
+ return RegRip(r.disp_ + (sint64)addr, 0, true);
+ }
+};
+#endif
+
+inline Reg8 Reg::cvt8() const
+{
+ const int idx = getIdx();
+ if (isBit(8)) return Reg8(idx, isExt8bit());
+#ifdef XBYAK32
+ if (idx >= 4) throw Error(ERR_CANT_CONVERT);
+#endif
+ return Reg8(idx, 4 <= idx && idx < 8);
+}
+
+inline Reg16 Reg::cvt16() const
+{
+ const int idx = getIdx();
+ if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT);
+ return Reg16(idx);
+}
+
+inline Reg32 Reg::cvt32() const
+{
+ const int idx = getIdx();
+ if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT);
+ return Reg32(idx);
+}
+
+#ifdef XBYAK64
+inline Reg64 Reg::cvt64() const
+{
+ const int idx = getIdx();
+ if (isBit(8) && (4 <= idx && idx < 8) && !isExt8bit()) throw Error(ERR_CANT_CONVERT);
+ return Reg64(idx);
+}
+#endif
+
+#ifndef XBYAK_DISABLE_SEGMENT
+// not derived from Reg
+class Segment {
+ int idx_;
+public:
+ enum {
+ es, cs, ss, ds, fs, gs
+ };
+ explicit Segment(int idx) : idx_(idx) { assert(0 <= idx_ && idx_ < 6); }
+ int getIdx() const { return idx_; }
+ const char *toString() const
+ {
+ static const char tbl[][3] = {
+ "es", "cs", "ss", "ds", "fs", "gs"
+ };
+ return tbl[idx_];
+ }
+};
+#endif
+
+class RegExp {
+public:
+#ifdef XBYAK64
+ enum { i32e = 32 | 64 };
+#else
+ enum { i32e = 32 };
+#endif
+ RegExp(size_t disp = 0) : scale_(0), disp_(disp) { }
+ RegExp(const Reg& r, int scale = 1)
+ : scale_(scale)
+ , disp_(0)
+ {
+ if (!r.isREG(i32e) && !r.is(Reg::XMM|Reg::YMM|Reg::ZMM)) throw Error(ERR_BAD_SIZE_OF_REGISTER);
+ if (scale == 0) return;
+ if (scale != 1 && scale != 2 && scale != 4 && scale != 8) throw Error(ERR_BAD_SCALE);
+ if (r.getBit() >= 128 || scale != 1) { // xmm/ymm is always index
+ index_ = r;
+ } else {
+ base_ = r;
+ }
+ }
+ bool isVsib(int bit = 128 | 256 | 512) const { return index_.isBit(bit); }
+ RegExp optimize() const
+ {
+ RegExp exp = *this;
+ // [reg * 2] => [reg + reg]
+ if (index_.isBit(i32e) && !base_.getBit() && scale_ == 2) {
+ exp.base_ = index_;
+ exp.scale_ = 1;
+ }
+ return exp;
+ }
+ bool operator==(const RegExp& rhs) const
+ {
+ return base_ == rhs.base_ && index_ == rhs.index_ && disp_ == rhs.disp_ && scale_ == rhs.scale_;
+ }
+ const Reg& getBase() const { return base_; }
+ const Reg& getIndex() const { return index_; }
+ int getScale() const { return scale_; }
+ size_t getDisp() const { return disp_; }
+ void verify() const
+ {
+ if (base_.getBit() >= 128) throw Error(ERR_BAD_SIZE_OF_REGISTER);
+ if (index_.getBit() && index_.getBit() <= 64) {
+ if (index_.getIdx() == Operand::ESP) throw Error(ERR_ESP_CANT_BE_INDEX);
+ if (base_.getBit() && base_.getBit() != index_.getBit()) throw Error(ERR_BAD_SIZE_OF_REGISTER);
+ }
+ }
+ friend RegExp operator+(const RegExp& a, const RegExp& b);
+ friend RegExp operator-(const RegExp& e, size_t disp);
+ uint8 getRex() const
+ {
+ uint8 rex = index_.getRexX() | base_.getRexB();
+ return rex ? uint8(rex | 0x40) : 0;
+ }
+private:
+ /*
+ [base_ + index_ * scale_ + disp_]
+ base : Reg32e, index : Reg32e(w/o esp), Xmm, Ymm
+ */
+ Reg base_;
+ Reg index_;
+ int scale_;
+ size_t disp_;
+};
+
+inline RegExp operator+(const RegExp& a, const RegExp& b)
+{
+ if (a.index_.getBit() && b.index_.getBit()) throw Error(ERR_BAD_ADDRESSING);
+ RegExp ret = a;
+ if (!ret.index_.getBit()) { ret.index_ = b.index_; ret.scale_ = b.scale_; }
+ if (b.base_.getBit()) {
+ if (ret.base_.getBit()) {
+ if (ret.index_.getBit()) throw Error(ERR_BAD_ADDRESSING);
+ // base + base => base + index * 1
+ ret.index_ = b.base_;
+ // [reg + esp] => [esp + reg]
+ if (ret.index_.getIdx() == Operand::ESP) std::swap(ret.base_, ret.index_);
+ ret.scale_ = 1;
+ } else {
+ ret.base_ = b.base_;
+ }
+ }
+ ret.disp_ += b.disp_;
+ return ret;
+}
+inline RegExp operator*(const Reg& r, int scale)
+{
+ return RegExp(r, scale);
+}
+inline RegExp operator-(const RegExp& e, size_t disp)
+{
+ RegExp ret = e;
+ ret.disp_ -= disp;
+ return ret;
+}
+
+// 2nd parameter for constructor of CodeArray(maxSize, userPtr, alloc)
+void *const AutoGrow = (void*)1; //-V566
+void *const DontSetProtectRWE = (void*)2; //-V566
+
+class CodeArray {
+ enum Type {
+ USER_BUF = 1, // use userPtr(non alignment, non protect)
+ ALLOC_BUF, // use new(alignment, protect)
+ AUTO_GROW // automatically move and grow memory if necessary
+ };
+ CodeArray(const CodeArray& rhs);
+ void operator=(const CodeArray&);
+ bool isAllocType() const { return type_ == ALLOC_BUF || type_ == AUTO_GROW; }
+ struct AddrInfo {
+ size_t codeOffset; // position to write
+ size_t jmpAddr; // value to write
+ int jmpSize; // size of jmpAddr
+ inner::LabelMode mode;
+ AddrInfo(size_t _codeOffset, size_t _jmpAddr, int _jmpSize, inner::LabelMode _mode)
+ : codeOffset(_codeOffset), jmpAddr(_jmpAddr), jmpSize(_jmpSize), mode(_mode) {}
+ uint64 getVal(const uint8 *top) const
+ {
+ uint64 disp = (mode == inner::LaddTop) ? jmpAddr + size_t(top) : (mode == inner::LasIs) ? jmpAddr : jmpAddr - size_t(top);
+ if (jmpSize == 4) disp = inner::VerifyInInt32(disp);
+ return disp;
+ }
+ };
+ typedef std::list<AddrInfo> AddrInfoList;
+ AddrInfoList addrInfoList_;
+ const Type type_;
+#ifdef XBYAK_USE_MMAP_ALLOCATOR
+ MmapAllocator defaultAllocator_;
+#else
+ Allocator defaultAllocator_;
+#endif
+ Allocator *alloc_;
+protected:
+ size_t maxSize_;
+ uint8 *top_;
+ size_t size_;
+ bool isCalledCalcJmpAddress_;
+
+ bool useProtect() const { return alloc_->useProtect(); }
+ /*
+ allocate new memory and copy old data to the new area
+ */
+ void growMemory()
+ {
+ const size_t newSize = (std::max<size_t>)(DEFAULT_MAX_CODE_SIZE, maxSize_ * 2);
+ uint8 *newTop = alloc_->alloc(newSize);
+ if (newTop == 0) throw Error(ERR_CANT_ALLOC);
+ for (size_t i = 0; i < size_; i++) newTop[i] = top_[i];
+ alloc_->free(top_);
+ top_ = newTop;
+ maxSize_ = newSize;
+ }
+ /*
+ calc jmp address for AutoGrow mode
+ */
+ void calcJmpAddress()
+ {
+ if (isCalledCalcJmpAddress_) return;
+ for (AddrInfoList::const_iterator i = addrInfoList_.begin(), ie = addrInfoList_.end(); i != ie; ++i) {
+ uint64 disp = i->getVal(top_);
+ rewrite(i->codeOffset, disp, i->jmpSize);
+ }
+ isCalledCalcJmpAddress_ = true;
+ }
+public:
+ enum ProtectMode {
+ PROTECT_RW = 0, // read/write
+ PROTECT_RWE = 1, // read/write/exec
+ PROTECT_RE = 2 // read/exec
+ };
+ explicit CodeArray(size_t maxSize, void *userPtr = 0, Allocator *allocator = 0)
+ : type_(userPtr == AutoGrow ? AUTO_GROW : (userPtr == 0 || userPtr == DontSetProtectRWE) ? ALLOC_BUF : USER_BUF)
+ , alloc_(allocator ? allocator : (Allocator*)&defaultAllocator_)
+ , maxSize_(maxSize)
+ , top_(type_ == USER_BUF ? reinterpret_cast<uint8*>(userPtr) : alloc_->alloc((std::max<size_t>)(maxSize, 1)))
+ , size_(0)
+ , isCalledCalcJmpAddress_(false)
+ {
+ if (maxSize_ > 0 && top_ == 0) throw Error(ERR_CANT_ALLOC);
+ if ((type_ == ALLOC_BUF && userPtr != DontSetProtectRWE && useProtect()) && !setProtectMode(PROTECT_RWE, false)) {
+ alloc_->free(top_);
+ throw Error(ERR_CANT_PROTECT);
+ }
+ }
+ virtual ~CodeArray()
+ {
+ if (isAllocType()) {
+ if (useProtect()) setProtectModeRW(false);
+ alloc_->free(top_);
+ }
+ }
+ bool setProtectMode(ProtectMode mode, bool throwException = true)
+ {
+ bool isOK = protect(top_, maxSize_, mode);
+ if (isOK) return true;
+ if (throwException) throw Error(ERR_CANT_PROTECT);
+ return false;
+ }
+ bool setProtectModeRE(bool throwException = true) { return setProtectMode(PROTECT_RE, throwException); }
+ bool setProtectModeRW(bool throwException = true) { return setProtectMode(PROTECT_RW, throwException); }
+ void resetSize()
+ {
+ size_ = 0;
+ addrInfoList_.clear();
+ isCalledCalcJmpAddress_ = false;
+ }
+ void db(int code)
+ {
+ if (size_ >= maxSize_) {
+ if (type_ == AUTO_GROW) {
+ growMemory();
+ } else {
+ throw Error(ERR_CODE_IS_TOO_BIG);
+ }
+ }
+ top_[size_++] = static_cast<uint8>(code);
+ }
+ void db(const uint8 *code, size_t codeSize)
+ {
+ for (size_t i = 0; i < codeSize; i++) db(code[i]);
+ }
+ void db(uint64 code, size_t codeSize)
+ {
+ if (codeSize > 8) throw Error(ERR_BAD_PARAMETER);
+ for (size_t i = 0; i < codeSize; i++) db(static_cast<uint8>(code >> (i * 8)));
+ }
+ void dw(uint32 code) { db(code, 2); }
+ void dd(uint32 code) { db(code, 4); }
+ void dq(uint64 code) { db(code, 8); }
+ const uint8 *getCode() const { return top_; }
+ template<class F>
+ const F getCode() const { return reinterpret_cast<F>(top_); }
+ const uint8 *getCurr() const { return &top_[size_]; }
+ template<class F>
+ const F getCurr() const { return reinterpret_cast<F>(&top_[size_]); }
+ size_t getSize() const { return size_; }
+ void setSize(size_t size)
+ {
+ if (size > maxSize_) throw Error(ERR_OFFSET_IS_TOO_BIG);
+ size_ = size;
+ }
+ void dump() const
+ {
+ const uint8 *p = getCode();
+ size_t bufSize = getSize();
+ size_t remain = bufSize;
+ for (int i = 0; i < 4; i++) {
+ size_t disp = 16;
+ if (remain < 16) {
+ disp = remain;
+ }
+ for (size_t j = 0; j < 16; j++) {
+ if (j < disp) {
+ printf("%02X", p[i * 16 + j]);
+ }
+ }
+ putchar('\n');
+ remain -= disp;
+ if (remain == 0) {
+ break;
+ }
+ }
+ }
+ /*
+ @param offset [in] offset from top
+ @param disp [in] offset from the next of jmp
+ @param size [in] write size(1, 2, 4, 8)
+ */
+ void rewrite(size_t offset, uint64 disp, size_t size)
+ {
+ assert(offset < maxSize_);
+ if (size != 1 && size != 2 && size != 4 && size != 8) throw Error(ERR_BAD_PARAMETER);
+ uint8 *const data = top_ + offset;
+ for (size_t i = 0; i < size; i++) {
+ data[i] = static_cast<uint8>(disp >> (i * 8));
+ }
+ }
+ void save(size_t offset, size_t val, int size, inner::LabelMode mode)
+ {
+ addrInfoList_.push_back(AddrInfo(offset, val, size, mode));
+ }
+ bool isAutoGrow() const { return type_ == AUTO_GROW; }
+ bool isCalledCalcJmpAddress() const { return isCalledCalcJmpAddress_; }
+ /**
+ change exec permission of memory
+ @param addr [in] buffer address
+ @param size [in] buffer size
+ @param protectMode [in] mode(RW/RWE/RE)
+ @return true(success), false(failure)
+ */
+ static inline bool protect(const void *addr, size_t size, int protectMode)
+ {
+#if defined(_WIN32)
+ const DWORD c_rw = PAGE_READWRITE;
+ const DWORD c_rwe = PAGE_EXECUTE_READWRITE;
+ const DWORD c_re = PAGE_EXECUTE_READ;
+ DWORD mode;
+#else
+ const int c_rw = PROT_READ | PROT_WRITE;
+ const int c_rwe = PROT_READ | PROT_WRITE | PROT_EXEC;
+ const int c_re = PROT_READ | PROT_EXEC;
+ int mode;
+#endif
+ switch (protectMode) {
+ case PROTECT_RW: mode = c_rw; break;
+ case PROTECT_RWE: mode = c_rwe; break;
+ case PROTECT_RE: mode = c_re; break;
+ default:
+ return false;
+ }
+#if defined(_WIN32)
+ DWORD oldProtect;
+ return VirtualProtect(const_cast<void*>(addr), size, mode, &oldProtect) != 0;
+#elif defined(__GNUC__)
+ size_t pageSize = sysconf(_SC_PAGESIZE);
+ size_t iaddr = reinterpret_cast<size_t>(addr);
+ size_t roundAddr = iaddr & ~(pageSize - static_cast<size_t>(1));
+#ifndef NDEBUG
+ if (pageSize != 4096) fprintf(stderr, "large page(%zd) is used. not tested enough.\n", pageSize);
+#endif
+ return mprotect(reinterpret_cast<void*>(roundAddr), size + (iaddr - roundAddr), mode) == 0;
+#else
+ return true;
+#endif
+ }
+ /**
+ get aligned memory pointer
+ @param addr [in] address
+ @param alignedSize [in] power of two
+ @return aligned addr by alingedSize
+ */
+ static inline uint8 *getAlignedAddress(uint8 *addr, size_t alignedSize = 16)
+ {
+ return reinterpret_cast<uint8*>((reinterpret_cast<size_t>(addr) + alignedSize - 1) & ~(alignedSize - static_cast<size_t>(1)));
+ }
+};
+
+class Address : public Operand {
+public:
+ enum Mode {
+ M_ModRM,
+ M_64bitDisp,
+ M_rip,
+ M_ripAddr
+ };
+ Address(uint32 sizeBit, bool broadcast, const RegExp& e)
+ : Operand(0, MEM, sizeBit), e_(e), label_(0), mode_(M_ModRM), broadcast_(broadcast)
+ {
+ e_.verify();
+ }
+#ifdef XBYAK64
+ explicit Address(size_t disp)
+ : Operand(0, MEM, 64), e_(disp), label_(0), mode_(M_64bitDisp), broadcast_(false){ }
+ Address(uint32 sizeBit, bool broadcast, const RegRip& addr)
+ : Operand(0, MEM, sizeBit), e_(addr.disp_), label_(addr.label_), mode_(addr.isAddr_ ? M_ripAddr : M_rip), broadcast_(broadcast) { }
+#endif
+ RegExp getRegExp(bool optimize = true) const
+ {
+ return optimize ? e_.optimize() : e_;
+ }
+ Mode getMode() const { return mode_; }
+ bool is32bit() const { return e_.getBase().getBit() == 32 || e_.getIndex().getBit() == 32; }
+ bool isOnlyDisp() const { return !e_.getBase().getBit() && !e_.getIndex().getBit(); } // for mov eax
+ size_t getDisp() const { return e_.getDisp(); }
+ uint8 getRex() const
+ {
+ if (mode_ != M_ModRM) return 0;
+ return getRegExp().getRex();
+ }
+ bool is64bitDisp() const { return mode_ == M_64bitDisp; } // for moffset
+ bool isBroadcast() const { return broadcast_; }
+ const Label* getLabel() const { return label_; }
+ bool operator==(const Address& rhs) const
+ {
+ return getBit() == rhs.getBit() && e_ == rhs.e_ && label_ == rhs.label_ && mode_ == rhs.mode_ && broadcast_ == rhs.broadcast_;
+ }
+ bool operator!=(const Address& rhs) const { return !operator==(rhs); }
+ bool isVsib() const { return e_.isVsib(); }
+private:
+ RegExp e_;
+ const Label* label_;
+ Mode mode_;
+ bool broadcast_;
+};
+
+inline const Address& Operand::getAddress() const
+{
+ assert(isMEM());
+ return static_cast<const Address&>(*this);
+}
+
+inline bool Operand::operator==(const Operand& rhs) const
+{
+ if (isMEM() && rhs.isMEM()) return this->getAddress() == rhs.getAddress();
+ return isEqualIfNotInherited(rhs);
+}
+
+class AddressFrame {
+ void operator=(const AddressFrame&);
+ AddressFrame(const AddressFrame&);
+public:
+ const uint32 bit_;
+ const bool broadcast_;
+ explicit AddressFrame(uint32 bit, bool broadcast = false) : bit_(bit), broadcast_(broadcast) { }
+ Address operator[](const RegExp& e) const
+ {
+ return Address(bit_, broadcast_, e);
+ }
+ Address operator[](const void *disp) const
+ {
+ return Address(bit_, broadcast_, RegExp(reinterpret_cast<size_t>(disp)));
+ }
+#ifdef XBYAK64
+ Address operator[](uint64 disp) const { return Address(disp); }
+ Address operator[](const RegRip& addr) const { return Address(bit_, broadcast_, addr); }
+#endif
+};
+
+struct JmpLabel {
+ size_t endOfJmp; /* offset from top to the end address of jmp */
+ int jmpSize;
+ inner::LabelMode mode;
+ size_t disp; // disp for [rip + disp]
+ explicit JmpLabel(size_t endOfJmp = 0, int jmpSize = 0, inner::LabelMode mode = inner::LasIs, size_t disp = 0)
+ : endOfJmp(endOfJmp), jmpSize(jmpSize), mode(mode), disp(disp)
+ {
+ }
+};
+
+class LabelManager;
+
+class Label {
+ mutable LabelManager *mgr;
+ mutable int id;
+ friend class LabelManager;
+public:
+ Label() : mgr(0), id(0) {}
+ Label(const Label& rhs);
+ Label& operator=(const Label& rhs);
+ ~Label();
+ void clear() { mgr = 0; id = 0; }
+ int getId() const { return id; }
+ const uint8 *getAddress() const;
+
+ // backward compatibility
+ static inline std::string toStr(int num)
+ {
+ char buf[16];
+#if defined(_MSC_VER) && (_MSC_VER < 1900)
+ _snprintf_s
+#else
+ snprintf
+#endif
+ (buf, sizeof(buf), ".%08x", num);
+ return buf;
+ }
+};
+
+class LabelManager {
+ // for string label
+ struct SlabelVal {
+ size_t offset;
+ SlabelVal(size_t offset) : offset(offset) {}
+ };
+ typedef XBYAK_STD_UNORDERED_MAP<std::string, SlabelVal> SlabelDefList;
+ typedef XBYAK_STD_UNORDERED_MULTIMAP<std::string, const JmpLabel> SlabelUndefList;
+ struct SlabelState {
+ SlabelDefList defList;
+ SlabelUndefList undefList;
+ };
+ typedef std::list<SlabelState> StateList;
+ // for Label class
+ struct ClabelVal {
+ ClabelVal(size_t offset = 0) : offset(offset), refCount(1) {}
+ size_t offset;
+ int refCount;
+ };
+ typedef XBYAK_STD_UNORDERED_MAP<int, ClabelVal> ClabelDefList;
+ typedef XBYAK_STD_UNORDERED_MULTIMAP<int, const JmpLabel> ClabelUndefList;
+ typedef XBYAK_STD_UNORDERED_SET<Label*> LabelPtrList;
+
+ CodeArray *base_;
+ // global : stateList_.front(), local : stateList_.back()
+ StateList stateList_;
+ mutable int labelId_;
+ ClabelDefList clabelDefList_;
+ ClabelUndefList clabelUndefList_;
+ LabelPtrList labelPtrList_;
+
+ int getId(const Label& label) const
+ {
+ if (label.id == 0) label.id = labelId_++;
+ return label.id;
+ }
+ template<class DefList, class UndefList, class T>
+ void define_inner(DefList& defList, UndefList& undefList, const T& labelId, size_t addrOffset)
+ {
+ // add label
+ typename DefList::value_type item(labelId, addrOffset);
+ std::pair<typename DefList::iterator, bool> ret = defList.insert(item);
+ if (!ret.second) throw Error(ERR_LABEL_IS_REDEFINED);
+ // search undefined label
+ for (;;) {
+ typename UndefList::iterator itr = undefList.find(labelId);
+ if (itr == undefList.end()) break;
+ const JmpLabel *jmp = &itr->second;
+ const size_t offset = jmp->endOfJmp - jmp->jmpSize;
+ size_t disp;
+ if (jmp->mode == inner::LaddTop) {
+ disp = addrOffset;
+ } else if (jmp->mode == inner::Labs) {
+ disp = size_t(base_->getCurr());
+ } else {
+ disp = addrOffset - jmp->endOfJmp + jmp->disp;
+#ifdef XBYAK64
+ if (jmp->jmpSize <= 4 && !inner::IsInInt32(disp)) throw Error(ERR_OFFSET_IS_TOO_BIG);
+#endif
+ if (jmp->jmpSize == 1 && !inner::IsInDisp8((uint32)disp)) throw Error(ERR_LABEL_IS_TOO_FAR);
+ }
+ if (base_->isAutoGrow()) {
+ base_->save(offset, disp, jmp->jmpSize, jmp->mode);
+ } else {
+ base_->rewrite(offset, disp, jmp->jmpSize);
+ }
+ undefList.erase(itr);
+ }
+ }
+ template<class DefList, class T>
+ bool getOffset_inner(const DefList& defList, size_t *offset, const T& label) const
+ {
+ typename DefList::const_iterator i = defList.find(label);
+ if (i == defList.end()) return false;
+ *offset = i->second.offset;
+ return true;
+ }
+ friend class Label;
+ void incRefCount(int id, Label *label)
+ {
+ clabelDefList_[id].refCount++;
+ labelPtrList_.insert(label);
+ }
+ void decRefCount(int id, Label *label)
+ {
+ labelPtrList_.erase(label);
+ ClabelDefList::iterator i = clabelDefList_.find(id);
+ if (i == clabelDefList_.end()) return;
+ if (i->second.refCount == 1) {
+ clabelDefList_.erase(id);
+ } else {
+ --i->second.refCount;
+ }
+ }
+ template<class T>
+ bool hasUndefinedLabel_inner(const T& list) const
+ {
+#ifndef NDEBUG
+ for (typename T::const_iterator i = list.begin(); i != list.end(); ++i) {
+ std::cerr << "undefined label:" << i->first << std::endl;
+ }
+#endif
+ return !list.empty();
+ }
+ // detach all labels linked to LabelManager
+ void resetLabelPtrList()
+ {
+ for (LabelPtrList::iterator i = labelPtrList_.begin(), ie = labelPtrList_.end(); i != ie; ++i) {
+ (*i)->clear();
+ }
+ labelPtrList_.clear();
+ }
+public:
+ LabelManager()
+ {
+ reset();
+ }
+ ~LabelManager()
+ {
+ resetLabelPtrList();
+ }
+ void reset()
+ {
+ base_ = 0;
+ labelId_ = 1;
+ stateList_.clear();
+ stateList_.push_back(SlabelState());
+ stateList_.push_back(SlabelState());
+ clabelDefList_.clear();
+ clabelUndefList_.clear();
+ resetLabelPtrList();
+ }
+ void enterLocal()
+ {
+ stateList_.push_back(SlabelState());
+ }
+ void leaveLocal()
+ {
+ if (stateList_.size() <= 2) throw Error(ERR_UNDER_LOCAL_LABEL);
+ if (hasUndefinedLabel_inner(stateList_.back().undefList)) throw Error(ERR_LABEL_IS_NOT_FOUND);
+ stateList_.pop_back();
+ }
+ void set(CodeArray *base) { base_ = base; }
+ void defineSlabel(std::string label)
+ {
+ if (label == "@b" || label == "@f") throw Error(ERR_BAD_LABEL_STR);
+ if (label == "@@") {
+ SlabelDefList& defList = stateList_.front().defList;
+ SlabelDefList::iterator i = defList.find("@f");
+ if (i != defList.end()) {
+ defList.erase(i);
+ label = "@b";
+ } else {
+ i = defList.find("@b");
+ if (i != defList.end()) {
+ defList.erase(i);
+ }
+ label = "@f";
+ }
+ }
+ SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front();
+ define_inner(st.defList, st.undefList, label, base_->getSize());
+ }
+ void defineClabel(Label& label)
+ {
+ define_inner(clabelDefList_, clabelUndefList_, getId(label), base_->getSize());
+ label.mgr = this;
+ labelPtrList_.insert(&label);
+ }
+ void assign(Label& dst, const Label& src)
+ {
+ ClabelDefList::const_iterator i = clabelDefList_.find(src.id);
+ if (i == clabelDefList_.end()) throw Error(ERR_LABEL_ISNOT_SET_BY_L);
+ define_inner(clabelDefList_, clabelUndefList_, dst.id, i->second.offset);
+ dst.mgr = this;
+ labelPtrList_.insert(&dst);
+ }
+ bool getOffset(size_t *offset, std::string& label) const
+ {
+ const SlabelDefList& defList = stateList_.front().defList;
+ if (label == "@b") {
+ if (defList.find("@f") != defList.end()) {
+ label = "@f";
+ } else if (defList.find("@b") == defList.end()) {
+ throw Error(ERR_LABEL_IS_NOT_FOUND);
+ }
+ } else if (label == "@f") {
+ if (defList.find("@f") != defList.end()) {
+ label = "@b";
+ }
+ }
+ const SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front();
+ return getOffset_inner(st.defList, offset, label);
+ }
+ bool getOffset(size_t *offset, const Label& label) const
+ {
+ return getOffset_inner(clabelDefList_, offset, getId(label));
+ }
+ void addUndefinedLabel(const std::string& label, const JmpLabel& jmp)
+ {
+ SlabelState& st = *label.c_str() == '.' ? stateList_.back() : stateList_.front();
+ st.undefList.insert(SlabelUndefList::value_type(label, jmp));
+ }
+ void addUndefinedLabel(const Label& label, const JmpLabel& jmp)
+ {
+ clabelUndefList_.insert(ClabelUndefList::value_type(label.id, jmp));
+ }
+ bool hasUndefSlabel() const
+ {
+ for (StateList::const_iterator i = stateList_.begin(), ie = stateList_.end(); i != ie; ++i) {
+ if (hasUndefinedLabel_inner(i->undefList)) return true;
+ }
+ return false;
+ }
+ bool hasUndefClabel() const { return hasUndefinedLabel_inner(clabelUndefList_); }
+ const uint8 *getCode() const { return base_->getCode(); }
+ bool isReady() const { return !base_->isAutoGrow() || base_->isCalledCalcJmpAddress(); }
+};
+
+inline Label::Label(const Label& rhs)
+{
+ id = rhs.id;
+ mgr = rhs.mgr;
+ if (mgr) mgr->incRefCount(id, this);
+}
+inline Label& Label::operator=(const Label& rhs)
+{
+ if (id) throw Error(ERR_LABEL_IS_ALREADY_SET_BY_L);
+ id = rhs.id;
+ mgr = rhs.mgr;
+ if (mgr) mgr->incRefCount(id, this);
+ return *this;
+}
+inline Label::~Label()
+{
+ if (id && mgr) mgr->decRefCount(id, this);
+}
+inline const uint8* Label::getAddress() const
+{
+ if (mgr == 0 || !mgr->isReady()) return 0;
+ size_t offset;
+ if (!mgr->getOffset(&offset, *this)) return 0;
+ return mgr->getCode() + offset;
+}
+
+class CodeGenerator : public CodeArray {
+public:
+ enum LabelType {
+ T_SHORT,
+ T_NEAR,
+ T_AUTO // T_SHORT if possible
+ };
+private:
+ CodeGenerator operator=(const CodeGenerator&); // don't call
+#ifdef XBYAK64
+ enum { i32e = 32 | 64, BIT = 64 };
+ static const size_t dummyAddr = (size_t(0x11223344) << 32) | 55667788;
+ typedef Reg64 NativeReg;
+#else
+ enum { i32e = 32, BIT = 32 };
+ static const size_t dummyAddr = 0x12345678;
+ typedef Reg32 NativeReg;
+#endif
+ // (XMM, XMM|MEM)
+ static inline bool isXMM_XMMorMEM(const Operand& op1, const Operand& op2)
+ {
+ return op1.isXMM() && (op2.isXMM() || op2.isMEM());
+ }
+ // (MMX, MMX|MEM) or (XMM, XMM|MEM)
+ static inline bool isXMMorMMX_MEM(const Operand& op1, const Operand& op2)
+ {
+ return (op1.isMMX() && (op2.isMMX() || op2.isMEM())) || isXMM_XMMorMEM(op1, op2);
+ }
+ // (XMM, MMX|MEM)
+ static inline bool isXMM_MMXorMEM(const Operand& op1, const Operand& op2)
+ {
+ return op1.isXMM() && (op2.isMMX() || op2.isMEM());
+ }
+ // (MMX, XMM|MEM)
+ static inline bool isMMX_XMMorMEM(const Operand& op1, const Operand& op2)
+ {
+ return op1.isMMX() && (op2.isXMM() || op2.isMEM());
+ }
+ // (XMM, REG32|MEM)
+ static inline bool isXMM_REG32orMEM(const Operand& op1, const Operand& op2)
+ {
+ return op1.isXMM() && (op2.isREG(i32e) || op2.isMEM());
+ }
+ // (REG32, XMM|MEM)
+ static inline bool isREG32_XMMorMEM(const Operand& op1, const Operand& op2)
+ {
+ return op1.isREG(i32e) && (op2.isXMM() || op2.isMEM());
+ }
+ // (REG32, REG32|MEM)
+ static inline bool isREG32_REG32orMEM(const Operand& op1, const Operand& op2)
+ {
+ return op1.isREG(i32e) && ((op2.isREG(i32e) && op1.getBit() == op2.getBit()) || op2.isMEM());
+ }
+ void rex(const Operand& op1, const Operand& op2 = Operand())
+ {
+ uint8 rex = 0;
+ const Operand *p1 = &op1, *p2 = &op2;
+ if (p1->isMEM()) std::swap(p1, p2);
+ if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION);
+ if (p2->isMEM()) {
+ const Address& addr = p2->getAddress();
+ if (BIT == 64 && addr.is32bit()) db(0x67);
+ rex = addr.getRex() | p1->getReg().getRex();
+ } else {
+ // ModRM(reg, base);
+ rex = op2.getReg().getRex(op1.getReg());
+ }
+ // except movsx(16bit, 32/64bit)
+ if ((op1.isBit(16) && !op2.isBit(i32e)) || (op2.isBit(16) && !op1.isBit(i32e))) db(0x66);
+ if (rex) db(rex);
+ }
+ enum AVXtype {
+ // low 3 bit
+ T_N1 = 1,
+ T_N2 = 2,
+ T_N4 = 3,
+ T_N8 = 4,
+ T_N16 = 5,
+ T_N32 = 6,
+ T_NX_MASK = 7,
+ //
+ T_N_VL = 1 << 3, // N * (1, 2, 4) for VL
+ T_DUP = 1 << 4, // N = (8, 32, 64)
+ T_66 = 1 << 5,
+ T_F3 = 1 << 6,
+ T_F2 = 1 << 7,
+ T_0F = 1 << 8,
+ T_0F38 = 1 << 9,
+ T_0F3A = 1 << 10,
+ T_L0 = 1 << 11,
+ T_L1 = 1 << 12,
+ T_W0 = 1 << 13,
+ T_W1 = 1 << 14,
+ T_EW0 = 1 << 15,
+ T_EW1 = 1 << 16,
+ T_YMM = 1 << 17, // support YMM, ZMM
+ T_EVEX = 1 << 18,
+ T_ER_X = 1 << 19, // xmm{er}
+ T_ER_Y = 1 << 20, // ymm{er}
+ T_ER_Z = 1 << 21, // zmm{er}
+ T_SAE_X = 1 << 22, // xmm{sae}
+ T_SAE_Y = 1 << 23, // ymm{sae}
+ T_SAE_Z = 1 << 24, // zmm{sae}
+ T_MUST_EVEX = 1 << 25, // contains T_EVEX
+ T_B32 = 1 << 26, // m32bcst
+ T_B64 = 1 << 27, // m64bcst
+ T_M_K = 1 << 28, // mem{k}
+ T_VSIB = 1 << 29,
+ T_MEM_EVEX = 1 << 30, // use evex if mem
+ T_XXX
+ };
+ void vex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false)
+ {
+ int w = (type & T_W1) ? 1 : 0;
+ bool is256 = (type & T_L1) ? true : (type & T_L0) ? false : reg.isYMM();
+ bool r = reg.isExtIdx();
+ bool b = base.isExtIdx();
+ int idx = v ? v->getIdx() : 0;
+ if ((idx | reg.getIdx() | base.getIdx()) >= 16) throw Error(ERR_BAD_COMBINATION);
+ uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0;
+ uint32 vvvv = (((~idx) & 15) << 3) | (is256 ? 4 : 0) | pp;
+ if (!b && !x && !w && (type & T_0F)) {
+ db(0xC5); db((r ? 0 : 0x80) | vvvv);
+ } else {
+ uint32 mmmm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0;
+ db(0xC4); db((r ? 0 : 0x80) | (x ? 0 : 0x40) | (b ? 0 : 0x20) | mmmm); db((w << 7) | vvvv);
+ }
+ db(code);
+ }
+ void verifySAE(const Reg& r, int type) const
+ {
+ if (((type & T_SAE_X) && r.isXMM()) || ((type & T_SAE_Y) && r.isYMM()) || ((type & T_SAE_Z) && r.isZMM())) return;
+ throw Error(ERR_SAE_IS_INVALID);
+ }
+ void verifyER(const Reg& r, int type) const
+ {
+ if (((type & T_ER_X) && r.isXMM()) || ((type & T_ER_Y) && r.isYMM()) || ((type & T_ER_Z) && r.isZMM())) return;
+ throw Error(ERR_ER_IS_INVALID);
+ }
+ // (a, b, c) contains non zero two or three values then err
+ int verifyDuplicate(int a, int b, int c, int err)
+ {
+ int v = a | b | c;
+ if ((a > 0 && a != v) + (b > 0 && b != v) + (c > 0 && c != v) > 0) return Error(err);
+ return v;
+ }
+ int evex(const Reg& reg, const Reg& base, const Operand *v, int type, int code, bool x = false, bool b = false, int aaa = 0, uint32 VL = 0, bool Hi16Vidx = false)
+ {
+ if (!(type & (T_EVEX | T_MUST_EVEX))) throw Error(ERR_EVEX_IS_INVALID);
+ int w = (type & T_EW1) ? 1 : 0;
+ uint32 mm = (type & T_0F) ? 1 : (type & T_0F38) ? 2 : (type & T_0F3A) ? 3 : 0;
+ uint32 pp = (type & T_66) ? 1 : (type & T_F3) ? 2 : (type & T_F2) ? 3 : 0;
+
+ int idx = v ? v->getIdx() : 0;
+ uint32 vvvv = ~idx;
+
+ bool R = !reg.isExtIdx();
+ bool X = x ? false : !base.isExtIdx2();
+ bool B = !base.isExtIdx();
+ bool Rp = !reg.isExtIdx2();
+ int LL;
+ int rounding = verifyDuplicate(reg.getRounding(), base.getRounding(), v ? v->getRounding() : 0, ERR_ROUNDING_IS_ALREADY_SET);
+ int disp8N = 1;
+ if (rounding) {
+ if (rounding == EvexModifierRounding::T_SAE) {
+ verifySAE(base, type); LL = 0;
+ } else {
+ verifyER(base, type); LL = rounding - 1;
+ }
+ b = true;
+ } else {
+ if (v) VL = (std::max)(VL, v->getBit());
+ VL = (std::max)((std::max)(reg.getBit(), base.getBit()), VL);
+ LL = (VL == 512) ? 2 : (VL == 256) ? 1 : 0;
+ if (b) {
+ disp8N = (type & T_B32) ? 4 : 8;
+ } else if (type & T_DUP) {
+ disp8N = VL == 128 ? 8 : VL == 256 ? 32 : 64;
+ } else {
+ if ((type & (T_NX_MASK | T_N_VL)) == 0) {
+ type |= T_N16 | T_N_VL; // default
+ }
+ int low = type & T_NX_MASK;
+ if (low > 0) {
+ disp8N = 1 << (low - 1);
+ if (type & T_N_VL) disp8N *= (VL == 512 ? 4 : VL == 256 ? 2 : 1);
+ }
+ }
+ }
+ bool Vp = !((v ? v->isExtIdx2() : 0) | Hi16Vidx);
+ bool z = reg.hasZero() || base.hasZero() || (v ? v->hasZero() : false);
+ if (aaa == 0) aaa = verifyDuplicate(base.getOpmaskIdx(), reg.getOpmaskIdx(), (v ? v->getOpmaskIdx() : 0), ERR_OPMASK_IS_ALREADY_SET);
+ db(0x62);
+ db((R ? 0x80 : 0) | (X ? 0x40 : 0) | (B ? 0x20 : 0) | (Rp ? 0x10 : 0) | (mm & 3));
+ db((w == 1 ? 0x80 : 0) | ((vvvv & 15) << 3) | 4 | (pp & 3));
+ db((z ? 0x80 : 0) | ((LL & 3) << 5) | (b ? 0x10 : 0) | (Vp ? 8 : 0) | (aaa & 7));
+ db(code);
+ return disp8N;
+ }
+ void setModRM(int mod, int r1, int r2)
+ {
+ db(static_cast<uint8>((mod << 6) | ((r1 & 7) << 3) | (r2 & 7)));
+ }
+ void setSIB(const RegExp& e, int reg, int disp8N = 0)
+ {
+ size_t disp64 = e.getDisp();
+#ifdef XBYAK64
+ size_t high = disp64 >> 32;
+ if (high != 0 && high != 0xFFFFFFFF) throw Error(ERR_OFFSET_IS_TOO_BIG);
+#endif
+ uint32 disp = static_cast<uint32>(disp64);
+ const Reg& base = e.getBase();
+ const Reg& index = e.getIndex();
+ const int baseIdx = base.getIdx();
+ const int baseBit = base.getBit();
+ const int indexBit = index.getBit();
+ enum {
+ mod00 = 0, mod01 = 1, mod10 = 2
+ };
+ int mod = mod10; // disp32
+ if (!baseBit || ((baseIdx & 7) != Operand::EBP && disp == 0)) {
+ mod = mod00;
+ } else {
+ if (disp8N == 0) {
+ if (inner::IsInDisp8(disp)) {
+ mod = mod01;
+ }
+ } else {
+ // disp must be casted to signed
+ uint32 t = static_cast<uint32>(static_cast<int>(disp) / disp8N);
+ if ((disp % disp8N) == 0 && inner::IsInDisp8(t)) {
+ disp = t;
+ mod = mod01;
+ }
+ }
+ }
+ const int newBaseIdx = baseBit ? (baseIdx & 7) : Operand::EBP;
+ /* ModR/M = [2:3:3] = [Mod:reg/code:R/M] */
+ bool hasSIB = indexBit || (baseIdx & 7) == Operand::ESP;
+#ifdef XBYAK64
+ if (!baseBit && !indexBit) hasSIB = true;
+#endif
+ if (hasSIB) {
+ setModRM(mod, reg, Operand::ESP);
+ /* SIB = [2:3:3] = [SS:index:base(=rm)] */
+ const int idx = indexBit ? (index.getIdx() & 7) : Operand::ESP;
+ const int scale = e.getScale();
+ const int SS = (scale == 8) ? 3 : (scale == 4) ? 2 : (scale == 2) ? 1 : 0;
+ setModRM(SS, idx, newBaseIdx);
+ } else {
+ setModRM(mod, reg, newBaseIdx);
+ }
+ if (mod == mod01) {
+ db(disp);
+ } else if (mod == mod10 || (mod == mod00 && !baseBit)) {
+ dd(disp);
+ }
+ }
+ LabelManager labelMgr_;
+ bool isInDisp16(uint32 x) const { return 0xFFFF8000 <= x || x <= 0x7FFF; }
+ void opModR(const Reg& reg1, const Reg& reg2, int code0, int code1 = NONE, int code2 = NONE)
+ {
+ rex(reg2, reg1);
+ db(code0 | (reg1.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2);
+ setModRM(3, reg1.getIdx(), reg2.getIdx());
+ }
+ void opModM(const Address& addr, const Reg& reg, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0)
+ {
+ if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP);
+ rex(addr, reg);
+ db(code0 | (reg.isBit(8) ? 0 : 1)); if (code1 != NONE) db(code1); if (code2 != NONE) db(code2);
+ opAddr(addr, reg.getIdx(), immSize);
+ }
+ void opMIB(const Address& addr, const Reg& reg, int code0, int code1)
+ {
+ if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP);
+ if (addr.getMode() != Address::M_ModRM) throw Error(ERR_INVALID_MIB_ADDRESS);
+ if (BIT == 64 && addr.is32bit()) db(0x67);
+ const RegExp& regExp = addr.getRegExp(false);
+ uint8 rex = regExp.getRex();
+ if (rex) db(rex);
+ db(code0); db(code1);
+ setSIB(regExp, reg.getIdx());
+ }
+ void makeJmp(uint32 disp, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref)
+ {
+ const int shortJmpSize = 2;
+ const int longHeaderSize = longPref ? 2 : 1;
+ const int longJmpSize = longHeaderSize + 4;
+ if (type != T_NEAR && inner::IsInDisp8(disp - shortJmpSize)) {
+ db(shortCode); db(disp - shortJmpSize);
+ } else {
+ if (type == T_SHORT) throw Error(ERR_LABEL_IS_TOO_FAR);
+ if (longPref) db(longPref);
+ db(longCode); dd(disp - longJmpSize);
+ }
+ }
+ template<class T>
+ void opJmp(T& label, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref)
+ {
+ if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory(); /* avoid splitting code of jmp */
+ size_t offset = 0;
+ if (labelMgr_.getOffset(&offset, label)) { /* label exists */
+ makeJmp(inner::VerifyInInt32(offset - size_), type, shortCode, longCode, longPref);
+ } else {
+ int jmpSize = 0;
+ if (type == T_NEAR) {
+ jmpSize = 4;
+ if (longPref) db(longPref);
+ db(longCode); dd(0);
+ } else {
+ jmpSize = 1;
+ db(shortCode); db(0);
+ }
+ JmpLabel jmp(size_, jmpSize, inner::LasIs);
+ labelMgr_.addUndefinedLabel(label, jmp);
+ }
+ }
+ void opJmpAbs(const void *addr, LabelType type, uint8 shortCode, uint8 longCode, uint8 longPref = 0)
+ {
+ if (isAutoGrow()) {
+ if (type != T_NEAR) throw Error(ERR_ONLY_T_NEAR_IS_SUPPORTED_IN_AUTO_GROW);
+ if (size_ + 16 >= maxSize_) growMemory();
+ if (longPref) db(longPref);
+ db(longCode);
+ dd(0);
+ save(size_ - 4, size_t(addr) - size_, 4, inner::Labs);
+ } else {
+ makeJmp(inner::VerifyInInt32(reinterpret_cast<const uint8*>(addr) - getCurr()), type, shortCode, longCode, longPref);
+ }
+
+ }
+ // reg is reg field of ModRM
+ // immSize is the size for immediate value
+ // disp8N = 0(normal), disp8N = 1(force disp32), disp8N = {2, 4, 8} ; compressed displacement
+ void opAddr(const Address &addr, int reg, int immSize = 0, int disp8N = 0, bool permitVisb = false)
+ {
+ if (!permitVisb && addr.isVsib()) throw Error(ERR_BAD_VSIB_ADDRESSING);
+ if (addr.getMode() == Address::M_ModRM) {
+ setSIB(addr.getRegExp(), reg, disp8N);
+ } else if (addr.getMode() == Address::M_rip || addr.getMode() == Address::M_ripAddr) {
+ setModRM(0, reg, 5);
+ if (addr.getLabel()) { // [rip + Label]
+ putL_inner(*addr.getLabel(), true, addr.getDisp() - immSize);
+ } else {
+ size_t disp = addr.getDisp();
+ if (addr.getMode() == Address::M_ripAddr) {
+ if (isAutoGrow()) throw Error(ERR_INVALID_RIP_IN_AUTO_GROW);
+ disp -= (size_t)getCurr() + 4 + immSize;
+ }
+ dd(inner::VerifyInInt32(disp));
+ }
+ }
+ }
+ /* preCode is for SSSE3/SSE4 */
+ void opGen(const Operand& reg, const Operand& op, int code, int pref, bool isValid(const Operand&, const Operand&), int imm8 = NONE, int preCode = NONE)
+ {
+ if (isValid && !isValid(reg, op)) throw Error(ERR_BAD_COMBINATION);
+ if (pref != NONE) db(pref);
+ if (op.isMEM()) {
+ opModM(op.getAddress(), reg.getReg(), 0x0F, preCode, code, (imm8 != NONE) ? 1 : 0);
+ } else {
+ opModR(reg.getReg(), op.getReg(), 0x0F, preCode, code);
+ }
+ if (imm8 != NONE) db(imm8);
+ }
+ void opMMX_IMM(const Mmx& mmx, int imm8, int code, int ext)
+ {
+ if (mmx.isXMM()) db(0x66);
+ opModR(Reg32(ext), mmx, 0x0F, code);
+ db(imm8);
+ }
+ void opMMX(const Mmx& mmx, const Operand& op, int code, int pref = 0x66, int imm8 = NONE, int preCode = NONE)
+ {
+ opGen(mmx, op, code, mmx.isXMM() ? pref : NONE, isXMMorMMX_MEM, imm8, preCode);
+ }
+ void opMovXMM(const Operand& op1, const Operand& op2, int code, int pref)
+ {
+ if (pref != NONE) db(pref);
+ if (op1.isXMM() && op2.isMEM()) {
+ opModM(op2.getAddress(), op1.getReg(), 0x0F, code);
+ } else if (op1.isMEM() && op2.isXMM()) {
+ opModM(op1.getAddress(), op2.getReg(), 0x0F, code | 1);
+ } else {
+ throw Error(ERR_BAD_COMBINATION);
+ }
+ }
+ void opExt(const Operand& op, const Mmx& mmx, int code, int imm, bool hasMMX2 = false)
+ {
+ if (hasMMX2 && op.isREG(i32e)) { /* pextrw is special */
+ if (mmx.isXMM()) db(0x66);
+ opModR(op.getReg(), mmx, 0x0F, 0xC5); db(imm);
+ } else {
+ opGen(mmx, op, code, 0x66, isXMM_REG32orMEM, imm, 0x3A);
+ }
+ }
+ void opR_ModM(const Operand& op, int bit, int ext, int code0, int code1 = NONE, int code2 = NONE, bool disableRex = false, int immSize = 0)
+ {
+ int opBit = op.getBit();
+ if (disableRex && opBit == 64) opBit = 32;
+ if (op.isREG(bit)) {
+ opModR(Reg(ext, Operand::REG, opBit), op.getReg().changeBit(opBit), code0, code1, code2);
+ } else if (op.isMEM()) {
+ opModM(op.getAddress(), Reg(ext, Operand::REG, opBit), code0, code1, code2, immSize);
+ } else {
+ throw Error(ERR_BAD_COMBINATION);
+ }
+ }
+ void opShift(const Operand& op, int imm, int ext)
+ {
+ verifyMemHasSize(op);
+ opR_ModM(op, 0, ext, (0xC0 | ((imm == 1 ? 1 : 0) << 4)), NONE, NONE, false, (imm != 1) ? 1 : 0);
+ if (imm != 1) db(imm);
+ }
+ void opShift(const Operand& op, const Reg8& _cl, int ext)
+ {
+ if (_cl.getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION);
+ opR_ModM(op, 0, ext, 0xD2);
+ }
+ void opModRM(const Operand& op1, const Operand& op2, bool condR, bool condM, int code0, int code1 = NONE, int code2 = NONE, int immSize = 0)
+ {
+ if (condR) {
+ opModR(op1.getReg(), op2.getReg(), code0, code1, code2);
+ } else if (condM) {
+ opModM(op2.getAddress(), op1.getReg(), code0, code1, code2, immSize);
+ } else {
+ throw Error(ERR_BAD_COMBINATION);
+ }
+ }
+ void opShxd(const Operand& op, const Reg& reg, uint8 imm, int code, const Reg8 *_cl = 0)
+ {
+ if (_cl && _cl->getIdx() != Operand::CL) throw Error(ERR_BAD_COMBINATION);
+ opModRM(reg, op, (op.isREG(16 | i32e) && op.getBit() == reg.getBit()), op.isMEM() && (reg.isREG(16 | i32e)), 0x0F, code | (_cl ? 1 : 0), NONE, _cl ? 0 : 1);
+ if (!_cl) db(imm);
+ }
+ // (REG, REG|MEM), (MEM, REG)
+ void opRM_RM(const Operand& op1, const Operand& op2, int code)
+ {
+ if (op1.isREG() && op2.isMEM()) {
+ opModM(op2.getAddress(), op1.getReg(), code | 2);
+ } else {
+ opModRM(op2, op1, op1.isREG() && op1.getKind() == op2.getKind(), op1.isMEM() && op2.isREG(), code);
+ }
+ }
+ // (REG|MEM, IMM)
+ void opRM_I(const Operand& op, uint32 imm, int code, int ext)
+ {
+ verifyMemHasSize(op);
+ uint32 immBit = inner::IsInDisp8(imm) ? 8 : isInDisp16(imm) ? 16 : 32;
+ if (op.isBit(8)) immBit = 8;
+ if (op.getBit() < immBit) throw Error(ERR_IMM_IS_TOO_BIG);
+ if (op.isBit(32|64) && immBit == 16) immBit = 32; /* don't use MEM16 if 32/64bit mode */
+ if (op.isREG() && op.getIdx() == 0 && (op.getBit() == immBit || (op.isBit(64) && immBit == 32))) { // rax, eax, ax, al
+ rex(op);
+ db(code | 4 | (immBit == 8 ? 0 : 1));
+ } else {
+ int tmp = immBit < (std::min)(op.getBit(), 32U) ? 2 : 0;
+ opR_ModM(op, 0, ext, 0x80 | tmp, NONE, NONE, false, immBit / 8);
+ }
+ db(imm, immBit / 8);
+ }
+ void opIncDec(const Operand& op, int code, int ext)
+ {
+ verifyMemHasSize(op);
+#ifndef XBYAK64
+ if (op.isREG() && !op.isBit(8)) {
+ rex(op); db(code | op.getIdx());
+ return;
+ }
+#endif
+ code = 0xFE;
+ if (op.isREG()) {
+ opModR(Reg(ext, Operand::REG, op.getBit()), op.getReg(), code);
+ } else {
+ opModM(op.getAddress(), Reg(ext, Operand::REG, op.getBit()), code);
+ }
+ }
+ void opPushPop(const Operand& op, int code, int ext, int alt)
+ {
+ int bit = op.getBit();
+ if (bit == 16 || bit == BIT) {
+ if (bit == 16) db(0x66);
+ if (op.isREG()) {
+ if (op.getReg().getIdx() >= 8) db(0x41);
+ db(alt | (op.getIdx() & 7));
+ return;
+ }
+ if (op.isMEM()) {
+ opModM(op.getAddress(), Reg(ext, Operand::REG, 32), code);
+ return;
+ }
+ }
+ throw Error(ERR_BAD_COMBINATION);
+ }
+ void verifyMemHasSize(const Operand& op) const
+ {
+ if (op.isMEM() && op.getBit() == 0) throw Error(ERR_MEM_SIZE_IS_NOT_SPECIFIED);
+ }
+ /*
+ mov(r, imm) = db(imm, mov_imm(r, imm))
+ */
+ int mov_imm(const Reg& reg, size_t imm)
+ {
+ int bit = reg.getBit();
+ const int idx = reg.getIdx();
+ int code = 0xB0 | ((bit == 8 ? 0 : 1) << 3);
+ if (bit == 64 && (imm & ~size_t(0xffffffffu)) == 0) {
+ rex(Reg32(idx));
+ bit = 32;
+ } else {
+ rex(reg);
+ if (bit == 64 && inner::IsInInt32(imm)) {
+ db(0xC7);
+ code = 0xC0;
+ bit = 32;
+ }
+ }
+ db(code | (idx & 7));
+ return bit / 8;
+ }
+ template<class T>
+ void putL_inner(T& label, bool relative = false, size_t disp = 0)
+ {
+ const int jmpSize = relative ? 4 : (int)sizeof(size_t);
+ if (isAutoGrow() && size_ + 16 >= maxSize_) growMemory();
+ size_t offset = 0;
+ if (labelMgr_.getOffset(&offset, label)) {
+ if (relative) {
+ db(inner::VerifyInInt32(offset + disp - size_ - jmpSize), jmpSize);
+ } else if (isAutoGrow()) {
+ db(uint64(0), jmpSize);
+ save(size_ - jmpSize, offset, jmpSize, inner::LaddTop);
+ } else {
+ db(size_t(top_) + offset, jmpSize);
+ }
+ return;
+ }
+ db(uint64(0), jmpSize);
+ JmpLabel jmp(size_, jmpSize, (relative ? inner::LasIs : isAutoGrow() ? inner::LaddTop : inner::Labs), disp);
+ labelMgr_.addUndefinedLabel(label, jmp);
+ }
+ void opMovxx(const Reg& reg, const Operand& op, uint8 code)
+ {
+ if (op.isBit(32)) throw Error(ERR_BAD_COMBINATION);
+ int w = op.isBit(16);
+#ifdef XBYAK64
+ if (op.isHigh8bit()) throw Error(ERR_BAD_COMBINATION);
+#endif
+ bool cond = reg.isREG() && (reg.getBit() > op.getBit());
+ opModRM(reg, op, cond && op.isREG(), cond && op.isMEM(), 0x0F, code | w);
+ }
+ void opFpuMem(const Address& addr, uint8 m16, uint8 m32, uint8 m64, uint8 ext, uint8 m64ext)
+ {
+ if (addr.is64bitDisp()) throw Error(ERR_CANT_USE_64BIT_DISP);
+ uint8 code = addr.isBit(16) ? m16 : addr.isBit(32) ? m32 : addr.isBit(64) ? m64 : 0;
+ if (!code) throw Error(ERR_BAD_MEM_SIZE);
+ if (m64ext && addr.isBit(64)) ext = m64ext;
+
+ rex(addr, st0);
+ db(code);
+ opAddr(addr, ext);
+ }
+ // use code1 if reg1 == st0
+ // use code2 if reg1 != st0 && reg2 == st0
+ void opFpuFpu(const Fpu& reg1, const Fpu& reg2, uint32 code1, uint32 code2)
+ {
+ uint32 code = reg1.getIdx() == 0 ? code1 : reg2.getIdx() == 0 ? code2 : 0;
+ if (!code) throw Error(ERR_BAD_ST_COMBINATION);
+ db(uint8(code >> 8));
+ db(uint8(code | (reg1.getIdx() | reg2.getIdx())));
+ }
+ void opFpu(const Fpu& reg, uint8 code1, uint8 code2)
+ {
+ db(code1); db(code2 | reg.getIdx());
+ }
+ void opVex(const Reg& r, const Operand *p1, const Operand& op2, int type, int code, int imm8 = NONE)
+ {
+ if (op2.isMEM()) {
+ const Address& addr = op2.getAddress();
+ const RegExp& regExp = addr.getRegExp();
+ const Reg& base = regExp.getBase();
+ const Reg& index = regExp.getIndex();
+ if (BIT == 64 && addr.is32bit()) db(0x67);
+ int disp8N = 0;
+ bool x = index.isExtIdx();
+ if ((type & (T_MUST_EVEX|T_MEM_EVEX)) || r.hasEvex() || (p1 && p1->hasEvex()) || addr.isBroadcast() || addr.getOpmaskIdx()) {
+ int aaa = addr.getOpmaskIdx();
+ if (aaa && !(type & T_M_K)) throw Error(ERR_INVALID_OPMASK_WITH_MEMORY);
+ bool b = false;
+ if (addr.isBroadcast()) {
+ if (!(type & (T_B32 | T_B64))) throw Error(ERR_INVALID_BROADCAST);
+ b = true;
+ }
+ int VL = regExp.isVsib() ? index.getBit() : 0;
+ disp8N = evex(r, base, p1, type, code, x, b, aaa, VL, index.isExtIdx2());
+ } else {
+ vex(r, base, p1, type, code, x);
+ }
+ opAddr(addr, r.getIdx(), (imm8 != NONE) ? 1 : 0, disp8N, (type & T_VSIB) != 0);
+ } else {
+ const Reg& base = op2.getReg();
+ if ((type & T_MUST_EVEX) || r.hasEvex() || (p1 && p1->hasEvex()) || base.hasEvex()) {
+ evex(r, base, p1, type, code);
+ } else {
+ vex(r, base, p1, type, code);
+ }
+ setModRM(3, r.getIdx(), base.getIdx());
+ }
+ if (imm8 != NONE) db(imm8);
+ }
+ // (r, r, r/m) if isR_R_RM
+ // (r, r/m, r)
+ void opGpr(const Reg32e& r, const Operand& op1, const Operand& op2, int type, uint8 code, bool isR_R_RM, int imm8 = NONE)
+ {
+ const Operand *p1 = &op1;
+ const Operand *p2 = &op2;
+ if (!isR_R_RM) std::swap(p1, p2);
+ const unsigned int bit = r.getBit();
+ if (p1->getBit() != bit || (p2->isREG() && p2->getBit() != bit)) throw Error(ERR_BAD_COMBINATION);
+ type |= (bit == 64) ? T_W1 : T_W0;
+ opVex(r, p1, *p2, type, code, imm8);
+ }
+ void opAVX_X_X_XM(const Xmm& x1, const Operand& op1, const Operand& op2, int type, int code0, int imm8 = NONE)
+ {
+ const Xmm *x2 = static_cast<const Xmm*>(&op1);
+ const Operand *op = &op2;
+ if (op2.isNone()) { // (x1, op1) -> (x1, x1, op1)
+ x2 = &x1;
+ op = &op1;
+ }
+ // (x1, x2, op)
+ if (!((x1.isXMM() && x2->isXMM()) || ((type & T_YMM) && ((x1.isYMM() && x2->isYMM()) || (x1.isZMM() && x2->isZMM()))))) throw Error(ERR_BAD_COMBINATION);
+ opVex(x1, x2, *op, type, code0, imm8);
+ }
+ void opAVX_K_X_XM(const Opmask& k, const Xmm& x2, const Operand& op3, int type, int code0, int imm8 = NONE)
+ {
+ if (!op3.isMEM() && (x2.getKind() != op3.getKind())) throw Error(ERR_BAD_COMBINATION);
+ opVex(k, &x2, op3, type, code0, imm8);
+ }
+ // (x, x/m), (y, x/m256), (z, y/m)
+ void checkCvt1(const Operand& x, const Operand& op) const
+ {
+ if (!op.isMEM() && !(x.is(Operand::XMM | Operand::YMM) && op.isXMM()) && !(x.isZMM() && op.isYMM())) throw Error(ERR_BAD_COMBINATION);
+ }
+ // (x, x/m), (x, y/m256), (y, z/m)
+ void checkCvt2(const Xmm& x, const Operand& op) const
+ {
+ if (!(x.isXMM() && op.is(Operand::XMM | Operand::YMM | Operand::MEM)) && !(x.isYMM() && op.is(Operand::ZMM | Operand::MEM))) throw Error(ERR_BAD_COMBINATION);
+ }
+ void opCvt2(const Xmm& x, const Operand& op, int type, int code)
+ {
+ checkCvt2(x, op);
+ Operand::Kind kind = x.isXMM() ? (op.isBit(256) ? Operand::YMM : Operand::XMM) : Operand::ZMM;
+ opVex(x.copyAndSetKind(kind), &xm0, op, type, code);
+ }
+ void opCvt3(const Xmm& x1, const Xmm& x2, const Operand& op, int type, int type64, int type32, uint8 code)
+ {
+ if (!(x1.isXMM() && x2.isXMM() && (op.isREG(i32e) || op.isMEM()))) throw Error(ERR_BAD_SIZE_OF_REGISTER);
+ Xmm x(op.getIdx());
+ const Operand *p = op.isREG() ? &x : &op;
+ opVex(x1, &x2, *p, type | (op.isBit(64) ? type64 : type32), code);
+ }
+ const Xmm& cvtIdx0(const Operand& x) const
+ {
+ return x.isZMM() ? zm0 : x.isYMM() ? ym0 : xm0;
+ }
+ // support (x, x/m, imm), (y, y/m, imm)
+ void opAVX_X_XM_IMM(const Xmm& x, const Operand& op, int type, int code, int imm8 = NONE)
+ {
+ opAVX_X_X_XM(x, cvtIdx0(x), op, type, code, imm8);
+ }
+ // QQQ:need to refactor
+ void opSp1(const Reg& reg, const Operand& op, uint8 pref, uint8 code0, uint8 code1)
+ {
+ if (reg.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER);
+ bool is16bit = reg.isREG(16) && (op.isREG(16) || op.isMEM());
+ if (!is16bit && !(reg.isREG(i32e) && (op.isREG(reg.getBit()) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION);
+ if (is16bit) db(0x66);
+ db(pref); opModRM(reg.changeBit(i32e == 32 ? 32 : reg.getBit()), op, op.isREG(), true, code0, code1);
+ }
+ void opGather(const Xmm& x1, const Address& addr, const Xmm& x2, int type, uint8 code, int mode)
+ {
+ const RegExp& regExp = addr.getRegExp();
+ if (!regExp.isVsib(128 | 256)) throw Error(ERR_BAD_VSIB_ADDRESSING);
+ const int y_vx_y = 0;
+ const int y_vy_y = 1;
+// const int x_vy_x = 2;
+ const bool isAddrYMM = regExp.getIndex().getBit() == 256;
+ if (!x1.isXMM() || isAddrYMM || !x2.isXMM()) {
+ bool isOK = false;
+ if (mode == y_vx_y) {
+ isOK = x1.isYMM() && !isAddrYMM && x2.isYMM();
+ } else if (mode == y_vy_y) {
+ isOK = x1.isYMM() && isAddrYMM && x2.isYMM();
+ } else { // x_vy_x
+ isOK = !x1.isYMM() && isAddrYMM && !x2.isYMM();
+ }
+ if (!isOK) throw Error(ERR_BAD_VSIB_ADDRESSING);
+ }
+ opAVX_X_X_XM(isAddrYMM ? Ymm(x1.getIdx()) : x1, isAddrYMM ? Ymm(x2.getIdx()) : x2, addr, type, code);
+ }
+ enum {
+ xx_yy_zz = 0,
+ xx_yx_zy = 1,
+ xx_xy_yz = 2
+ };
+ void checkGather2(const Xmm& x1, const Reg& x2, int mode) const
+ {
+ if (x1.isXMM() && x2.isXMM()) return;
+ switch (mode) {
+ case xx_yy_zz: if ((x1.isYMM() && x2.isYMM()) || (x1.isZMM() && x2.isZMM())) return;
+ break;
+ case xx_yx_zy: if ((x1.isYMM() && x2.isXMM()) || (x1.isZMM() && x2.isYMM())) return;
+ break;
+ case xx_xy_yz: if ((x1.isXMM() && x2.isYMM()) || (x1.isYMM() && x2.isZMM())) return;
+ break;
+ }
+ throw Error(ERR_BAD_VSIB_ADDRESSING);
+ }
+ void opGather2(const Xmm& x, const Address& addr, int type, uint8 code, int mode)
+ {
+ if (x.hasZero()) throw Error(ERR_INVALID_ZERO);
+ checkGather2(x, addr.getRegExp().getIndex(), mode);
+ opVex(x, 0, addr, type, code);
+ }
+ /*
+ xx_xy_yz ; mode = true
+ xx_xy_xz ; mode = false
+ */
+ void opVmov(const Operand& op, const Xmm& x, int type, uint8 code, bool mode)
+ {
+ if (mode) {
+ if (!op.isMEM() && !((op.isXMM() && x.isXMM()) || (op.isXMM() && x.isYMM()) || (op.isYMM() && x.isZMM()))) throw Error(ERR_BAD_COMBINATION);
+ } else {
+ if (!op.isMEM() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION);
+ }
+ opVex(x, 0, op, type, code);
+ }
+ void opGatherFetch(const Address& addr, const Xmm& x, int type, uint8 code, Operand::Kind kind)
+ {
+ if (addr.hasZero()) throw Error(ERR_INVALID_ZERO);
+ if (addr.getRegExp().getIndex().getKind() != kind) throw Error(ERR_BAD_VSIB_ADDRESSING);
+ opVex(x, 0, addr, type, code);
+ }
+public:
+ unsigned int getVersion() const { return VERSION; }
+ using CodeArray::db;
+ const Mmx mm0, mm1, mm2, mm3, mm4, mm5, mm6, mm7;
+ const Xmm xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7;
+ const Ymm ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7;
+ const Zmm zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7;
+ const Xmm &xm0, &xm1, &xm2, &xm3, &xm4, &xm5, &xm6, &xm7;
+ const Ymm &ym0, &ym1, &ym2, &ym3, &ym4, &ym5, &ym6, &ym7;
+ const Ymm &zm0, &zm1, &zm2, &zm3, &zm4, &zm5, &zm6, &zm7;
+ const Reg32 eax, ecx, edx, ebx, esp, ebp, esi, edi;
+ const Reg16 ax, cx, dx, bx, sp, bp, si, di;
+ const Reg8 al, cl, dl, bl, ah, ch, dh, bh;
+ const AddressFrame ptr, byte, word, dword, qword, xword, yword, zword; // xword is same as oword of NASM
+ const AddressFrame ptr_b, xword_b, yword_b, zword_b; // broadcast such as {1to2}, {1to4}, {1to8}, {1to16}, {b}
+ const Fpu st0, st1, st2, st3, st4, st5, st6, st7;
+ const Opmask k0, k1, k2, k3, k4, k5, k6, k7;
+ const BoundsReg bnd0, bnd1, bnd2, bnd3;
+ const EvexModifierRounding T_sae, T_rn_sae, T_rd_sae, T_ru_sae, T_rz_sae; // {sae}, {rn-sae}, {rd-sae}, {ru-sae}, {rz-sae}
+ const EvexModifierZero T_z; // {z}
+#ifdef XBYAK64
+ const Reg64 rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi, r8, r9, r10, r11, r12, r13, r14, r15;
+ const Reg32 r8d, r9d, r10d, r11d, r12d, r13d, r14d, r15d;
+ const Reg16 r8w, r9w, r10w, r11w, r12w, r13w, r14w, r15w;
+ const Reg8 r8b, r9b, r10b, r11b, r12b, r13b, r14b, r15b;
+ const Reg8 spl, bpl, sil, dil;
+ const Xmm xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15;
+ const Xmm xmm16, xmm17, xmm18, xmm19, xmm20, xmm21, xmm22, xmm23;
+ const Xmm xmm24, xmm25, xmm26, xmm27, xmm28, xmm29, xmm30, xmm31;
+ const Ymm ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15;
+ const Ymm ymm16, ymm17, ymm18, ymm19, ymm20, ymm21, ymm22, ymm23;
+ const Ymm ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31;
+ const Zmm zmm8, zmm9, zmm10, zmm11, zmm12, zmm13, zmm14, zmm15;
+ const Zmm zmm16, zmm17, zmm18, zmm19, zmm20, zmm21, zmm22, zmm23;
+ const Zmm zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31;
+ const Xmm &xm8, &xm9, &xm10, &xm11, &xm12, &xm13, &xm14, &xm15; // for my convenience
+ const Xmm &xm16, &xm17, &xm18, &xm19, &xm20, &xm21, &xm22, &xm23;
+ const Xmm &xm24, &xm25, &xm26, &xm27, &xm28, &xm29, &xm30, &xm31;
+ const Ymm &ym8, &ym9, &ym10, &ym11, &ym12, &ym13, &ym14, &ym15;
+ const Ymm &ym16, &ym17, &ym18, &ym19, &ym20, &ym21, &ym22, &ym23;
+ const Ymm &ym24, &ym25, &ym26, &ym27, &ym28, &ym29, &ym30, &ym31;
+ const Zmm &zm8, &zm9, &zm10, &zm11, &zm12, &zm13, &zm14, &zm15;
+ const Zmm &zm16, &zm17, &zm18, &zm19, &zm20, &zm21, &zm22, &zm23;
+ const Zmm &zm24, &zm25, &zm26, &zm27, &zm28, &zm29, &zm30, &zm31;
+ const RegRip rip;
+#endif
+#ifndef XBYAK_DISABLE_SEGMENT
+ const Segment es, cs, ss, ds, fs, gs;
+#endif
+ void L(const std::string& label) { labelMgr_.defineSlabel(label); }
+ void L(Label& label) { labelMgr_.defineClabel(label); }
+ Label L() { Label label; L(label); return label; }
+ void inLocalLabel() { labelMgr_.enterLocal(); }
+ void outLocalLabel() { labelMgr_.leaveLocal(); }
+ /*
+ assign src to dst
+ require
+ dst : does not used by L()
+ src : used by L()
+ */
+ void assignL(Label& dst, const Label& src) { labelMgr_.assign(dst, src); }
+ /*
+ put address of label to buffer
+ @note the put size is 4(32-bit), 8(64-bit)
+ */
+ void putL(std::string label) { putL_inner(label); }
+ void putL(const Label& label) { putL_inner(label); }
+
+ void jmp(const Operand& op) { opR_ModM(op, BIT, 4, 0xFF, NONE, NONE, true); }
+ void jmp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); }
+ void jmp(const char *label, LabelType type = T_AUTO) { jmp(std::string(label), type); }
+ void jmp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0xEB, 0xE9, 0); }
+ void jmp(const void *addr, LabelType type = T_AUTO) { opJmpAbs(addr, type, 0xEB, 0xE9); }
+
+ void call(const Operand& op) { opR_ModM(op, 16 | i32e, 2, 0xFF, NONE, NONE, true); }
+ // call(string label), not const std::string&
+ void call(std::string label) { opJmp(label, T_NEAR, 0, 0xE8, 0); }
+ void call(const char *label) { call(std::string(label)); }
+ void call(const Label& label) { opJmp(label, T_NEAR, 0, 0xE8, 0); }
+ // call(function pointer)
+#ifdef XBYAK_VARIADIC_TEMPLATE
+ template<class Ret, class... Params>
+ void call(Ret(*func)(Params...)) { call(reinterpret_cast<const void*>(func)); }
+#endif
+ void call(const void *addr) { opJmpAbs(addr, T_NEAR, 0, 0xE8); }
+
+ void test(const Operand& op, const Reg& reg)
+ {
+ opModRM(reg, op, op.isREG() && (op.getKind() == reg.getKind()), op.isMEM(), 0x84);
+ }
+ void test(const Operand& op, uint32 imm)
+ {
+ verifyMemHasSize(op);
+ int immSize = (std::min)(op.getBit() / 8, 4U);
+ if (op.isREG() && op.getIdx() == 0) { // al, ax, eax
+ rex(op);
+ db(0xA8 | (op.isBit(8) ? 0 : 1));
+ } else {
+ opR_ModM(op, 0, 0, 0xF6, NONE, NONE, false, immSize);
+ }
+ db(imm, immSize);
+ }
+ void imul(const Reg& reg, const Operand& op)
+ {
+ opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x0F, 0xAF);
+ }
+ void imul(const Reg& reg, const Operand& op, int imm)
+ {
+ int s = inner::IsInDisp8(imm) ? 1 : 0;
+ int immSize = s ? 1 : reg.isREG(16) ? 2 : 4;
+ opModRM(reg, op, op.isREG() && (reg.getKind() == op.getKind()), op.isMEM(), 0x69 | (s << 1), NONE, NONE, immSize);
+ db(imm, immSize);
+ }
+ void push(const Operand& op) { opPushPop(op, 0xFF, 6, 0x50); }
+ void pop(const Operand& op) { opPushPop(op, 0x8F, 0, 0x58); }
+ void push(const AddressFrame& af, uint32 imm)
+ {
+ if (af.bit_ == 8 && inner::IsInDisp8(imm)) {
+ db(0x6A); db(imm);
+ } else if (af.bit_ == 16 && isInDisp16(imm)) {
+ db(0x66); db(0x68); dw(imm);
+ } else {
+ db(0x68); dd(imm);
+ }
+ }
+ /* use "push(word, 4)" if you want "push word 4" */
+ void push(uint32 imm)
+ {
+ if (inner::IsInDisp8(imm)) {
+ push(byte, imm);
+ } else {
+ push(dword, imm);
+ }
+ }
+ void mov(const Operand& reg1, const Operand& reg2)
+ {
+ const Reg *reg = 0;
+ const Address *addr = 0;
+ uint8 code = 0;
+ if (reg1.isREG() && reg1.getIdx() == 0 && reg2.isMEM()) { // mov eax|ax|al, [disp]
+ reg = &reg1.getReg();
+ addr= &reg2.getAddress();
+ code = 0xA0;
+ } else
+ if (reg1.isMEM() && reg2.isREG() && reg2.getIdx() == 0) { // mov [disp], eax|ax|al
+ reg = &reg2.getReg();
+ addr= &reg1.getAddress();
+ code = 0xA2;
+ }
+#ifdef XBYAK64
+ if (addr && addr->is64bitDisp()) {
+ if (code) {
+ rex(*reg);
+ db(reg1.isREG(8) ? 0xA0 : reg1.isREG() ? 0xA1 : reg2.isREG(8) ? 0xA2 : 0xA3);
+ db(addr->getDisp(), 8);
+ } else {
+ throw Error(ERR_BAD_COMBINATION);
+ }
+ } else
+#else
+ if (code && addr->isOnlyDisp()) {
+ rex(*reg, *addr);
+ db(code | (reg->isBit(8) ? 0 : 1));
+ dd(static_cast<uint32>(addr->getDisp()));
+ } else
+#endif
+ {
+ opRM_RM(reg1, reg2, 0x88);
+ }
+ }
+ void mov(const Operand& op, size_t imm)
+ {
+ if (op.isREG()) {
+ const int size = mov_imm(op.getReg(), imm);
+ db(imm, size);
+ } else if (op.isMEM()) {
+ verifyMemHasSize(op);
+ int immSize = op.getBit() / 8;
+ if (immSize <= 4) {
+ sint64 s = sint64(imm) >> (immSize * 8);
+ if (s != 0 && s != -1) throw Error(ERR_IMM_IS_TOO_BIG);
+ } else {
+ if (!inner::IsInInt32(imm)) throw Error(ERR_IMM_IS_TOO_BIG);
+ immSize = 4;
+ }
+ opModM(op.getAddress(), Reg(0, Operand::REG, op.getBit()), 0xC6, NONE, NONE, immSize);
+ db(static_cast<uint32>(imm), immSize);
+ } else {
+ throw Error(ERR_BAD_COMBINATION);
+ }
+ }
+ void mov(const NativeReg& reg, const char *label) // can't use std::string
+ {
+ if (label == 0) {
+ mov(static_cast<const Operand&>(reg), 0); // call imm
+ return;
+ }
+ mov_imm(reg, dummyAddr);
+ putL(label);
+ }
+ void mov(const NativeReg& reg, const Label& label)
+ {
+ mov_imm(reg, dummyAddr);
+ putL(label);
+ }
+ void xchg(const Operand& op1, const Operand& op2)
+ {
+ const Operand *p1 = &op1, *p2 = &op2;
+ if (p1->isMEM() || (p2->isREG(16 | i32e) && p2->getIdx() == 0)) {
+ p1 = &op2; p2 = &op1;
+ }
+ if (p1->isMEM()) throw Error(ERR_BAD_COMBINATION);
+ if (p2->isREG() && (p1->isREG(16 | i32e) && p1->getIdx() == 0)
+#ifdef XBYAK64
+ && (p2->getIdx() != 0 || !p1->isREG(32))
+#endif
+ ) {
+ rex(*p2, *p1); db(0x90 | (p2->getIdx() & 7));
+ return;
+ }
+ opModRM(*p1, *p2, (p1->isREG() && p2->isREG() && (p1->getBit() == p2->getBit())), p2->isMEM(), 0x86 | (p1->isBit(8) ? 0 : 1));
+ }
+
+#ifndef XBYAK_DISABLE_SEGMENT
+ void push(const Segment& seg)
+ {
+ switch (seg.getIdx()) {
+ case Segment::es: db(0x06); break;
+ case Segment::cs: db(0x0E); break;
+ case Segment::ss: db(0x16); break;
+ case Segment::ds: db(0x1E); break;
+ case Segment::fs: db(0x0F); db(0xA0); break;
+ case Segment::gs: db(0x0F); db(0xA8); break;
+ default:
+ assert(0);
+ }
+ }
+ void pop(const Segment& seg)
+ {
+ switch (seg.getIdx()) {
+ case Segment::es: db(0x07); break;
+ case Segment::cs: throw Error(ERR_BAD_COMBINATION);
+ case Segment::ss: db(0x17); break;
+ case Segment::ds: db(0x1F); break;
+ case Segment::fs: db(0x0F); db(0xA1); break;
+ case Segment::gs: db(0x0F); db(0xA9); break;
+ default:
+ assert(0);
+ }
+ }
+ void putSeg(const Segment& seg)
+ {
+ switch (seg.getIdx()) {
+ case Segment::es: db(0x2E); break;
+ case Segment::cs: db(0x36); break;
+ case Segment::ss: db(0x3E); break;
+ case Segment::ds: db(0x26); break;
+ case Segment::fs: db(0x64); break;
+ case Segment::gs: db(0x65); break;
+ default:
+ assert(0);
+ }
+ }
+ void mov(const Operand& op, const Segment& seg)
+ {
+ opModRM(Reg8(seg.getIdx()), op, op.isREG(16|i32e), op.isMEM(), 0x8C);
+ }
+ void mov(const Segment& seg, const Operand& op)
+ {
+ opModRM(Reg8(seg.getIdx()), op.isREG(16|i32e) ? static_cast<const Operand&>(op.getReg().cvt32()) : op, op.isREG(16|i32e), op.isMEM(), 0x8E);
+ }
+#endif
+
+ enum { NONE = 256 };
+ // constructor
+ CodeGenerator(size_t maxSize = DEFAULT_MAX_CODE_SIZE, void *userPtr = 0, Allocator *allocator = 0)
+ : CodeArray(maxSize, userPtr, allocator)
+ , mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7)
+ , xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7)
+ , ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7)
+ , zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7)
+ // for my convenience
+ , xm0(xmm0), xm1(xmm1), xm2(xmm2), xm3(xmm3), xm4(xmm4), xm5(xmm5), xm6(xmm6), xm7(xmm7)
+ , ym0(ymm0), ym1(ymm1), ym2(ymm2), ym3(ymm3), ym4(ymm4), ym5(ymm5), ym6(ymm6), ym7(ymm7)
+ , zm0(zmm0), zm1(zmm1), zm2(zmm2), zm3(zmm3), zm4(zmm4), zm5(zmm5), zm6(zmm6), zm7(zmm7)
+
+ , eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI)
+ , ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI)
+ , al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH)
+ , ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512)
+ , ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true)
+ , st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7)
+ , k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7)
+ , bnd0(0), bnd1(1), bnd2(2), bnd3(3)
+ , T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE)
+ , T_z()
+#ifdef XBYAK64
+ , rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15)
+ , r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15)
+ , r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15)
+ , r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15)
+ , spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true)
+ , xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15)
+ , xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23)
+ , xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31)
+ , ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15)
+ , ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23)
+ , ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31)
+ , zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15)
+ , zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23)
+ , zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31)
+ // for my convenience
+ , xm8(xmm8), xm9(xmm9), xm10(xmm10), xm11(xmm11), xm12(xmm12), xm13(xmm13), xm14(xmm14), xm15(xmm15)
+ , xm16(xmm16), xm17(xmm17), xm18(xmm18), xm19(xmm19), xm20(xmm20), xm21(xmm21), xm22(xmm22), xm23(xmm23)
+ , xm24(xmm24), xm25(xmm25), xm26(xmm26), xm27(xmm27), xm28(xmm28), xm29(xmm29), xm30(xmm30), xm31(xmm31)
+ , ym8(ymm8), ym9(ymm9), ym10(ymm10), ym11(ymm11), ym12(ymm12), ym13(ymm13), ym14(ymm14), ym15(ymm15)
+ , ym16(ymm16), ym17(ymm17), ym18(ymm18), ym19(ymm19), ym20(ymm20), ym21(ymm21), ym22(ymm22), ym23(ymm23)
+ , ym24(ymm24), ym25(ymm25), ym26(ymm26), ym27(ymm27), ym28(ymm28), ym29(ymm29), ym30(ymm30), ym31(ymm31)
+ , zm8(zmm8), zm9(zmm9), zm10(zmm10), zm11(zmm11), zm12(zmm12), zm13(zmm13), zm14(zmm14), zm15(zmm15)
+ , zm16(zmm16), zm17(zmm17), zm18(zmm18), zm19(zmm19), zm20(zmm20), zm21(zmm21), zm22(zmm22), zm23(zmm23)
+ , zm24(zmm24), zm25(zmm25), zm26(zmm26), zm27(zmm27), zm28(zmm28), zm29(zmm29), zm30(zmm30), zm31(zmm31)
+ , rip()
+#endif
+#ifndef XBYAK_DISABLE_SEGMENT
+ , es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs)
+#endif
+ {
+ labelMgr_.set(this);
+ }
+ void reset()
+ {
+ resetSize();
+ labelMgr_.reset();
+ labelMgr_.set(this);
+ }
+ bool hasUndefinedLabel() const { return labelMgr_.hasUndefSlabel() || labelMgr_.hasUndefClabel(); }
+ /*
+ MUST call ready() to complete generating code if you use AutoGrow mode.
+ It is not necessary for the other mode if hasUndefinedLabel() is true.
+ */
+ void ready(ProtectMode mode = PROTECT_RWE)
+ {
+ if (hasUndefinedLabel()) throw Error(ERR_LABEL_IS_NOT_FOUND);
+ if (isAutoGrow()) {
+ calcJmpAddress();
+ if (useProtect()) setProtectMode(mode);
+ }
+ }
+ // set read/exec
+ void readyRE() { return ready(PROTECT_RE); }
+#ifdef XBYAK_TEST
+ void dump(bool doClear = true)
+ {
+ CodeArray::dump();
+ if (doClear) size_ = 0;
+ }
+#endif
+
+#ifdef XBYAK_UNDEF_JNL
+ #undef jnl
+#endif
+
+ /*
+ use single byte nop if useMultiByteNop = false
+ */
+ void nop(size_t size = 1, bool useMultiByteNop = true)
+ {
+ if (!useMultiByteNop) {
+ for (size_t i = 0; i < size; i++) {
+ db(0x90);
+ }
+ return;
+ }
+ /*
+ Intel Architectures Software Developer's Manual Volume 2
+ recommended multi-byte sequence of NOP instruction
+ AMD and Intel seem to agree on the same sequences for up to 9 bytes:
+ https://support.amd.com/TechDocs/55723_SOG_Fam_17h_Processors_3.00.pdf
+ */
+ static const uint8 nopTbl[9][9] = {
+ {0x90},
+ {0x66, 0x90},
+ {0x0F, 0x1F, 0x00},
+ {0x0F, 0x1F, 0x40, 0x00},
+ {0x0F, 0x1F, 0x44, 0x00, 0x00},
+ {0x66, 0x0F, 0x1F, 0x44, 0x00, 0x00},
+ {0x0F, 0x1F, 0x80, 0x00, 0x00, 0x00, 0x00},
+ {0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00},
+ {0x66, 0x0F, 0x1F, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00},
+ };
+ const size_t n = sizeof(nopTbl) / sizeof(nopTbl[0]);
+ while (size > 0) {
+ size_t len = (std::min)(n, size);
+ const uint8 *seq = nopTbl[len - 1];
+ db(seq, len);
+ size -= len;
+ }
+ }
+
+#ifndef XBYAK_DONT_READ_LIST
+#include "xbyak_mnemonic.h"
+ /*
+ use single byte nop if useMultiByteNop = false
+ */
+ void align(size_t x = 16, bool useMultiByteNop = true)
+ {
+ if (x == 1) return;
+ if (x < 1 || (x & (x - 1))) throw Error(ERR_BAD_ALIGN);
+ if (isAutoGrow() && x > inner::ALIGN_PAGE_SIZE) fprintf(stderr, "warning:autoGrow mode does not support %d align\n", (int)x);
+ size_t remain = size_t(getCurr()) % x;
+ if (remain) {
+ nop(x - remain, useMultiByteNop);
+ }
+ }
+#endif
+};
+
+namespace util {
+static const Mmx mm0(0), mm1(1), mm2(2), mm3(3), mm4(4), mm5(5), mm6(6), mm7(7);
+static const Xmm xmm0(0), xmm1(1), xmm2(2), xmm3(3), xmm4(4), xmm5(5), xmm6(6), xmm7(7);
+static const Ymm ymm0(0), ymm1(1), ymm2(2), ymm3(3), ymm4(4), ymm5(5), ymm6(6), ymm7(7);
+static const Zmm zmm0(0), zmm1(1), zmm2(2), zmm3(3), zmm4(4), zmm5(5), zmm6(6), zmm7(7);
+static const Reg32 eax(Operand::EAX), ecx(Operand::ECX), edx(Operand::EDX), ebx(Operand::EBX), esp(Operand::ESP), ebp(Operand::EBP), esi(Operand::ESI), edi(Operand::EDI);
+static const Reg16 ax(Operand::AX), cx(Operand::CX), dx(Operand::DX), bx(Operand::BX), sp(Operand::SP), bp(Operand::BP), si(Operand::SI), di(Operand::DI);
+static const Reg8 al(Operand::AL), cl(Operand::CL), dl(Operand::DL), bl(Operand::BL), ah(Operand::AH), ch(Operand::CH), dh(Operand::DH), bh(Operand::BH);
+static const AddressFrame ptr(0), byte(8), word(16), dword(32), qword(64), xword(128), yword(256), zword(512);
+static const AddressFrame ptr_b(0, true), xword_b(128, true), yword_b(256, true), zword_b(512, true);
+static const Fpu st0(0), st1(1), st2(2), st3(3), st4(4), st5(5), st6(6), st7(7);
+static const Opmask k0(0), k1(1), k2(2), k3(3), k4(4), k5(5), k6(6), k7(7);
+static const BoundsReg bnd0(0), bnd1(1), bnd2(2), bnd3(3);
+static const EvexModifierRounding T_sae(EvexModifierRounding::T_SAE), T_rn_sae(EvexModifierRounding::T_RN_SAE), T_rd_sae(EvexModifierRounding::T_RD_SAE), T_ru_sae(EvexModifierRounding::T_RU_SAE), T_rz_sae(EvexModifierRounding::T_RZ_SAE);
+static const EvexModifierZero T_z;
+#ifdef XBYAK64
+static const Reg64 rax(Operand::RAX), rcx(Operand::RCX), rdx(Operand::RDX), rbx(Operand::RBX), rsp(Operand::RSP), rbp(Operand::RBP), rsi(Operand::RSI), rdi(Operand::RDI), r8(Operand::R8), r9(Operand::R9), r10(Operand::R10), r11(Operand::R11), r12(Operand::R12), r13(Operand::R13), r14(Operand::R14), r15(Operand::R15);
+static const Reg32 r8d(8), r9d(9), r10d(10), r11d(11), r12d(12), r13d(13), r14d(14), r15d(15);
+static const Reg16 r8w(8), r9w(9), r10w(10), r11w(11), r12w(12), r13w(13), r14w(14), r15w(15);
+static const Reg8 r8b(8), r9b(9), r10b(10), r11b(11), r12b(12), r13b(13), r14b(14), r15b(15), spl(Operand::SPL, true), bpl(Operand::BPL, true), sil(Operand::SIL, true), dil(Operand::DIL, true);
+static const Xmm xmm8(8), xmm9(9), xmm10(10), xmm11(11), xmm12(12), xmm13(13), xmm14(14), xmm15(15);
+static const Xmm xmm16(16), xmm17(17), xmm18(18), xmm19(19), xmm20(20), xmm21(21), xmm22(22), xmm23(23);
+static const Xmm xmm24(24), xmm25(25), xmm26(26), xmm27(27), xmm28(28), xmm29(29), xmm30(30), xmm31(31);
+static const Ymm ymm8(8), ymm9(9), ymm10(10), ymm11(11), ymm12(12), ymm13(13), ymm14(14), ymm15(15);
+static const Ymm ymm16(16), ymm17(17), ymm18(18), ymm19(19), ymm20(20), ymm21(21), ymm22(22), ymm23(23);
+static const Ymm ymm24(24), ymm25(25), ymm26(26), ymm27(27), ymm28(28), ymm29(29), ymm30(30), ymm31(31);
+static const Zmm zmm8(8), zmm9(9), zmm10(10), zmm11(11), zmm12(12), zmm13(13), zmm14(14), zmm15(15);
+static const Zmm zmm16(16), zmm17(17), zmm18(18), zmm19(19), zmm20(20), zmm21(21), zmm22(22), zmm23(23);
+static const Zmm zmm24(24), zmm25(25), zmm26(26), zmm27(27), zmm28(28), zmm29(29), zmm30(30), zmm31(31);
+static const RegRip rip;
+#endif
+#ifndef XBYAK_DISABLE_SEGMENT
+static const Segment es(Segment::es), cs(Segment::cs), ss(Segment::ss), ds(Segment::ds), fs(Segment::fs), gs(Segment::gs);
+#endif
+} // util
+
+#ifdef _MSC_VER
+ #pragma warning(pop)
+#endif
+
+} // end of namespace
+
+#endif // XBYAK_XBYAK_H_
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h
new file mode 100644
index 0000000000..a22e5224c3
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_bin2hex.h
@@ -0,0 +1,303 @@
+/*******************************************************************************
+* Copyright 2016-2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*******************************************************************************
+* Copyright (c) 2007 MITSUNARI Shigeo
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+* Redistributions of source code must retain the above copyright notice, this
+* list of conditions and the following disclaimer.
+* Redistributions in binary form must reproduce the above copyright notice,
+* this list of conditions and the following disclaimer in the documentation
+* and/or other materials provided with the distribution.
+* Neither the name of the copyright owner nor the names of its contributors may
+* be used to endorse or promote products derived from this software without
+* specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+*******************************************************************************/
+
+enum {
+ B00000000= 0,
+ B00000001= 1,
+ B00000010= 2,
+ B00000011= 3,
+ B00000100= 4,
+ B00000101= 5,
+ B00000110= 6,
+ B00000111= 7,
+ B00001000= 8,
+ B00001001= 9,
+ B00001010= 10,
+ B00001011= 11,
+ B00001100= 12,
+ B00001101= 13,
+ B00001110= 14,
+ B00001111= 15,
+ B00010000= 16,
+ B00010001= 17,
+ B00010010= 18,
+ B00010011= 19,
+ B00010100= 20,
+ B00010101= 21,
+ B00010110= 22,
+ B00010111= 23,
+ B00011000= 24,
+ B00011001= 25,
+ B00011010= 26,
+ B00011011= 27,
+ B00011100= 28,
+ B00011101= 29,
+ B00011110= 30,
+ B00011111= 31,
+ B00100000= 32,
+ B00100001= 33,
+ B00100010= 34,
+ B00100011= 35,
+ B00100100= 36,
+ B00100101= 37,
+ B00100110= 38,
+ B00100111= 39,
+ B00101000= 40,
+ B00101001= 41,
+ B00101010= 42,
+ B00101011= 43,
+ B00101100= 44,
+ B00101101= 45,
+ B00101110= 46,
+ B00101111= 47,
+ B00110000= 48,
+ B00110001= 49,
+ B00110010= 50,
+ B00110011= 51,
+ B00110100= 52,
+ B00110101= 53,
+ B00110110= 54,
+ B00110111= 55,
+ B00111000= 56,
+ B00111001= 57,
+ B00111010= 58,
+ B00111011= 59,
+ B00111100= 60,
+ B00111101= 61,
+ B00111110= 62,
+ B00111111= 63,
+ B01000000= 64,
+ B01000001= 65,
+ B01000010= 66,
+ B01000011= 67,
+ B01000100= 68,
+ B01000101= 69,
+ B01000110= 70,
+ B01000111= 71,
+ B01001000= 72,
+ B01001001= 73,
+ B01001010= 74,
+ B01001011= 75,
+ B01001100= 76,
+ B01001101= 77,
+ B01001110= 78,
+ B01001111= 79,
+ B01010000= 80,
+ B01010001= 81,
+ B01010010= 82,
+ B01010011= 83,
+ B01010100= 84,
+ B01010101= 85,
+ B01010110= 86,
+ B01010111= 87,
+ B01011000= 88,
+ B01011001= 89,
+ B01011010= 90,
+ B01011011= 91,
+ B01011100= 92,
+ B01011101= 93,
+ B01011110= 94,
+ B01011111= 95,
+ B01100000= 96,
+ B01100001= 97,
+ B01100010= 98,
+ B01100011= 99,
+ B01100100= 100,
+ B01100101= 101,
+ B01100110= 102,
+ B01100111= 103,
+ B01101000= 104,
+ B01101001= 105,
+ B01101010= 106,
+ B01101011= 107,
+ B01101100= 108,
+ B01101101= 109,
+ B01101110= 110,
+ B01101111= 111,
+ B01110000= 112,
+ B01110001= 113,
+ B01110010= 114,
+ B01110011= 115,
+ B01110100= 116,
+ B01110101= 117,
+ B01110110= 118,
+ B01110111= 119,
+ B01111000= 120,
+ B01111001= 121,
+ B01111010= 122,
+ B01111011= 123,
+ B01111100= 124,
+ B01111101= 125,
+ B01111110= 126,
+ B01111111= 127,
+ B10000000= 128,
+ B10000001= 129,
+ B10000010= 130,
+ B10000011= 131,
+ B10000100= 132,
+ B10000101= 133,
+ B10000110= 134,
+ B10000111= 135,
+ B10001000= 136,
+ B10001001= 137,
+ B10001010= 138,
+ B10001011= 139,
+ B10001100= 140,
+ B10001101= 141,
+ B10001110= 142,
+ B10001111= 143,
+ B10010000= 144,
+ B10010001= 145,
+ B10010010= 146,
+ B10010011= 147,
+ B10010100= 148,
+ B10010101= 149,
+ B10010110= 150,
+ B10010111= 151,
+ B10011000= 152,
+ B10011001= 153,
+ B10011010= 154,
+ B10011011= 155,
+ B10011100= 156,
+ B10011101= 157,
+ B10011110= 158,
+ B10011111= 159,
+ B10100000= 160,
+ B10100001= 161,
+ B10100010= 162,
+ B10100011= 163,
+ B10100100= 164,
+ B10100101= 165,
+ B10100110= 166,
+ B10100111= 167,
+ B10101000= 168,
+ B10101001= 169,
+ B10101010= 170,
+ B10101011= 171,
+ B10101100= 172,
+ B10101101= 173,
+ B10101110= 174,
+ B10101111= 175,
+ B10110000= 176,
+ B10110001= 177,
+ B10110010= 178,
+ B10110011= 179,
+ B10110100= 180,
+ B10110101= 181,
+ B10110110= 182,
+ B10110111= 183,
+ B10111000= 184,
+ B10111001= 185,
+ B10111010= 186,
+ B10111011= 187,
+ B10111100= 188,
+ B10111101= 189,
+ B10111110= 190,
+ B10111111= 191,
+ B11000000= 192,
+ B11000001= 193,
+ B11000010= 194,
+ B11000011= 195,
+ B11000100= 196,
+ B11000101= 197,
+ B11000110= 198,
+ B11000111= 199,
+ B11001000= 200,
+ B11001001= 201,
+ B11001010= 202,
+ B11001011= 203,
+ B11001100= 204,
+ B11001101= 205,
+ B11001110= 206,
+ B11001111= 207,
+ B11010000= 208,
+ B11010001= 209,
+ B11010010= 210,
+ B11010011= 211,
+ B11010100= 212,
+ B11010101= 213,
+ B11010110= 214,
+ B11010111= 215,
+ B11011000= 216,
+ B11011001= 217,
+ B11011010= 218,
+ B11011011= 219,
+ B11011100= 220,
+ B11011101= 221,
+ B11011110= 222,
+ B11011111= 223,
+ B11100000= 224,
+ B11100001= 225,
+ B11100010= 226,
+ B11100011= 227,
+ B11100100= 228,
+ B11100101= 229,
+ B11100110= 230,
+ B11100111= 231,
+ B11101000= 232,
+ B11101001= 233,
+ B11101010= 234,
+ B11101011= 235,
+ B11101100= 236,
+ B11101101= 237,
+ B11101110= 238,
+ B11101111= 239,
+ B11110000= 240,
+ B11110001= 241,
+ B11110010= 242,
+ B11110011= 243,
+ B11110100= 244,
+ B11110101= 245,
+ B11110110= 246,
+ B11110111= 247,
+ B11111000= 248,
+ B11111001= 249,
+ B11111010= 250,
+ B11111011= 251,
+ B11111100= 252,
+ B11111101= 253,
+ B11111110= 254,
+ B11111111= 255
+};
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h
new file mode 100644
index 0000000000..28d2d222f9
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_mnemonic.h
@@ -0,0 +1,2017 @@
+/*******************************************************************************
+* Copyright 2016-2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*******************************************************************************
+* Copyright (c) 2007 MITSUNARI Shigeo
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+* Redistributions of source code must retain the above copyright notice, this
+* list of conditions and the following disclaimer.
+* Redistributions in binary form must reproduce the above copyright notice,
+* this list of conditions and the following disclaimer in the documentation
+* and/or other materials provided with the distribution.
+* Neither the name of the copyright owner nor the names of its contributors may
+* be used to endorse or promote products derived from this software without
+* specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+*******************************************************************************/
+
+const char *getVersionString() const { return "5.76"; }
+void adc(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x10, 2); }
+void adc(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x10); }
+void adcx(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0x66, isREG32_REG32orMEM, NONE, 0x38); }
+void add(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x00, 0); }
+void add(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x00); }
+void addpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x66, isXMM_XMMorMEM); }
+void addps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0x100, isXMM_XMMorMEM); }
+void addsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF2, isXMM_XMMorMEM); }
+void addss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x58, 0xF3, isXMM_XMMorMEM); }
+void addsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0x66, isXMM_XMMorMEM); }
+void addsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xD0, 0xF2, isXMM_XMMorMEM); }
+void adox(const Reg32e& reg, const Operand& op) { opGen(reg, op, 0xF6, 0xF3, isREG32_REG32orMEM, NONE, 0x38); }
+void aesdec(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDE, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void aesdeclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void aesenc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDC, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void aesenclast(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDD, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void aesimc(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xDB, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void aeskeygenassist(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xDF, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
+void and_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x20, 4); }
+void and_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x20); }
+void andn(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_0F38, 0xf2, true); }
+void andnpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x66, isXMM_XMMorMEM); }
+void andnps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x55, 0x100, isXMM_XMMorMEM); }
+void andpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x66, isXMM_XMMorMEM); }
+void andps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x54, 0x100, isXMM_XMMorMEM); }
+void bextr(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf7, false); }
+void blendpd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0D, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void blendps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0C, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void blendvpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void blendvps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void blsi(const Reg32e& r, const Operand& op) { opGpr(Reg32e(3, r.getBit()), op, r, T_0F38, 0xf3, false); }
+void blsmsk(const Reg32e& r, const Operand& op) { opGpr(Reg32e(2, r.getBit()), op, r, T_0F38, 0xf3, false); }
+void blsr(const Reg32e& r, const Operand& op) { opGpr(Reg32e(1, r.getBit()), op, r, T_0F38, 0xf3, false); }
+void bnd() { db(0xF2); }
+void bndcl(const BoundsReg& bnd, const Operand& op) { db(0xF3); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); }
+void bndcn(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1B, NONE, !op.isMEM()); }
+void bndcu(const BoundsReg& bnd, const Operand& op) { db(0xF2); opR_ModM(op, i32e, bnd.getIdx(), 0x0F, 0x1A, NONE, !op.isMEM()); }
+void bndldx(const BoundsReg& bnd, const Address& addr) { opMIB(addr, bnd, 0x0F, 0x1A); }
+void bndmk(const BoundsReg& bnd, const Address& addr) { db(0xF3); opModM(addr, bnd, 0x0F, 0x1B); }
+void bndmov(const Address& addr, const BoundsReg& bnd) { db(0x66); opModM(addr, bnd, 0x0F, 0x1B); }
+void bndmov(const BoundsReg& bnd, const Operand& op) { db(0x66); opModRM(bnd, op, op.isBNDREG(), op.isMEM(), 0x0F, 0x1A); }
+void bndstx(const Address& addr, const BoundsReg& bnd) { opMIB(addr, bnd, 0x0F, 0x1B); }
+void bsf(const Reg&reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBC); }
+void bsr(const Reg&reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0xBD); }
+void bswap(const Reg32e& reg) { opModR(Reg32(1), reg, 0x0F); }
+void bt(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xA3); }
+void bt(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 4, 0x0f, 0xba, NONE, false, 1); db(imm); }
+void btc(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xBB); }
+void btc(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 7, 0x0f, 0xba, NONE, false, 1); db(imm); }
+void btr(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xB3); }
+void btr(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 6, 0x0f, 0xba, NONE, false, 1); db(imm); }
+void bts(const Operand& op, const Reg& reg) { opModRM(reg, op, op.isREG(16|32|64) && op.getBit() == reg.getBit(), op.isMEM(), 0x0f, 0xAB); }
+void bts(const Operand& op, uint8 imm) { opR_ModM(op, 16|32|64, 5, 0x0f, 0xba, NONE, false, 1); db(imm); }
+void bzhi(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_0F38, 0xf5, false); }
+void cbw() { db(0x66); db(0x98); }
+void cdq() { db(0x99); }
+void clc() { db(0xF8); }
+void cld() { db(0xFC); }
+void clflush(const Address& addr) { opModM(addr, Reg32(7), 0x0F, 0xAE); }
+void cli() { db(0xFA); }
+void cmc() { db(0xF5); }
+void cmova(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524
+void cmovae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524
+void cmovb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524
+void cmovbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524
+void cmovc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524
+void cmove(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524
+void cmovg(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524
+void cmovge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524
+void cmovl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524
+void cmovle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524
+void cmovna(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 6); }//-V524
+void cmovnae(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 2); }//-V524
+void cmovnb(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524
+void cmovnbe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 7); }//-V524
+void cmovnc(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 3); }//-V524
+void cmovne(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524
+void cmovng(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 14); }//-V524
+void cmovnge(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 12); }//-V524
+void cmovnl(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 13); }//-V524
+void cmovnle(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 15); }//-V524
+void cmovno(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 1); }//-V524
+void cmovnp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524
+void cmovns(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 9); }//-V524
+void cmovnz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 5); }//-V524
+void cmovo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 0); }//-V524
+void cmovp(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524
+void cmovpe(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 10); }//-V524
+void cmovpo(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 11); }//-V524
+void cmovs(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 8); }//-V524
+void cmovz(const Reg& reg, const Operand& op) { opModRM(reg, op, op.isREG(16 | i32e), op.isMEM(), 0x0F, 0x40 | 4); }//-V524
+void cmp(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x38, 7); }
+void cmp(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x38); }
+void cmpeqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 0); }
+void cmpeqps(const Xmm& x, const Operand& op) { cmpps(x, op, 0); }
+void cmpeqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 0); }
+void cmpeqss(const Xmm& x, const Operand& op) { cmpss(x, op, 0); }
+void cmplepd(const Xmm& x, const Operand& op) { cmppd(x, op, 2); }
+void cmpleps(const Xmm& x, const Operand& op) { cmpps(x, op, 2); }
+void cmplesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 2); }
+void cmpless(const Xmm& x, const Operand& op) { cmpss(x, op, 2); }
+void cmpltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 1); }
+void cmpltps(const Xmm& x, const Operand& op) { cmpps(x, op, 1); }
+void cmpltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 1); }
+void cmpltss(const Xmm& x, const Operand& op) { cmpss(x, op, 1); }
+void cmpneqpd(const Xmm& x, const Operand& op) { cmppd(x, op, 4); }
+void cmpneqps(const Xmm& x, const Operand& op) { cmpps(x, op, 4); }
+void cmpneqsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 4); }
+void cmpneqss(const Xmm& x, const Operand& op) { cmpss(x, op, 4); }
+void cmpnlepd(const Xmm& x, const Operand& op) { cmppd(x, op, 6); }
+void cmpnleps(const Xmm& x, const Operand& op) { cmpps(x, op, 6); }
+void cmpnlesd(const Xmm& x, const Operand& op) { cmpsd(x, op, 6); }
+void cmpnless(const Xmm& x, const Operand& op) { cmpss(x, op, 6); }
+void cmpnltpd(const Xmm& x, const Operand& op) { cmppd(x, op, 5); }
+void cmpnltps(const Xmm& x, const Operand& op) { cmpps(x, op, 5); }
+void cmpnltsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 5); }
+void cmpnltss(const Xmm& x, const Operand& op) { cmpss(x, op, 5); }
+void cmpordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 7); }
+void cmpordps(const Xmm& x, const Operand& op) { cmpps(x, op, 7); }
+void cmpordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 7); }
+void cmpordss(const Xmm& x, const Operand& op) { cmpss(x, op, 7); }
+void cmppd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x66, isXMM_XMMorMEM, imm8); }
+void cmpps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0x100, isXMM_XMMorMEM, imm8); }
+void cmpsb() { db(0xA6); }
+void cmpsd() { db(0xA7); }
+void cmpsd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF2, isXMM_XMMorMEM, imm8); }
+void cmpss(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC2, 0xF3, isXMM_XMMorMEM, imm8); }
+void cmpsw() { db(0x66); db(0xA7); }
+void cmpunordpd(const Xmm& x, const Operand& op) { cmppd(x, op, 3); }
+void cmpunordps(const Xmm& x, const Operand& op) { cmpps(x, op, 3); }
+void cmpunordsd(const Xmm& x, const Operand& op) { cmpsd(x, op, 3); }
+void cmpunordss(const Xmm& x, const Operand& op) { cmpss(x, op, 3); }
+void cmpxchg(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xB0 | (reg.isBit(8) ? 0 : 1)); }
+void cmpxchg8b(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0xC7); }
+void comisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x66, isXMM_XMMorMEM); }
+void comiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2F, 0x100, isXMM_XMMorMEM); }
+void cpuid() { db(0x0F); db(0xA2); }
+void crc32(const Reg32e& reg, const Operand& op) { if (reg.isBit(32) && op.isBit(16)) db(0x66); db(0xF2); opModRM(reg, op, op.isREG(), op.isMEM(), 0x0F, 0x38, 0xF0 | (op.isBit(8) ? 0 : 1)); }
+void cvtdq2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF3, isXMM_XMMorMEM); }
+void cvtdq2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x100, isXMM_XMMorMEM); }
+void cvtpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0xF2, isXMM_XMMorMEM); }
+void cvtpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x66, isMMX_XMMorMEM); }
+void cvtpd2ps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x66, isXMM_XMMorMEM); }
+void cvtpi2pd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x66, isXMM_MMXorMEM); }
+void cvtpi2ps(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0x100, isXMM_MMXorMEM); }
+void cvtps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0x66, isXMM_XMMorMEM); }
+void cvtps2pd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0x100, isXMM_XMMorMEM); }
+void cvtps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0x100, isMMX_XMMorMEM); }
+void cvtsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF2, isREG32_XMMorMEM); }
+void cvtsd2ss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF2, isXMM_XMMorMEM); }
+void cvtsi2sd(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF2, isXMM_REG32orMEM); }
+void cvtsi2ss(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2A, 0xF3, isXMM_REG32orMEM); }
+void cvtss2sd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5A, 0xF3, isXMM_XMMorMEM); }
+void cvtss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2D, 0xF3, isREG32_XMMorMEM); }
+void cvttpd2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xE6, 0x66, isXMM_XMMorMEM); }
+void cvttpd2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x66, isMMX_XMMorMEM); }
+void cvttps2dq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5B, 0xF3, isXMM_XMMorMEM); }
+void cvttps2pi(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0x100, isMMX_XMMorMEM); }
+void cvttsd2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF2, isREG32_XMMorMEM); }
+void cvttss2si(const Operand& reg, const Operand& op) { opGen(reg, op, 0x2C, 0xF3, isREG32_XMMorMEM); }
+void cwd() { db(0x66); db(0x99); }
+void cwde() { db(0x98); }
+void dec(const Operand& op) { opIncDec(op, 0x48, 1); }
+void div(const Operand& op) { opR_ModM(op, 0, 6, 0xF6); }
+void divpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x66, isXMM_XMMorMEM); }
+void divps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0x100, isXMM_XMMorMEM); }
+void divsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF2, isXMM_XMMorMEM); }
+void divss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5E, 0xF3, isXMM_XMMorMEM); }
+void dppd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void dpps(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void emms() { db(0x0F); db(0x77); }
+void extractps(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x17, imm); }
+void f2xm1() { db(0xD9); db(0xF0); }
+void fabs() { db(0xD9); db(0xE1); }
+void fadd(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 0, 0); }
+void fadd(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C0, 0xDCC0); }
+void fadd(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C0, 0xDCC0); }
+void faddp() { db(0xDE); db(0xC1); }
+void faddp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC0); }
+void faddp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC0); }
+void fchs() { db(0xD9); db(0xE0); }
+void fcmovb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC0, 0x00C0); }
+void fcmovb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC0, 0x00C0); }
+void fcmovbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD0, 0x00D0); }
+void fcmovbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD0, 0x00D0); }
+void fcmove(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAC8, 0x00C8); }
+void fcmove(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAC8, 0x00C8); }
+void fcmovnb(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC0, 0x00C0); }
+void fcmovnb(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC0, 0x00C0); }
+void fcmovnbe(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD0, 0x00D0); }
+void fcmovnbe(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD0, 0x00D0); }
+void fcmovne(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBC8, 0x00C8); }
+void fcmovne(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBC8, 0x00C8); }
+void fcmovnu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBD8, 0x00D8); }
+void fcmovnu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBD8, 0x00D8); }
+void fcmovu(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDAD8, 0x00D8); }
+void fcmovu(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDAD8, 0x00D8); }
+void fcom() { db(0xD8); db(0xD1); }
+void fcom(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 2, 0); }
+void fcom(const Fpu& reg) { opFpu(reg, 0xD8, 0xD0); }
+void fcomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBF0, 0x00F0); }
+void fcomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBF0, 0x00F0); }
+void fcomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFF0, 0x00F0); }
+void fcomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFF0, 0x00F0); }
+void fcomp() { db(0xD8); db(0xD9); }
+void fcomp(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 3, 0); }
+void fcomp(const Fpu& reg) { opFpu(reg, 0xD8, 0xD8); }
+void fcompp() { db(0xDE); db(0xD9); }
+void fcos() { db(0xD9); db(0xFF); }
+void fdecstp() { db(0xD9); db(0xF6); }
+void fdiv(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 6, 0); }
+void fdiv(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F0, 0xDCF8); }
+void fdiv(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F0, 0xDCF8); }
+void fdivp() { db(0xDE); db(0xF9); }
+void fdivp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF8); }
+void fdivp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF8); }
+void fdivr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 7, 0); }
+void fdivr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8F8, 0xDCF0); }
+void fdivr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8F8, 0xDCF0); }
+void fdivrp() { db(0xDE); db(0xF1); }
+void fdivrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEF0); }
+void fdivrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEF0); }
+void ffree(const Fpu& reg) { opFpu(reg, 0xDD, 0xC0); }
+void fiadd(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 0, 0); }
+void ficom(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 2, 0); }
+void ficomp(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 3, 0); }
+void fidiv(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 6, 0); }
+void fidivr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 7, 0); }
+void fild(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 0, 5); }
+void fimul(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 1, 0); }
+void fincstp() { db(0xD9); db(0xF7); }
+void finit() { db(0x9B); db(0xDB); db(0xE3); }
+void fist(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0x00, 2, 0); }
+void fistp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDF, 3, 7); }
+void fisttp(const Address& addr) { opFpuMem(addr, 0xDF, 0xDB, 0xDD, 1, 0); }
+void fisub(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 4, 0); }
+void fisubr(const Address& addr) { opFpuMem(addr, 0xDE, 0xDA, 0x00, 5, 0); }
+void fld(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 0, 0); }
+void fld(const Fpu& reg) { opFpu(reg, 0xD9, 0xC0); }
+void fld1() { db(0xD9); db(0xE8); }
+void fldcw(const Address& addr) { opModM(addr, Reg32(5), 0xD9, 0x100); }
+void fldl2e() { db(0xD9); db(0xEA); }
+void fldl2t() { db(0xD9); db(0xE9); }
+void fldlg2() { db(0xD9); db(0xEC); }
+void fldln2() { db(0xD9); db(0xED); }
+void fldpi() { db(0xD9); db(0xEB); }
+void fldz() { db(0xD9); db(0xEE); }
+void fmul(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 1, 0); }
+void fmul(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8C8, 0xDCC8); }
+void fmul(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8C8, 0xDCC8); }
+void fmulp() { db(0xDE); db(0xC9); }
+void fmulp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEC8); }
+void fmulp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEC8); }
+void fninit() { db(0xDB); db(0xE3); }
+void fnop() { db(0xD9); db(0xD0); }
+void fpatan() { db(0xD9); db(0xF3); }
+void fprem() { db(0xD9); db(0xF8); }
+void fprem1() { db(0xD9); db(0xF5); }
+void fptan() { db(0xD9); db(0xF2); }
+void frndint() { db(0xD9); db(0xFC); }
+void fscale() { db(0xD9); db(0xFD); }
+void fsin() { db(0xD9); db(0xFE); }
+void fsincos() { db(0xD9); db(0xFB); }
+void fsqrt() { db(0xD9); db(0xFA); }
+void fst(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 2, 0); }
+void fst(const Fpu& reg) { opFpu(reg, 0xDD, 0xD0); }
+void fstcw(const Address& addr) { db(0x9B); opModM(addr, Reg32(7), 0xD9, NONE); }
+void fstp(const Address& addr) { opFpuMem(addr, 0x00, 0xD9, 0xDD, 3, 0); }
+void fstp(const Fpu& reg) { opFpu(reg, 0xDD, 0xD8); }
+void fsub(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 4, 0); }
+void fsub(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E0, 0xDCE8); }
+void fsub(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E0, 0xDCE8); }
+void fsubp() { db(0xDE); db(0xE9); }
+void fsubp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE8); }
+void fsubp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE8); }
+void fsubr(const Address& addr) { opFpuMem(addr, 0x00, 0xD8, 0xDC, 5, 0); }
+void fsubr(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xD8E8, 0xDCE0); }
+void fsubr(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xD8E8, 0xDCE0); }
+void fsubrp() { db(0xDE); db(0xE1); }
+void fsubrp(const Fpu& reg1) { opFpuFpu(reg1, st0, 0x0000, 0xDEE0); }
+void fsubrp(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0x0000, 0xDEE0); }
+void ftst() { db(0xD9); db(0xE4); }
+void fucom() { db(0xDD); db(0xE1); }
+void fucom(const Fpu& reg) { opFpu(reg, 0xDD, 0xE0); }
+void fucomi(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDBE8, 0x00E8); }
+void fucomi(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDBE8, 0x00E8); }
+void fucomip(const Fpu& reg1) { opFpuFpu(st0, reg1, 0xDFE8, 0x00E8); }
+void fucomip(const Fpu& reg1, const Fpu& reg2) { opFpuFpu(reg1, reg2, 0xDFE8, 0x00E8); }
+void fucomp() { db(0xDD); db(0xE9); }
+void fucomp(const Fpu& reg) { opFpu(reg, 0xDD, 0xE8); }
+void fucompp() { db(0xDA); db(0xE9); }
+void fwait() { db(0x9B); }
+void fxam() { db(0xD9); db(0xE5); }
+void fxch() { db(0xD9); db(0xC9); }
+void fxch(const Fpu& reg) { opFpu(reg, 0xD9, 0xC8); }
+void fxtract() { db(0xD9); db(0xF4); }
+void fyl2x() { db(0xD9); db(0xF1); }
+void fyl2xp1() { db(0xD9); db(0xF9); }
+void gf2p8affineinvqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void gf2p8affineqb(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0xCE, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void gf2p8mulb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCF, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void haddpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0x66, isXMM_XMMorMEM); }
+void haddps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7C, 0xF2, isXMM_XMMorMEM); }
+void hsubpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0x66, isXMM_XMMorMEM); }
+void hsubps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x7D, 0xF2, isXMM_XMMorMEM); }
+void idiv(const Operand& op) { opR_ModM(op, 0, 7, 0xF6); }
+void imul(const Operand& op) { opR_ModM(op, 0, 5, 0xF6); }
+void inc(const Operand& op) { opIncDec(op, 0x40, 0); }
+void insertps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
+void ja(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524
+void ja(const char *label, LabelType type = T_AUTO) { ja(std::string(label), type); }//-V524
+void ja(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524
+void ja(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524
+void jae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
+void jae(const char *label, LabelType type = T_AUTO) { jae(std::string(label), type); }//-V524
+void jae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524
+void jae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
+void jb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
+void jb(const char *label, LabelType type = T_AUTO) { jb(std::string(label), type); }//-V524
+void jb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524
+void jb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
+void jbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524
+void jbe(const char *label, LabelType type = T_AUTO) { jbe(std::string(label), type); }//-V524
+void jbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524
+void jbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524
+void jc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
+void jc(const char *label, LabelType type = T_AUTO) { jc(std::string(label), type); }//-V524
+void jc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524
+void jc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
+void je(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524
+void je(const char *label, LabelType type = T_AUTO) { je(std::string(label), type); }//-V524
+void je(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524
+void je(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524
+void jg(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524
+void jg(const char *label, LabelType type = T_AUTO) { jg(std::string(label), type); }//-V524
+void jg(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524
+void jg(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524
+void jge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524
+void jge(const char *label, LabelType type = T_AUTO) { jge(std::string(label), type); }//-V524
+void jge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524
+void jge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524
+void jl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524
+void jl(const char *label, LabelType type = T_AUTO) { jl(std::string(label), type); }//-V524
+void jl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524
+void jl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524
+void jle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524
+void jle(const char *label, LabelType type = T_AUTO) { jle(std::string(label), type); }//-V524
+void jle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524
+void jle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524
+void jna(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524
+void jna(const char *label, LabelType type = T_AUTO) { jna(std::string(label), type); }//-V524
+void jna(const void *addr) { opJmpAbs(addr, T_NEAR, 0x76, 0x86, 0x0F); }//-V524
+void jna(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x76, 0x86, 0x0F); }//-V524
+void jnae(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
+void jnae(const char *label, LabelType type = T_AUTO) { jnae(std::string(label), type); }//-V524
+void jnae(const void *addr) { opJmpAbs(addr, T_NEAR, 0x72, 0x82, 0x0F); }//-V524
+void jnae(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x72, 0x82, 0x0F); }//-V524
+void jnb(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
+void jnb(const char *label, LabelType type = T_AUTO) { jnb(std::string(label), type); }//-V524
+void jnb(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524
+void jnb(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
+void jnbe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524
+void jnbe(const char *label, LabelType type = T_AUTO) { jnbe(std::string(label), type); }//-V524
+void jnbe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x77, 0x87, 0x0F); }//-V524
+void jnbe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x77, 0x87, 0x0F); }//-V524
+void jnc(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
+void jnc(const char *label, LabelType type = T_AUTO) { jnc(std::string(label), type); }//-V524
+void jnc(const void *addr) { opJmpAbs(addr, T_NEAR, 0x73, 0x83, 0x0F); }//-V524
+void jnc(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x73, 0x83, 0x0F); }//-V524
+void jne(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524
+void jne(const char *label, LabelType type = T_AUTO) { jne(std::string(label), type); }//-V524
+void jne(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524
+void jne(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524
+void jng(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524
+void jng(const char *label, LabelType type = T_AUTO) { jng(std::string(label), type); }//-V524
+void jng(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7E, 0x8E, 0x0F); }//-V524
+void jng(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7E, 0x8E, 0x0F); }//-V524
+void jnge(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524
+void jnge(const char *label, LabelType type = T_AUTO) { jnge(std::string(label), type); }//-V524
+void jnge(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7C, 0x8C, 0x0F); }//-V524
+void jnge(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7C, 0x8C, 0x0F); }//-V524
+void jnl(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524
+void jnl(const char *label, LabelType type = T_AUTO) { jnl(std::string(label), type); }//-V524
+void jnl(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7D, 0x8D, 0x0F); }//-V524
+void jnl(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7D, 0x8D, 0x0F); }//-V524
+void jnle(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524
+void jnle(const char *label, LabelType type = T_AUTO) { jnle(std::string(label), type); }//-V524
+void jnle(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7F, 0x8F, 0x0F); }//-V524
+void jnle(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7F, 0x8F, 0x0F); }//-V524
+void jno(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524
+void jno(const char *label, LabelType type = T_AUTO) { jno(std::string(label), type); }//-V524
+void jno(const void *addr) { opJmpAbs(addr, T_NEAR, 0x71, 0x81, 0x0F); }//-V524
+void jno(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x71, 0x81, 0x0F); }//-V524
+void jnp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524
+void jnp(const char *label, LabelType type = T_AUTO) { jnp(std::string(label), type); }//-V524
+void jnp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524
+void jnp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524
+void jns(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524
+void jns(const char *label, LabelType type = T_AUTO) { jns(std::string(label), type); }//-V524
+void jns(const void *addr) { opJmpAbs(addr, T_NEAR, 0x79, 0x89, 0x0F); }//-V524
+void jns(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x79, 0x89, 0x0F); }//-V524
+void jnz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524
+void jnz(const char *label, LabelType type = T_AUTO) { jnz(std::string(label), type); }//-V524
+void jnz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x75, 0x85, 0x0F); }//-V524
+void jnz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x75, 0x85, 0x0F); }//-V524
+void jo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524
+void jo(const char *label, LabelType type = T_AUTO) { jo(std::string(label), type); }//-V524
+void jo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x70, 0x80, 0x0F); }//-V524
+void jo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x70, 0x80, 0x0F); }//-V524
+void jp(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524
+void jp(const char *label, LabelType type = T_AUTO) { jp(std::string(label), type); }//-V524
+void jp(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524
+void jp(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524
+void jpe(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524
+void jpe(const char *label, LabelType type = T_AUTO) { jpe(std::string(label), type); }//-V524
+void jpe(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7A, 0x8A, 0x0F); }//-V524
+void jpe(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7A, 0x8A, 0x0F); }//-V524
+void jpo(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524
+void jpo(const char *label, LabelType type = T_AUTO) { jpo(std::string(label), type); }//-V524
+void jpo(const void *addr) { opJmpAbs(addr, T_NEAR, 0x7B, 0x8B, 0x0F); }//-V524
+void jpo(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x7B, 0x8B, 0x0F); }//-V524
+void js(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524
+void js(const char *label, LabelType type = T_AUTO) { js(std::string(label), type); }//-V524
+void js(const void *addr) { opJmpAbs(addr, T_NEAR, 0x78, 0x88, 0x0F); }//-V524
+void js(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x78, 0x88, 0x0F); }//-V524
+void jz(const Label& label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524
+void jz(const char *label, LabelType type = T_AUTO) { jz(std::string(label), type); }//-V524
+void jz(const void *addr) { opJmpAbs(addr, T_NEAR, 0x74, 0x84, 0x0F); }//-V524
+void jz(std::string label, LabelType type = T_AUTO) { opJmp(label, type, 0x74, 0x84, 0x0F); }//-V524
+void lahf() { db(0x9F); }
+void lddqu(const Xmm& xmm, const Address& addr) { db(0xF2); opModM(addr, xmm, 0x0F, 0xF0); }
+void ldmxcsr(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0xAE); }
+void lea(const Reg& reg, const Address& addr) { if (!reg.isBit(16 | i32e)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModM(addr, reg, 0x8D); }
+void lfence() { db(0x0F); db(0xAE); db(0xE8); }
+void lock() { db(0xF0); }
+void lzcnt(const Reg&reg, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBD); }
+void maskmovdqu(const Xmm& reg1, const Xmm& reg2) { db(0x66); opModR(reg1, reg2, 0x0F, 0xF7); }
+void maskmovq(const Mmx& reg1, const Mmx& reg2) { if (!reg1.isMMX() || !reg2.isMMX()) throw Error(ERR_BAD_COMBINATION); opModR(reg1, reg2, 0x0F, 0xF7); }
+void maxpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x66, isXMM_XMMorMEM); }
+void maxps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0x100, isXMM_XMMorMEM); }
+void maxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF2, isXMM_XMMorMEM); }
+void maxss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5F, 0xF3, isXMM_XMMorMEM); }
+void mfence() { db(0x0F); db(0xAE); db(0xF0); }
+void minpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x66, isXMM_XMMorMEM); }
+void minps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0x100, isXMM_XMMorMEM); }
+void minsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF2, isXMM_XMMorMEM); }
+void minss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5D, 0xF3, isXMM_XMMorMEM); }
+void monitor() { db(0x0F); db(0x01); db(0xC8); }
+void movapd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x29); }
+void movapd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x66); }
+void movaps(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x29); }
+void movaps(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x28, 0x100); }
+void movbe(const Address& addr, const Reg& reg) { opModM(addr, reg, 0x0F, 0x38, 0xF1); }
+void movbe(const Reg& reg, const Address& addr) { opModM(addr, reg, 0x0F, 0x38, 0xF0); }
+void movd(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x7E); }
+void movd(const Mmx& mmx, const Address& addr) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, 0x6E); }
+void movd(const Mmx& mmx, const Reg32& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); }
+void movd(const Reg32& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); }
+void movddup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF2, isXMM_XMMorMEM, NONE, NONE); }
+void movdq2q(const Mmx& mmx, const Xmm& xmm) { db(0xF2); opModR(mmx, xmm, 0x0F, 0xD6); }
+void movdqa(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x7F); }
+void movdqa(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0x66); }
+void movdqu(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x7F); }
+void movdqu(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x6F, 0xF3); }
+void movhlps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x12); }
+void movhpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x66); }
+void movhps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x16, 0x100); }
+void movlhps(const Xmm& reg1, const Xmm& reg2) { opModR(reg1, reg2, 0x0F, 0x16); }
+void movlpd(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x66); }
+void movlps(const Operand& op1, const Operand& op2) { opMovXMM(op1, op2, 0x12, 0x100); }
+void movmskpd(const Reg32e& reg, const Xmm& xmm) { db(0x66); movmskps(reg, xmm); }
+void movmskps(const Reg32e& reg, const Xmm& xmm) { opModR(reg, xmm, 0x0F, 0x50); }
+void movntdq(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0xE7); }
+void movntdqa(const Xmm& xmm, const Address& addr) { db(0x66); opModM(addr, xmm, 0x0F, 0x38, 0x2A); }
+void movnti(const Address& addr, const Reg32e& reg) { opModM(addr, reg, 0x0F, 0xC3); }
+void movntpd(const Address& addr, const Xmm& reg) { opModM(addr, Reg16(reg.getIdx()), 0x0F, 0x2B); }
+void movntps(const Address& addr, const Xmm& xmm) { opModM(addr, Mmx(xmm.getIdx()), 0x0F, 0x2B); }
+void movntq(const Address& addr, const Mmx& mmx) { if (!mmx.isMMX()) throw Error(ERR_BAD_COMBINATION); opModM(addr, mmx, 0x0F, 0xE7); }
+void movq(const Address& addr, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModM(addr, mmx, 0x0F, mmx.isXMM() ? 0xD6 : 0x7F); }
+void movq(const Mmx& mmx, const Operand& op) { if (mmx.isXMM()) db(0xF3); opModRM(mmx, op, (mmx.getKind() == op.getKind()), op.isMEM(), 0x0F, mmx.isXMM() ? 0x7E : 0x6F); }
+void movq2dq(const Xmm& xmm, const Mmx& mmx) { db(0xF3); opModR(xmm, mmx, 0x0F, 0xD6); }
+void movsb() { db(0xA4); }
+void movsd() { db(0xA5); }
+void movsd(const Address& addr, const Xmm& xmm) { db(0xF2); opModM(addr, xmm, 0x0F, 0x11); }
+void movsd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF2); }
+void movshdup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x16, 0xF3, isXMM_XMMorMEM, NONE, NONE); }
+void movsldup(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x12, 0xF3, isXMM_XMMorMEM, NONE, NONE); }
+void movss(const Address& addr, const Xmm& xmm) { db(0xF3); opModM(addr, xmm, 0x0F, 0x11); }
+void movss(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0xF3); }
+void movsw() { db(0x66); db(0xA5); }
+void movsx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xBE); }
+void movupd(const Address& addr, const Xmm& xmm) { db(0x66); opModM(addr, xmm, 0x0F, 0x11); }
+void movupd(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x66); }
+void movups(const Address& addr, const Xmm& xmm) { opModM(addr, xmm, 0x0F, 0x11); }
+void movups(const Xmm& xmm, const Operand& op) { opMMX(xmm, op, 0x10, 0x100); }
+void movzx(const Reg& reg, const Operand& op) { opMovxx(reg, op, 0xB6); }
+void mpsadbw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x42, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void mul(const Operand& op) { opR_ModM(op, 0, 4, 0xF6); }
+void mulpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x66, isXMM_XMMorMEM); }
+void mulps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0x100, isXMM_XMMorMEM); }
+void mulsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF2, isXMM_XMMorMEM); }
+void mulss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x59, 0xF3, isXMM_XMMorMEM); }
+void mulx(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf6, true); }
+void mwait() { db(0x0F); db(0x01); db(0xC9); }
+void neg(const Operand& op) { opR_ModM(op, 0, 3, 0xF6); }
+void not_(const Operand& op) { opR_ModM(op, 0, 2, 0xF6); }
+void or_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x08, 1); }
+void or_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x08); }
+void orpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x66, isXMM_XMMorMEM); }
+void orps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x56, 0x100, isXMM_XMMorMEM); }
+void pabsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1C, 0x66, NONE, 0x38); }
+void pabsd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1E, 0x66, NONE, 0x38); }
+void pabsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x1D, 0x66, NONE, 0x38); }
+void packssdw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6B); }
+void packsswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x63); }
+void packusdw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2B, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void packuswb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x67); }
+void paddb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFC); }
+void paddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFE); }
+void paddq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD4); }
+void paddsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEC); }
+void paddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xED); }
+void paddusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDC); }
+void paddusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDD); }
+void paddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFD); }
+void palignr(const Mmx& mmx, const Operand& op, int imm) { opMMX(mmx, op, 0x0f, 0x66, static_cast<uint8>(imm), 0x3a); }
+void pand(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDB); }
+void pandn(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDF); }
+void pause() { db(0xF3); db(0x90); }
+void pavgb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE0); }
+void pavgw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE3); }
+void pblendvb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x10, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pblendw(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0E, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void pclmulhqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x11); }
+void pclmulhqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x01); }
+void pclmullqhdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x10); }
+void pclmullqlqdq(const Xmm& xmm, const Operand& op) { pclmulqdq(xmm, op, 0x00); }
+void pclmulqdq(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x44, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void pcmpeqb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x74); }
+void pcmpeqd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x76); }
+void pcmpeqq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x29, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pcmpeqw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x75); }
+void pcmpestri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x61, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
+void pcmpestrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x60, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
+void pcmpgtb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x64); }
+void pcmpgtd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x66); }
+void pcmpgtq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x37, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pcmpgtw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x65); }
+void pcmpistri(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x63, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
+void pcmpistrm(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x62, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
+void pdep(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F2 | T_0F38, 0xf5, true); }
+void pext(const Reg32e& r1, const Reg32e& r2, const Operand& op) { opGpr(r1, r2, op, T_F3 | T_0F38, 0xf5, true); }
+void pextrb(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x14, imm); }
+void pextrd(const Operand& op, const Xmm& xmm, uint8 imm) { opExt(op, xmm, 0x16, imm); }
+void pextrw(const Operand& op, const Mmx& xmm, uint8 imm) { opExt(op, xmm, 0x15, imm, true); }
+void phaddd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x02, 0x66, NONE, 0x38); }
+void phaddsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x03, 0x66, NONE, 0x38); }
+void phaddw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x01, 0x66, NONE, 0x38); }
+void phminposuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x41, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void phsubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x06, 0x66, NONE, 0x38); }
+void phsubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x07, 0x66, NONE, 0x38); }
+void phsubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x05, 0x66, NONE, 0x38); }
+void pinsrb(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x20, 0x66, isXMM_REG32orMEM, imm, 0x3A); }
+void pinsrd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x22, 0x66, isXMM_REG32orMEM, imm, 0x3A); }
+void pinsrw(const Mmx& mmx, const Operand& op, int imm) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(mmx, op, 0xC4, mmx.isXMM() ? 0x66 : NONE, 0, imm); }
+void pmaddubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x04, 0x66, NONE, 0x38); }
+void pmaddwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF5); }
+void pmaxsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3C, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmaxsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3D, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmaxsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEE); }
+void pmaxub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDE); }
+void pmaxud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3F, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmaxuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3E, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pminsb(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x38, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pminsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x39, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pminsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEA); }
+void pminub(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xDA); }
+void pminud(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3B, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pminuw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x3A, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovmskb(const Reg32e& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(reg, mmx, 0x0F, 0xD7); }
+void pmovsxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x21, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovsxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x22, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovsxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x20, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovsxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x25, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovsxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x23, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovsxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x24, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovzxbd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x31, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovzxbq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x32, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovzxbw(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x30, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovzxdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x35, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovzxwd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x33, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmovzxwq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x34, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmuldq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x28, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmulhrsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0B, 0x66, NONE, 0x38); }
+void pmulhuw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE4); }
+void pmulhw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE5); }
+void pmulld(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x40, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void pmullw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD5); }
+void pmuludq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF4); }
+void popcnt(const Reg&reg, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xB8); }
+void popf() { db(0x9D); }
+void por(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEB); }
+void prefetchnta(const Address& addr) { opModM(addr, Reg32(0), 0x0F, 0x18); }
+void prefetcht0(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x18); }
+void prefetcht1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x18); }
+void prefetcht2(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0x18); }
+void prefetchw(const Address& addr) { opModM(addr, Reg32(1), 0x0F, 0x0D); }
+void prefetchwt1(const Address& addr) { opModM(addr, Reg32(2), 0x0F, 0x0D); }
+void psadbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF6); }
+void pshufb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x00, 0x66, NONE, 0x38); }
+void pshufd(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x66, imm8); }
+void pshufhw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF3, imm8); }
+void pshuflw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0xF2, imm8); }
+void pshufw(const Mmx& mmx, const Operand& op, uint8 imm8) { opMMX(mmx, op, 0x70, 0x00, imm8); }
+void psignb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x08, 0x66, NONE, 0x38); }
+void psignd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x0A, 0x66, NONE, 0x38); }
+void psignw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x09, 0x66, NONE, 0x38); }
+void pslld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF2); }
+void pslld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 6); }
+void pslldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 7); }
+void psllq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF3); }
+void psllq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 6); }
+void psllw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF1); }
+void psllw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 6); }
+void psrad(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE2); }
+void psrad(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 4); }
+void psraw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE1); }
+void psraw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 4); }
+void psrld(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD2); }
+void psrld(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x72, 2); }
+void psrldq(const Xmm& xmm, int imm8) { opMMX_IMM(xmm, imm8, 0x73, 3); }
+void psrlq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD3); }
+void psrlq(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x73, 2); }
+void psrlw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD1); }
+void psrlw(const Mmx& mmx, int imm8) { opMMX_IMM(mmx, imm8, 0x71, 2); }
+void psubb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF8); }
+void psubd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFA); }
+void psubq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xFB); }
+void psubsb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE8); }
+void psubsw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xE9); }
+void psubusb(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD8); }
+void psubusw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xD9); }
+void psubw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xF9); }
+void ptest(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x17, 0x66, isXMM_XMMorMEM, NONE, 0x38); }
+void punpckhbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x68); }
+void punpckhdq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x6A); }
+void punpckhqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6D, 0x66, isXMM_XMMorMEM); }
+void punpckhwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x69); }
+void punpcklbw(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x60); }
+void punpckldq(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x62); }
+void punpcklqdq(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x6C, 0x66, isXMM_XMMorMEM); }
+void punpcklwd(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0x61); }
+void pushf() { db(0x9C); }
+void pxor(const Mmx& mmx, const Operand& op) { opMMX(mmx, op, 0xEF); }
+void rcl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 2); }
+void rcl(const Operand& op, int imm) { opShift(op, imm, 2); }
+void rcpps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0x100, isXMM_XMMorMEM); }
+void rcpss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x53, 0xF3, isXMM_XMMorMEM); }
+void rcr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 3); }
+void rcr(const Operand& op, int imm) { opShift(op, imm, 3); }
+void rdmsr() { db(0x0F); db(0x32); }
+void rdpmc() { db(0x0F); db(0x33); }
+void rdrand(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(6, Operand::REG, r.getBit()), r, 0x0F, 0xC7); }
+void rdseed(const Reg& r) { if (r.isBit(8)) throw Error(ERR_BAD_SIZE_OF_REGISTER); opModR(Reg(7, Operand::REG, r.getBit()), r, 0x0F, 0xC7); }
+void rdtsc() { db(0x0F); db(0x31); }
+void rdtscp() { db(0x0F); db(0x01); db(0xF9); }
+void rep() { db(0xF3); }
+void ret(int imm = 0) { if (imm) { db(0xC2); dw(imm); } else { db(0xC3); } }
+void rol(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 0); }
+void rol(const Operand& op, int imm) { opShift(op, imm, 0); }
+void ror(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 1); }
+void ror(const Operand& op, int imm) { opShift(op, imm, 1); }
+void rorx(const Reg32e& r, const Operand& op, uint8 imm) { opGpr(r, op, Reg32e(0, r.getBit()), T_0F3A | T_F2, 0xF0, false, imm); }
+void roundpd(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x09, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
+void roundps(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0x08, 0x66, isXMM_XMMorMEM, imm, 0x3A); }
+void roundsd(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0B, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void roundss(const Xmm& xmm, const Operand& op, int imm) { opGen(xmm, op, 0x0A, 0x66, isXMM_XMMorMEM, static_cast<uint8>(imm), 0x3A); }
+void rsqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0x100, isXMM_XMMorMEM); }
+void rsqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x52, 0xF3, isXMM_XMMorMEM); }
+void sahf() { db(0x9E); }
+void sal(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); }
+void sal(const Operand& op, int imm) { opShift(op, imm, 4); }
+void sar(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 7); }
+void sar(const Operand& op, int imm) { opShift(op, imm, 7); }
+void sarx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F3 | T_0F38, 0xf7, false); }
+void sbb(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x18, 3); }
+void sbb(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x18); }
+void scasb() { db(0xAE); }
+void scasd() { db(0xAF); }
+void scasw() { db(0x66); db(0xAF); }
+void seta(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524
+void setae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524
+void setb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524
+void setbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524
+void setc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524
+void sete(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524
+void setg(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524
+void setge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524
+void setl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524
+void setle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524
+void setna(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 6); }//-V524
+void setnae(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 2); }//-V524
+void setnb(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524
+void setnbe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 7); }//-V524
+void setnc(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 3); }//-V524
+void setne(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524
+void setng(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 14); }//-V524
+void setnge(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 12); }//-V524
+void setnl(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 13); }//-V524
+void setnle(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 15); }//-V524
+void setno(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 1); }//-V524
+void setnp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524
+void setns(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 9); }//-V524
+void setnz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 5); }//-V524
+void seto(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 0); }//-V524
+void setp(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524
+void setpe(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 10); }//-V524
+void setpo(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 11); }//-V524
+void sets(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 8); }//-V524
+void setz(const Operand& op) { opR_ModM(op, 8, 0, 0x0F, 0x90 | 4); }//-V524
+void sfence() { db(0x0F); db(0xAE); db(0xF8); }
+void sha1msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC9, NONE, isXMM_XMMorMEM, NONE, 0x38); }
+void sha1msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCA, NONE, isXMM_XMMorMEM, NONE, 0x38); }
+void sha1nexte(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xC8, NONE, isXMM_XMMorMEM, NONE, 0x38); }
+void sha1rnds4(const Xmm& xmm, const Operand& op, uint8 imm) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, imm, 0x3A); }
+void sha256msg1(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCC, NONE, isXMM_XMMorMEM, NONE, 0x38); }
+void sha256msg2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCD, NONE, isXMM_XMMorMEM, NONE, 0x38); }
+void sha256rnds2(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0xCB, NONE, isXMM_XMMorMEM, NONE, 0x38); }
+void shl(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 4); }
+void shl(const Operand& op, int imm) { opShift(op, imm, 4); }
+void shld(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xA4, &_cl); }
+void shld(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xA4); }
+void shlx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_66 | T_0F38, 0xf7, false); }
+void shr(const Operand& op, const Reg8& _cl) { opShift(op, _cl, 5); }
+void shr(const Operand& op, int imm) { opShift(op, imm, 5); }
+void shrd(const Operand& op, const Reg& reg, const Reg8& _cl) { opShxd(op, reg, 0, 0xAC, &_cl); }
+void shrd(const Operand& op, const Reg& reg, uint8 imm) { opShxd(op, reg, imm, 0xAC); }
+void shrx(const Reg32e& r1, const Operand& op, const Reg32e& r2) { opGpr(r1, op, r2, T_F2 | T_0F38, 0xf7, false); }
+void shufpd(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x66, isXMM_XMMorMEM, imm8); }
+void shufps(const Xmm& xmm, const Operand& op, uint8 imm8) { opGen(xmm, op, 0xC6, 0x100, isXMM_XMMorMEM, imm8); }
+void sqrtpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x66, isXMM_XMMorMEM); }
+void sqrtps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0x100, isXMM_XMMorMEM); }
+void sqrtsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF2, isXMM_XMMorMEM); }
+void sqrtss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x51, 0xF3, isXMM_XMMorMEM); }
+void stac() { db(0x0F); db(0x01); db(0xCB); }
+void stc() { db(0xF9); }
+void std() { db(0xFD); }
+void sti() { db(0xFB); }
+void stmxcsr(const Address& addr) { opModM(addr, Reg32(3), 0x0F, 0xAE); }
+void stosb() { db(0xAA); }
+void stosd() { db(0xAB); }
+void stosw() { db(0x66); db(0xAB); }
+void sub(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x28, 5); }
+void sub(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x28); }
+void subpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x66, isXMM_XMMorMEM); }
+void subps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0x100, isXMM_XMMorMEM); }
+void subsd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF2, isXMM_XMMorMEM); }
+void subss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x5C, 0xF3, isXMM_XMMorMEM); }
+void tzcnt(const Reg&reg, const Operand& op) { opSp1(reg, op, 0xF3, 0x0F, 0xBC); }
+void ucomisd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x66, isXMM_XMMorMEM); }
+void ucomiss(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x2E, 0x100, isXMM_XMMorMEM); }
+void ud2() { db(0x0F); db(0x0B); }
+void unpckhpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x66, isXMM_XMMorMEM); }
+void unpckhps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x15, 0x100, isXMM_XMMorMEM); }
+void unpcklpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x66, isXMM_XMMorMEM); }
+void unpcklps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x14, 0x100, isXMM_XMMorMEM); }
+void vaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x58); }
+void vaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x58); }
+void vaddsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x58); }
+void vaddss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x58); }
+void vaddsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0xD0); }
+void vaddsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0xD0); }
+void vaesdec(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDE); }
+void vaesdeclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDF); }
+void vaesenc(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDC); }
+void vaesenclast(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F38 | T_YMM | T_EVEX, 0xDD); }
+void vaesimc(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_W0, 0xDB); }
+void vaeskeygenassist(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0xDF, imm); }
+void vandnpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x55); }
+void vandnps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x55); }
+void vandpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x54); }
+void vandps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x54); }
+void vblendpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0D, imm); }
+void vblendps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0C, imm); }
+void vblendvpd(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4B, x4.getIdx() << 4); }
+void vblendvps(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4A, x4.getIdx() << 4); }
+void vbroadcastf128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x1A); }
+void vbroadcasti128(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x5A); }
+void vbroadcastsd(const Ymm& y, const Operand& op) { if (!op.isMEM() && !(y.isYMM() && op.isXMM()) && !(y.isZMM() && op.isXMM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(y, op, T_0F38 | T_66 | T_W0 | T_YMM | T_EVEX | T_EW1 | T_N8, 0x19); }
+void vbroadcastss(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x18); }
+void vcmpeq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 16); }
+void vcmpeq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 16); }
+void vcmpeq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 16); }
+void vcmpeq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 16); }
+void vcmpeq_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 8); }
+void vcmpeq_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 8); }
+void vcmpeq_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 8); }
+void vcmpeq_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 8); }
+void vcmpeq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 24); }
+void vcmpeq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 24); }
+void vcmpeq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 24); }
+void vcmpeq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 24); }
+void vcmpeqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 0); }
+void vcmpeqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 0); }
+void vcmpeqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 0); }
+void vcmpeqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 0); }
+void vcmpfalse_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 27); }
+void vcmpfalse_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 27); }
+void vcmpfalse_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 27); }
+void vcmpfalse_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 27); }
+void vcmpfalsepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 11); }
+void vcmpfalseps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 11); }
+void vcmpfalsesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 11); }
+void vcmpfalsess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 11); }
+void vcmpge_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 29); }
+void vcmpge_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 29); }
+void vcmpge_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 29); }
+void vcmpge_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 29); }
+void vcmpgepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 13); }
+void vcmpgeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 13); }
+void vcmpgesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 13); }
+void vcmpgess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 13); }
+void vcmpgt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 30); }
+void vcmpgt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 30); }
+void vcmpgt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 30); }
+void vcmpgt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 30); }
+void vcmpgtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 14); }
+void vcmpgtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 14); }
+void vcmpgtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 14); }
+void vcmpgtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 14); }
+void vcmple_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 18); }
+void vcmple_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 18); }
+void vcmple_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 18); }
+void vcmple_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 18); }
+void vcmplepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 2); }
+void vcmpleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 2); }
+void vcmplesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 2); }
+void vcmpless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 2); }
+void vcmplt_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 17); }
+void vcmplt_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 17); }
+void vcmplt_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 17); }
+void vcmplt_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 17); }
+void vcmpltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 1); }
+void vcmpltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 1); }
+void vcmpltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 1); }
+void vcmpltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 1); }
+void vcmpneq_oqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 12); }
+void vcmpneq_oqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 12); }
+void vcmpneq_oqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 12); }
+void vcmpneq_oqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 12); }
+void vcmpneq_ospd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 28); }
+void vcmpneq_osps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 28); }
+void vcmpneq_ossd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 28); }
+void vcmpneq_osss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 28); }
+void vcmpneq_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 20); }
+void vcmpneq_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 20); }
+void vcmpneq_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 20); }
+void vcmpneq_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 20); }
+void vcmpneqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 4); }
+void vcmpneqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 4); }
+void vcmpneqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 4); }
+void vcmpneqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 4); }
+void vcmpnge_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 25); }
+void vcmpnge_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 25); }
+void vcmpnge_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 25); }
+void vcmpnge_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 25); }
+void vcmpngepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 9); }
+void vcmpngeps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 9); }
+void vcmpngesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 9); }
+void vcmpngess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 9); }
+void vcmpngt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 26); }
+void vcmpngt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 26); }
+void vcmpngt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 26); }
+void vcmpngt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 26); }
+void vcmpngtpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 10); }
+void vcmpngtps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 10); }
+void vcmpngtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 10); }
+void vcmpngtss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 10); }
+void vcmpnle_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 22); }
+void vcmpnle_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 22); }
+void vcmpnle_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 22); }
+void vcmpnle_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 22); }
+void vcmpnlepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 6); }
+void vcmpnleps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 6); }
+void vcmpnlesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 6); }
+void vcmpnless(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 6); }
+void vcmpnlt_uqpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 21); }
+void vcmpnlt_uqps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 21); }
+void vcmpnlt_uqsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 21); }
+void vcmpnlt_uqss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 21); }
+void vcmpnltpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 5); }
+void vcmpnltps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 5); }
+void vcmpnltsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 5); }
+void vcmpnltss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 5); }
+void vcmpord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 23); }
+void vcmpord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 23); }
+void vcmpord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 23); }
+void vcmpord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 23); }
+void vcmpordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 7); }
+void vcmpordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 7); }
+void vcmpordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 7); }
+void vcmpordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 7); }
+void vcmppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xC2, imm); }
+void vcmpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_YMM, 0xC2, imm); }
+void vcmpsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F2 | T_0F, 0xC2, imm); }
+void vcmpss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0xC2, imm); }
+void vcmptrue_uspd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 31); }
+void vcmptrue_usps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 31); }
+void vcmptrue_ussd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 31); }
+void vcmptrue_usss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 31); }
+void vcmptruepd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 15); }
+void vcmptrueps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 15); }
+void vcmptruesd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 15); }
+void vcmptruess(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 15); }
+void vcmpunord_spd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 19); }
+void vcmpunord_sps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 19); }
+void vcmpunord_ssd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 19); }
+void vcmpunord_sss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 19); }
+void vcmpunordpd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmppd(x1, x2, op, 3); }
+void vcmpunordps(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpps(x1, x2, op, 3); }
+void vcmpunordsd(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpsd(x1, x2, op, 3); }
+void vcmpunordss(const Xmm& x1, const Xmm& x2, const Operand& op) { vcmpss(x1, x2, op, 3); }
+void vcomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2F); }
+void vcomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2F); }
+void vcvtdq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_F3 | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0xE6); }
+void vcvtdq2ps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); }
+void vcvtpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_F2 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0xE6); }
+void vcvtpd2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_66 | T_YMM | T_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5A); }
+void vcvtph2ps(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F38 | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x13); }
+void vcvtps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5B); }
+void vcvtps2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_0F | T_YMM | T_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x5A); }
+void vcvtps2ph(const Operand& op, const Xmm& x, uint8 imm) { checkCvt1(x, op); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N8 | T_N_VL | T_SAE_Y, 0x1D, imm); }
+void vcvtsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_ER_X, 0x2D); }
+void vcvtsd2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x5A); }
+void vcvtsi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F2 | T_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); }
+void vcvtsi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_0F | T_F3 | T_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x2A); }
+void vcvtss2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x5A); }
+void vcvtss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_ER_X | T_N8, 0x2D); }
+void vcvttpd2dq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_66 | T_0F | T_YMM | T_EVEX |T_EW1 | T_B64 | T_ER_Z, 0xE6); }
+void vcvttps2dq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX | T_SAE_Z | T_B32, 0x5B); }
+void vcvttsd2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W0 | T_EVEX | T_EW0 | T_N4 | T_SAE_X, 0x2C); }
+void vcvttss2si(const Reg32& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W0 | T_EVEX | T_EW0 | T_SAE_X | T_N8, 0x2C); }
+void vdivpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5E); }
+void vdivps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5E); }
+void vdivsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5E); }
+void vdivss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5E); }
+void vdppd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x41, imm); }
+void vdpps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x40, imm); }
+void vextractf128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x19, imm); }
+void vextracti128(const Operand& op, const Ymm& y, uint8 imm) { if (!(op.isXMEM() && y.isYMM())) throw Error(ERR_BAD_COMBINATION); opVex(y, 0, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x39, imm); }
+void vextractps(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_N4, 0x17, imm); }
+void vfmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x98); }
+void vfmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x98); }
+void vfmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x99); }
+void vfmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x99); }
+void vfmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA8); }
+void vfmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA8); }
+void vfmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xA9); }
+void vfmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xA9); }
+void vfmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB8); }
+void vfmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB8); }
+void vfmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xB9); }
+void vfmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xB9); }
+void vfmaddsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x96); }
+void vfmaddsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x96); }
+void vfmaddsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA6); }
+void vfmaddsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA6); }
+void vfmaddsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB6); }
+void vfmaddsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB6); }
+void vfmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9A); }
+void vfmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9A); }
+void vfmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9B); }
+void vfmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9B); }
+void vfmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAA); }
+void vfmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAA); }
+void vfmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAB); }
+void vfmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAB); }
+void vfmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBA); }
+void vfmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBA); }
+void vfmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBB); }
+void vfmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBB); }
+void vfmsubadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x97); }
+void vfmsubadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x97); }
+void vfmsubadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xA7); }
+void vfmsubadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xA7); }
+void vfmsubadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xB7); }
+void vfmsubadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xB7); }
+void vfnmadd132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9C); }
+void vfnmadd132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9C); }
+void vfnmadd132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9D); }
+void vfnmadd132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9D); }
+void vfnmadd213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAC); }
+void vfnmadd213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAC); }
+void vfnmadd213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAD); }
+void vfnmadd213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAD); }
+void vfnmadd231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBC); }
+void vfnmadd231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBC); }
+void vfnmadd231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBD); }
+void vfnmadd231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBD); }
+void vfnmsub132pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x9E); }
+void vfnmsub132ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x9E); }
+void vfnmsub132sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0x9F); }
+void vfnmsub132ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0x9F); }
+void vfnmsub213pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xAE); }
+void vfnmsub213ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xAE); }
+void vfnmsub213sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xAF); }
+void vfnmsub213ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xAF); }
+void vfnmsub231pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0xBE); }
+void vfnmsub231ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0xBE); }
+void vfnmsub231sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_W1 | T_EW1 | T_EVEX | T_ER_X, 0xBF); }
+void vfnmsub231ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_W0 | T_EW0 | T_EVEX | T_ER_X, 0xBF); }
+void vgatherdpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x92, 0); }
+void vgatherdps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x92, 1); }
+void vgatherqpd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x93, 1); }
+void vgatherqps(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x93, 2); }
+void vgf2p8affineinvqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCF, imm); }
+void vgf2p8affineqb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_SAE_Z | T_B64, 0xCE, imm); }
+void vgf2p8mulb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_SAE_Z, 0xCF); }
+void vhaddpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7C); }
+void vhaddps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7C); }
+void vhsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_66 | T_0F | T_YMM, 0x7D); }
+void vhsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_F2 | T_0F | T_YMM, 0x7D); }
+void vinsertf128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x18, imm); }
+void vinserti128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isXMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x38, imm); }
+void vinsertps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_W0 | T_EW0 | T_EVEX, 0x21, imm); }
+void vlddqu(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, cvtIdx0(x), addr, T_0F | T_F2 | T_W0 | T_YMM, 0xF0); }
+void vldmxcsr(const Address& addr) { opAVX_X_X_XM(xm2, xm0, addr, T_0F, 0xAE); }
+void vmaskmovdqu(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_66, 0xF7); }
+void vmaskmovpd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2F); }
+void vmaskmovpd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2D); }
+void vmaskmovps(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2E); }
+void vmaskmovps(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x2C); }
+void vmaxpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5F); }
+void vmaxps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5F); }
+void vmaxsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5F); }
+void vmaxss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5F); }
+void vminpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5D); }
+void vminps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5D); }
+void vminsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5D); }
+void vminss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5D); }
+void vmovapd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x29); }
+void vmovapd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x28); }
+void vmovaps(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x29); }
+void vmovaps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x28); }
+void vmovd(const Operand& op, const Xmm& x) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x7E); }
+void vmovd(const Xmm& x, const Operand& op) { if (!op.isREG(32) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, xm0, op, T_0F | T_66 | T_W0 | T_EVEX | T_N4, 0x6E); }
+void vmovddup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_DUP | T_F2 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_X | T_ER_Y | T_ER_Z, 0x12); }
+void vmovdqa(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_YMM, 0x7F); }
+void vmovdqa(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_YMM, 0x6F); }
+void vmovdqu(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_F3 | T_0F | T_YMM, 0x7F); }
+void vmovdqu(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM, 0x6F); }
+void vmovhlps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x12); }
+void vmovhpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x17); }
+void vmovhpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x16); }
+void vmovhps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x17); }
+void vmovhps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x16); }
+void vmovlhps(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_0F | T_EVEX | T_EW0, 0x16); }
+void vmovlpd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x13); }
+void vmovlpd(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, 0x12); }
+void vmovlps(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_EVEX | T_EW0 | T_N8, 0x13); }
+void vmovlps(const Xmm& x, const Operand& op1, const Operand& op2 = Operand()) { if (!op2.isNone() && !op2.isMEM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x, op1, op2, T_0F | T_EVEX | T_EW0 | T_N8, 0x12); }
+void vmovmskpd(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_66 | T_W0 | T_YMM, 0x50); }
+void vmovmskps(const Reg& r, const Xmm& x) { if (!r.isBit(i32e)) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x.isXMM() ? Xmm(r.getIdx()) : Ymm(r.getIdx()), cvtIdx0(x), x, T_0F | T_W0 | T_YMM, 0x50); }
+void vmovntdq(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW0, 0xE7); }
+void vmovntdqa(const Xmm& x, const Address& addr) { opVex(x, 0, addr, T_0F38 | T_66 | T_YMM | T_EVEX | T_EW0, 0x2A); }
+void vmovntpd(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_66 | T_YMM | T_EVEX | T_EW1, 0x2B); }
+void vmovntps(const Address& addr, const Xmm& x) { opVex(x, 0, addr, T_0F | T_YMM | T_EVEX | T_EW0, 0x2B); }
+void vmovq(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_0F | T_66 | T_EVEX | T_EW1 | T_N8, x.getIdx() < 16 ? 0xD6 : 0x7E); }
+void vmovq(const Xmm& x, const Address& addr) { int type, code; if (x.getIdx() < 16) { type = T_0F | T_F3; code = 0x7E; } else { type = T_0F | T_66 | T_EVEX | T_EW1 | T_N8; code = 0x6E; } opAVX_X_X_XM(x, xm0, addr, type, code); }
+void vmovq(const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x1, xm0, x2, T_0F | T_F3 | T_EVEX | T_EW1 | T_N8, 0x7E); }
+void vmovsd(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_M_K, 0x11); }
+void vmovsd(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); }
+void vmovsd(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX, 0x10); }
+void vmovshdup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x16); }
+void vmovsldup(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_EW0 | T_YMM | T_EVEX, 0x12); }
+void vmovss(const Address& addr, const Xmm& x) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_M_K, 0x11); }
+void vmovss(const Xmm& x, const Address& addr) { opAVX_X_X_XM(x, xm0, addr, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); }
+void vmovss(const Xmm& x1, const Xmm& x2, const Operand& op = Operand()) { if (!op.isNone() && !op.isXMM()) throw Error(ERR_BAD_COMBINATION); opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX, 0x10); }
+void vmovupd(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_M_K, 0x11); }
+void vmovupd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0x10); }
+void vmovups(const Address& addr, const Xmm& xmm) { opAVX_X_XM_IMM(xmm, addr, T_0F | T_EW0 | T_YMM | T_EVEX | T_M_K, 0x11); }
+void vmovups(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX, 0x10); }
+void vmpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x42, imm); }
+void vmulpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x59); }
+void vmulps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x59); }
+void vmulsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x59); }
+void vmulss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x59); }
+void vorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x56); }
+void vorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x56); }
+void vpabsb(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1C); }
+void vpabsd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x1E); }
+void vpabsw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x1D); }
+void vpackssdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6B); }
+void vpacksswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x63); }
+void vpackusdw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x2B); }
+void vpackuswb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x67); }
+void vpaddb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFC); }
+void vpaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFE); }
+void vpaddq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xD4); }
+void vpaddsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEC); }
+void vpaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xED); }
+void vpaddusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDC); }
+void vpaddusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDD); }
+void vpaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xFD); }
+void vpalignr(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_YMM | T_EVEX, 0x0F, imm); }
+void vpand(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDB); }
+void vpandn(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xDF); }
+void vpavgb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE0); }
+void vpavgw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE3); }
+void vpblendd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x02, imm); }
+void vpblendvb(const Xmm& x1, const Xmm& x2, const Operand& op, const Xmm& x4) { opAVX_X_X_XM(x1, x2, op, T_0F3A | T_66 | T_YMM, 0x4C, x4.getIdx() << 4); }
+void vpblendw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM, 0x0E, imm); }
+void vpbroadcastb(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x78); }
+void vpbroadcastd(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x58); }
+void vpbroadcastq(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX, 0x59); }
+void vpbroadcastw(const Xmm& x, const Operand& op) { if (!(op.isXMM() || op.isMEM())) throw Error(ERR_BAD_COMBINATION); opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_W0 | T_YMM | T_EVEX, 0x79); }
+void vpclmulqdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0 | T_YMM | T_EVEX, 0x44, imm); }
+void vpcmpeqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x74); }
+void vpcmpeqd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x76); }
+void vpcmpeqq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x29); }
+void vpcmpeqw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x75); }
+void vpcmpestri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x61, imm); }
+void vpcmpestrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x60, imm); }
+void vpcmpgtb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x64); }
+void vpcmpgtd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x66); }
+void vpcmpgtq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x37); }
+void vpcmpgtw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0x65); }
+void vpcmpistri(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x63, imm); }
+void vpcmpistrm(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A, 0x62, imm); }
+void vperm2f128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x06, imm); }
+void vperm2i128(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { if (!(y1.isYMM() && y2.isYMM() && op.isYMEM())) throw Error(ERR_BAD_COMBINATION); opVex(y1, &y2, op, T_0F3A | T_66 | T_W0 | T_YMM, 0x46, imm); }
+void vpermd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x36); }
+void vpermilpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x0D); }
+void vpermilpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_EVEX | T_B64, 0x05, imm); }
+void vpermilps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x0C); }
+void vpermilps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_EVEX | T_B32, 0x04, imm); }
+void vpermpd(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x01, imm); }
+void vpermpd(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x16); }
+void vpermps(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x16); }
+void vpermq(const Ymm& y, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(y, op, T_66 | T_0F3A | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x00, imm); }
+void vpermq(const Ymm& y1, const Ymm& y2, const Operand& op) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F38 | T_W0 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x36); }
+void vpextrb(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(8|16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x14, imm); }
+void vpextrd(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(32) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x16, imm); }
+void vpextrq(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(64) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); opVex(x, 0, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x16, imm); }
+void vpextrw(const Operand& op, const Xmm& x, uint8 imm) { if (!((op.isREG(16|i32e) || op.isMEM()) && x.isXMM())) throw Error(ERR_BAD_COMBINATION); if (op.isREG() && x.getIdx() < 16) { opAVX_X_X_XM(Xmm(op.getIdx()), xm0, x, T_0F | T_66, 0xC5, imm); } else { opVex(x, 0, op, T_0F3A | T_66 | T_EVEX | T_N2, 0x15, imm); } }
+void vpgatherdd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x90, 1); }
+void vpgatherdq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x90, 0); }
+void vpgatherqd(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W0, 0x91, 2); }
+void vpgatherqq(const Xmm& x1, const Address& addr, const Xmm& x2) { opGather(x1, addr, x2, T_0F38 | T_66 | T_YMM | T_VSIB | T_W1, 0x91, 1); }
+void vphaddd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x02); }
+void vphaddsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x03); }
+void vphaddw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x01); }
+void vphminposuw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38, 0x41); }
+void vphsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x06); }
+void vphsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x07); }
+void vphsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x05); }
+void vpinsrb(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_EVEX | T_N1, 0x20, imm); }
+void vpinsrd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W0 | T_EVEX | T_EW0 | T_N4, 0x22, imm); }
+void vpinsrq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(64) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F3A | T_66 | T_W1 | T_EVEX | T_EW1 | T_N8, 0x22, imm); }
+void vpinsrw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { if (!(x1.isXMM() && x2.isXMM() && (op.isREG(32) || op.isMEM()))) throw Error(ERR_BAD_COMBINATION); opVex(x1, &x2, op, T_0F | T_66 | T_EVEX | T_N2, 0xC4, imm); }
+void vpmaddubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x04); }
+void vpmaddwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF5); }
+void vpmaskmovd(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8E); }
+void vpmaskmovd(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W0 | T_YMM, 0x8C); }
+void vpmaskmovq(const Address& addr, const Xmm& x1, const Xmm& x2) { opAVX_X_X_XM(x2, x1, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8E); }
+void vpmaskmovq(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_66 | T_W1 | T_YMM, 0x8C); }
+void vpmaxsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3C); }
+void vpmaxsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3D); }
+void vpmaxsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEE); }
+void vpmaxub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDE); }
+void vpmaxud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3F); }
+void vpmaxuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3E); }
+void vpminsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x38); }
+void vpminsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x39); }
+void vpminsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xEA); }
+void vpminub(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xDA); }
+void vpminud(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x3B); }
+void vpminuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x3A); }
+void vpmovmskb(const Reg32e& r, const Xmm& x) { if (!x.is(Operand::XMM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(x.isYMM() ? Ymm(r.getIdx()) : Xmm(r.getIdx()), 0, x, T_0F | T_66 | T_YMM, 0xD7); }
+void vpmovsxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x21); }
+void vpmovsxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x22); }
+void vpmovsxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x20); }
+void vpmovsxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x25); }
+void vpmovsxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x23); }
+void vpmovsxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x24); }
+void vpmovzxbd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x31); }
+void vpmovzxbq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N2 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x32); }
+void vpmovzxbw(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x30); }
+void vpmovzxdq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX, 0x35); }
+void vpmovzxwd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x33); }
+void vpmovzxwq(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_N_VL | T_66 | T_0F38 | T_YMM | T_EVEX, 0x34); }
+void vpmuldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x28); }
+void vpmulhrsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x0B); }
+void vpmulhuw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE4); }
+void vpmulhw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE5); }
+void vpmulld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x40); }
+void vpmullw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD5); }
+void vpmuludq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xF4); }
+void vpor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEB); }
+void vpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF6); }
+void vpshufb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM | T_EVEX, 0x00); }
+void vpshufd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x70, imm); }
+void vpshufhw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F3 | T_0F | T_YMM | T_EVEX, 0x70, imm); }
+void vpshuflw(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_F2 | T_0F | T_YMM | T_EVEX, 0x70, imm); }
+void vpsignb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x08); }
+void vpsignd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x0A); }
+void vpsignw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_YMM, 0x09); }
+void vpslld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); }
+void vpslld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xF2); }
+void vpslldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 7), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); }
+void vpsllq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); }
+void vpsllq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xF3); }
+void vpsllvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x47); }
+void vpsllvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x47); }
+void vpsllw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 6), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); }
+void vpsllw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xF1); }
+void vpsrad(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); }
+void vpsrad(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xE2); }
+void vpsravd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x46); }
+void vpsraw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); }
+void vpsraw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xE1); }
+void vpsrld(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32 | T_MEM_EVEX, 0x72, imm); }
+void vpsrld(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW0 | T_YMM | T_EVEX, 0xD2); }
+void vpsrldq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 3), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x73, imm); }
+void vpsrlq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64 | T_MEM_EVEX, 0x73, imm); }
+void vpsrlq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_EVEX, 0xD3); }
+void vpsrlvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W0 | T_EW0 | T_YMM | T_EVEX | T_B32, 0x45); }
+void vpsrlvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_W1 | T_EW1 | T_YMM | T_EVEX | T_B64, 0x45); }
+void vpsrlw(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 2), x, op, T_66 | T_0F | T_YMM | T_EVEX | T_MEM_EVEX, 0x71, imm); }
+void vpsrlw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_YMM | T_EVEX, 0xD1); }
+void vpsubb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF8); }
+void vpsubd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xFA); }
+void vpsubq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xFB); }
+void vpsubsb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE8); }
+void vpsubsw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xE9); }
+void vpsubusb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD8); }
+void vpsubusw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xD9); }
+void vpsubw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0xF9); }
+void vptest(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x17); }
+void vpunpckhbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x68); }
+void vpunpckhdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x6A); }
+void vpunpckhqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6D); }
+void vpunpckhwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x69); }
+void vpunpcklbw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x60); }
+void vpunpckldq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x62); }
+void vpunpcklqdq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x6C); }
+void vpunpcklwd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM | T_EVEX, 0x61); }
+void vpxor(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_YMM, 0xEF); }
+void vrcpps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x53); }
+void vrcpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x53); }
+void vroundpd(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x09, imm); }
+void vroundps(const Xmm& xm, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F3A | T_YMM, 0x08, imm); }
+void vroundsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0B, imm); }
+void vroundss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_W0, 0x0A, imm); }
+void vrsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_YMM, 0x52); }
+void vrsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_F3 | T_0F, 0x52); }
+void vshufpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0xC6, imm); }
+void vshufps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0xC6, imm); }
+void vsqrtpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x51); }
+void vsqrtps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x51); }
+void vsqrtsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_F2 | T_0F | T_EW1 | T_EVEX | T_ER_X, 0x51); }
+void vsqrtss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_F3 | T_0F | T_EW0 | T_EVEX | T_ER_X, 0x51); }
+void vstmxcsr(const Address& addr) { opAVX_X_X_XM(xm3, xm0, addr, T_0F, 0xAE); }
+void vsubpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x5C); }
+void vsubps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x5C); }
+void vsubsd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F2 | T_EW1 | T_EVEX | T_ER_Z | T_N8, 0x5C); }
+void vsubss(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_F3 | T_EW0 | T_EVEX | T_ER_Z | T_N4, 0x5C); }
+void vtestpd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0F); }
+void vtestps(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_66 | T_0F38 | T_YMM, 0x0E); }
+void vucomisd(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N8 | T_66 | T_0F | T_EW1 | T_EVEX | T_SAE_X, 0x2E); }
+void vucomiss(const Xmm& xm, const Operand& op) { opAVX_X_XM_IMM(xm, op, T_N4 | T_0F | T_EW0 | T_EVEX | T_SAE_X, 0x2E); }
+void vunpckhpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x15); }
+void vunpckhps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x15); }
+void vunpcklpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_EVEX | T_B64, 0x14); }
+void vunpcklps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_0F | T_EW0 | T_YMM | T_EVEX | T_B32, 0x14); }
+void vxorpd(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_66 | T_EW1 | T_YMM | T_EVEX | T_ER_Z | T_B64, 0x57); }
+void vxorps(const Xmm& xmm, const Operand& op1, const Operand& op2 = Operand()) { opAVX_X_X_XM(xmm, op1, op2, T_0F | T_EW0 | T_YMM | T_EVEX | T_ER_Z | T_B32, 0x57); }
+void vzeroall() { db(0xC5); db(0xFC); db(0x77); }
+void vzeroupper() { db(0xC5); db(0xF8); db(0x77); }
+void wait() { db(0x9B); }
+void wbinvd() { db(0x0F); db(0x09); }
+void wrmsr() { db(0x0F); db(0x30); }
+void xadd(const Operand& op, const Reg& reg) { opModRM(reg, op, (op.isREG() && reg.isREG() && op.getBit() == reg.getBit()), op.isMEM(), 0x0F, 0xC0 | (reg.isBit(8) ? 0 : 1)); }
+void xgetbv() { db(0x0F); db(0x01); db(0xD0); }
+void xlatb() { db(0xD7); }
+void xor_(const Operand& op, uint32 imm) { opRM_I(op, imm, 0x30, 6); }
+void xor_(const Operand& op1, const Operand& op2) { opRM_RM(op1, op2, 0x30); }
+void xorpd(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x66, isXMM_XMMorMEM); }
+void xorps(const Xmm& xmm, const Operand& op) { opGen(xmm, op, 0x57, 0x100, isXMM_XMMorMEM); }
+#ifdef XBYAK_ENABLE_OMITTED_OPERAND
+void vblendpd(const Xmm& x, const Operand& op, uint8 imm) { vblendpd(x, x, op, imm); }
+void vblendps(const Xmm& x, const Operand& op, uint8 imm) { vblendps(x, x, op, imm); }
+void vblendvpd(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvpd(x1, x1, op, x4); }
+void vblendvps(const Xmm& x1, const Operand& op, const Xmm& x4) { vblendvps(x1, x1, op, x4); }
+void vcmpeq_ospd(const Xmm& x, const Operand& op) { vcmpeq_ospd(x, x, op); }
+void vcmpeq_osps(const Xmm& x, const Operand& op) { vcmpeq_osps(x, x, op); }
+void vcmpeq_ossd(const Xmm& x, const Operand& op) { vcmpeq_ossd(x, x, op); }
+void vcmpeq_osss(const Xmm& x, const Operand& op) { vcmpeq_osss(x, x, op); }
+void vcmpeq_uqpd(const Xmm& x, const Operand& op) { vcmpeq_uqpd(x, x, op); }
+void vcmpeq_uqps(const Xmm& x, const Operand& op) { vcmpeq_uqps(x, x, op); }
+void vcmpeq_uqsd(const Xmm& x, const Operand& op) { vcmpeq_uqsd(x, x, op); }
+void vcmpeq_uqss(const Xmm& x, const Operand& op) { vcmpeq_uqss(x, x, op); }
+void vcmpeq_uspd(const Xmm& x, const Operand& op) { vcmpeq_uspd(x, x, op); }
+void vcmpeq_usps(const Xmm& x, const Operand& op) { vcmpeq_usps(x, x, op); }
+void vcmpeq_ussd(const Xmm& x, const Operand& op) { vcmpeq_ussd(x, x, op); }
+void vcmpeq_usss(const Xmm& x, const Operand& op) { vcmpeq_usss(x, x, op); }
+void vcmpeqpd(const Xmm& x, const Operand& op) { vcmpeqpd(x, x, op); }
+void vcmpeqps(const Xmm& x, const Operand& op) { vcmpeqps(x, x, op); }
+void vcmpeqsd(const Xmm& x, const Operand& op) { vcmpeqsd(x, x, op); }
+void vcmpeqss(const Xmm& x, const Operand& op) { vcmpeqss(x, x, op); }
+void vcmpfalse_ospd(const Xmm& x, const Operand& op) { vcmpfalse_ospd(x, x, op); }
+void vcmpfalse_osps(const Xmm& x, const Operand& op) { vcmpfalse_osps(x, x, op); }
+void vcmpfalse_ossd(const Xmm& x, const Operand& op) { vcmpfalse_ossd(x, x, op); }
+void vcmpfalse_osss(const Xmm& x, const Operand& op) { vcmpfalse_osss(x, x, op); }
+void vcmpfalsepd(const Xmm& x, const Operand& op) { vcmpfalsepd(x, x, op); }
+void vcmpfalseps(const Xmm& x, const Operand& op) { vcmpfalseps(x, x, op); }
+void vcmpfalsesd(const Xmm& x, const Operand& op) { vcmpfalsesd(x, x, op); }
+void vcmpfalsess(const Xmm& x, const Operand& op) { vcmpfalsess(x, x, op); }
+void vcmpge_oqpd(const Xmm& x, const Operand& op) { vcmpge_oqpd(x, x, op); }
+void vcmpge_oqps(const Xmm& x, const Operand& op) { vcmpge_oqps(x, x, op); }
+void vcmpge_oqsd(const Xmm& x, const Operand& op) { vcmpge_oqsd(x, x, op); }
+void vcmpge_oqss(const Xmm& x, const Operand& op) { vcmpge_oqss(x, x, op); }
+void vcmpgepd(const Xmm& x, const Operand& op) { vcmpgepd(x, x, op); }
+void vcmpgeps(const Xmm& x, const Operand& op) { vcmpgeps(x, x, op); }
+void vcmpgesd(const Xmm& x, const Operand& op) { vcmpgesd(x, x, op); }
+void vcmpgess(const Xmm& x, const Operand& op) { vcmpgess(x, x, op); }
+void vcmpgt_oqpd(const Xmm& x, const Operand& op) { vcmpgt_oqpd(x, x, op); }
+void vcmpgt_oqps(const Xmm& x, const Operand& op) { vcmpgt_oqps(x, x, op); }
+void vcmpgt_oqsd(const Xmm& x, const Operand& op) { vcmpgt_oqsd(x, x, op); }
+void vcmpgt_oqss(const Xmm& x, const Operand& op) { vcmpgt_oqss(x, x, op); }
+void vcmpgtpd(const Xmm& x, const Operand& op) { vcmpgtpd(x, x, op); }
+void vcmpgtps(const Xmm& x, const Operand& op) { vcmpgtps(x, x, op); }
+void vcmpgtsd(const Xmm& x, const Operand& op) { vcmpgtsd(x, x, op); }
+void vcmpgtss(const Xmm& x, const Operand& op) { vcmpgtss(x, x, op); }
+void vcmple_oqpd(const Xmm& x, const Operand& op) { vcmple_oqpd(x, x, op); }
+void vcmple_oqps(const Xmm& x, const Operand& op) { vcmple_oqps(x, x, op); }
+void vcmple_oqsd(const Xmm& x, const Operand& op) { vcmple_oqsd(x, x, op); }
+void vcmple_oqss(const Xmm& x, const Operand& op) { vcmple_oqss(x, x, op); }
+void vcmplepd(const Xmm& x, const Operand& op) { vcmplepd(x, x, op); }
+void vcmpleps(const Xmm& x, const Operand& op) { vcmpleps(x, x, op); }
+void vcmplesd(const Xmm& x, const Operand& op) { vcmplesd(x, x, op); }
+void vcmpless(const Xmm& x, const Operand& op) { vcmpless(x, x, op); }
+void vcmplt_oqpd(const Xmm& x, const Operand& op) { vcmplt_oqpd(x, x, op); }
+void vcmplt_oqps(const Xmm& x, const Operand& op) { vcmplt_oqps(x, x, op); }
+void vcmplt_oqsd(const Xmm& x, const Operand& op) { vcmplt_oqsd(x, x, op); }
+void vcmplt_oqss(const Xmm& x, const Operand& op) { vcmplt_oqss(x, x, op); }
+void vcmpltpd(const Xmm& x, const Operand& op) { vcmpltpd(x, x, op); }
+void vcmpltps(const Xmm& x, const Operand& op) { vcmpltps(x, x, op); }
+void vcmpltsd(const Xmm& x, const Operand& op) { vcmpltsd(x, x, op); }
+void vcmpltss(const Xmm& x, const Operand& op) { vcmpltss(x, x, op); }
+void vcmpneq_oqpd(const Xmm& x, const Operand& op) { vcmpneq_oqpd(x, x, op); }
+void vcmpneq_oqps(const Xmm& x, const Operand& op) { vcmpneq_oqps(x, x, op); }
+void vcmpneq_oqsd(const Xmm& x, const Operand& op) { vcmpneq_oqsd(x, x, op); }
+void vcmpneq_oqss(const Xmm& x, const Operand& op) { vcmpneq_oqss(x, x, op); }
+void vcmpneq_ospd(const Xmm& x, const Operand& op) { vcmpneq_ospd(x, x, op); }
+void vcmpneq_osps(const Xmm& x, const Operand& op) { vcmpneq_osps(x, x, op); }
+void vcmpneq_ossd(const Xmm& x, const Operand& op) { vcmpneq_ossd(x, x, op); }
+void vcmpneq_osss(const Xmm& x, const Operand& op) { vcmpneq_osss(x, x, op); }
+void vcmpneq_uspd(const Xmm& x, const Operand& op) { vcmpneq_uspd(x, x, op); }
+void vcmpneq_usps(const Xmm& x, const Operand& op) { vcmpneq_usps(x, x, op); }
+void vcmpneq_ussd(const Xmm& x, const Operand& op) { vcmpneq_ussd(x, x, op); }
+void vcmpneq_usss(const Xmm& x, const Operand& op) { vcmpneq_usss(x, x, op); }
+void vcmpneqpd(const Xmm& x, const Operand& op) { vcmpneqpd(x, x, op); }
+void vcmpneqps(const Xmm& x, const Operand& op) { vcmpneqps(x, x, op); }
+void vcmpneqsd(const Xmm& x, const Operand& op) { vcmpneqsd(x, x, op); }
+void vcmpneqss(const Xmm& x, const Operand& op) { vcmpneqss(x, x, op); }
+void vcmpnge_uqpd(const Xmm& x, const Operand& op) { vcmpnge_uqpd(x, x, op); }
+void vcmpnge_uqps(const Xmm& x, const Operand& op) { vcmpnge_uqps(x, x, op); }
+void vcmpnge_uqsd(const Xmm& x, const Operand& op) { vcmpnge_uqsd(x, x, op); }
+void vcmpnge_uqss(const Xmm& x, const Operand& op) { vcmpnge_uqss(x, x, op); }
+void vcmpngepd(const Xmm& x, const Operand& op) { vcmpngepd(x, x, op); }
+void vcmpngeps(const Xmm& x, const Operand& op) { vcmpngeps(x, x, op); }
+void vcmpngesd(const Xmm& x, const Operand& op) { vcmpngesd(x, x, op); }
+void vcmpngess(const Xmm& x, const Operand& op) { vcmpngess(x, x, op); }
+void vcmpngt_uqpd(const Xmm& x, const Operand& op) { vcmpngt_uqpd(x, x, op); }
+void vcmpngt_uqps(const Xmm& x, const Operand& op) { vcmpngt_uqps(x, x, op); }
+void vcmpngt_uqsd(const Xmm& x, const Operand& op) { vcmpngt_uqsd(x, x, op); }
+void vcmpngt_uqss(const Xmm& x, const Operand& op) { vcmpngt_uqss(x, x, op); }
+void vcmpngtpd(const Xmm& x, const Operand& op) { vcmpngtpd(x, x, op); }
+void vcmpngtps(const Xmm& x, const Operand& op) { vcmpngtps(x, x, op); }
+void vcmpngtsd(const Xmm& x, const Operand& op) { vcmpngtsd(x, x, op); }
+void vcmpngtss(const Xmm& x, const Operand& op) { vcmpngtss(x, x, op); }
+void vcmpnle_uqpd(const Xmm& x, const Operand& op) { vcmpnle_uqpd(x, x, op); }
+void vcmpnle_uqps(const Xmm& x, const Operand& op) { vcmpnle_uqps(x, x, op); }
+void vcmpnle_uqsd(const Xmm& x, const Operand& op) { vcmpnle_uqsd(x, x, op); }
+void vcmpnle_uqss(const Xmm& x, const Operand& op) { vcmpnle_uqss(x, x, op); }
+void vcmpnlepd(const Xmm& x, const Operand& op) { vcmpnlepd(x, x, op); }
+void vcmpnleps(const Xmm& x, const Operand& op) { vcmpnleps(x, x, op); }
+void vcmpnlesd(const Xmm& x, const Operand& op) { vcmpnlesd(x, x, op); }
+void vcmpnless(const Xmm& x, const Operand& op) { vcmpnless(x, x, op); }
+void vcmpnlt_uqpd(const Xmm& x, const Operand& op) { vcmpnlt_uqpd(x, x, op); }
+void vcmpnlt_uqps(const Xmm& x, const Operand& op) { vcmpnlt_uqps(x, x, op); }
+void vcmpnlt_uqsd(const Xmm& x, const Operand& op) { vcmpnlt_uqsd(x, x, op); }
+void vcmpnlt_uqss(const Xmm& x, const Operand& op) { vcmpnlt_uqss(x, x, op); }
+void vcmpnltpd(const Xmm& x, const Operand& op) { vcmpnltpd(x, x, op); }
+void vcmpnltps(const Xmm& x, const Operand& op) { vcmpnltps(x, x, op); }
+void vcmpnltsd(const Xmm& x, const Operand& op) { vcmpnltsd(x, x, op); }
+void vcmpnltss(const Xmm& x, const Operand& op) { vcmpnltss(x, x, op); }
+void vcmpord_spd(const Xmm& x, const Operand& op) { vcmpord_spd(x, x, op); }
+void vcmpord_sps(const Xmm& x, const Operand& op) { vcmpord_sps(x, x, op); }
+void vcmpord_ssd(const Xmm& x, const Operand& op) { vcmpord_ssd(x, x, op); }
+void vcmpord_sss(const Xmm& x, const Operand& op) { vcmpord_sss(x, x, op); }
+void vcmpordpd(const Xmm& x, const Operand& op) { vcmpordpd(x, x, op); }
+void vcmpordps(const Xmm& x, const Operand& op) { vcmpordps(x, x, op); }
+void vcmpordsd(const Xmm& x, const Operand& op) { vcmpordsd(x, x, op); }
+void vcmpordss(const Xmm& x, const Operand& op) { vcmpordss(x, x, op); }
+void vcmppd(const Xmm& x, const Operand& op, uint8 imm) { vcmppd(x, x, op, imm); }
+void vcmpps(const Xmm& x, const Operand& op, uint8 imm) { vcmpps(x, x, op, imm); }
+void vcmpsd(const Xmm& x, const Operand& op, uint8 imm) { vcmpsd(x, x, op, imm); }
+void vcmpss(const Xmm& x, const Operand& op, uint8 imm) { vcmpss(x, x, op, imm); }
+void vcmptrue_uspd(const Xmm& x, const Operand& op) { vcmptrue_uspd(x, x, op); }
+void vcmptrue_usps(const Xmm& x, const Operand& op) { vcmptrue_usps(x, x, op); }
+void vcmptrue_ussd(const Xmm& x, const Operand& op) { vcmptrue_ussd(x, x, op); }
+void vcmptrue_usss(const Xmm& x, const Operand& op) { vcmptrue_usss(x, x, op); }
+void vcmptruepd(const Xmm& x, const Operand& op) { vcmptruepd(x, x, op); }
+void vcmptrueps(const Xmm& x, const Operand& op) { vcmptrueps(x, x, op); }
+void vcmptruesd(const Xmm& x, const Operand& op) { vcmptruesd(x, x, op); }
+void vcmptruess(const Xmm& x, const Operand& op) { vcmptruess(x, x, op); }
+void vcmpunord_spd(const Xmm& x, const Operand& op) { vcmpunord_spd(x, x, op); }
+void vcmpunord_sps(const Xmm& x, const Operand& op) { vcmpunord_sps(x, x, op); }
+void vcmpunord_ssd(const Xmm& x, const Operand& op) { vcmpunord_ssd(x, x, op); }
+void vcmpunord_sss(const Xmm& x, const Operand& op) { vcmpunord_sss(x, x, op); }
+void vcmpunordpd(const Xmm& x, const Operand& op) { vcmpunordpd(x, x, op); }
+void vcmpunordps(const Xmm& x, const Operand& op) { vcmpunordps(x, x, op); }
+void vcmpunordsd(const Xmm& x, const Operand& op) { vcmpunordsd(x, x, op); }
+void vcmpunordss(const Xmm& x, const Operand& op) { vcmpunordss(x, x, op); }
+void vcvtsd2ss(const Xmm& x, const Operand& op) { vcvtsd2ss(x, x, op); }
+void vcvtsi2sd(const Xmm& x, const Operand& op) { vcvtsi2sd(x, x, op); }
+void vcvtsi2ss(const Xmm& x, const Operand& op) { vcvtsi2ss(x, x, op); }
+void vcvtss2sd(const Xmm& x, const Operand& op) { vcvtss2sd(x, x, op); }
+void vdppd(const Xmm& x, const Operand& op, uint8 imm) { vdppd(x, x, op, imm); }
+void vdpps(const Xmm& x, const Operand& op, uint8 imm) { vdpps(x, x, op, imm); }
+void vinsertps(const Xmm& x, const Operand& op, uint8 imm) { vinsertps(x, x, op, imm); }
+void vmpsadbw(const Xmm& x, const Operand& op, uint8 imm) { vmpsadbw(x, x, op, imm); }
+void vpackssdw(const Xmm& x, const Operand& op) { vpackssdw(x, x, op); }
+void vpacksswb(const Xmm& x, const Operand& op) { vpacksswb(x, x, op); }
+void vpackusdw(const Xmm& x, const Operand& op) { vpackusdw(x, x, op); }
+void vpackuswb(const Xmm& x, const Operand& op) { vpackuswb(x, x, op); }
+void vpaddb(const Xmm& x, const Operand& op) { vpaddb(x, x, op); }
+void vpaddd(const Xmm& x, const Operand& op) { vpaddd(x, x, op); }
+void vpaddq(const Xmm& x, const Operand& op) { vpaddq(x, x, op); }
+void vpaddsb(const Xmm& x, const Operand& op) { vpaddsb(x, x, op); }
+void vpaddsw(const Xmm& x, const Operand& op) { vpaddsw(x, x, op); }
+void vpaddusb(const Xmm& x, const Operand& op) { vpaddusb(x, x, op); }
+void vpaddusw(const Xmm& x, const Operand& op) { vpaddusw(x, x, op); }
+void vpaddw(const Xmm& x, const Operand& op) { vpaddw(x, x, op); }
+void vpalignr(const Xmm& x, const Operand& op, uint8 imm) { vpalignr(x, x, op, imm); }
+void vpand(const Xmm& x, const Operand& op) { vpand(x, x, op); }
+void vpandn(const Xmm& x, const Operand& op) { vpandn(x, x, op); }
+void vpavgb(const Xmm& x, const Operand& op) { vpavgb(x, x, op); }
+void vpavgw(const Xmm& x, const Operand& op) { vpavgw(x, x, op); }
+void vpblendd(const Xmm& x, const Operand& op, uint8 imm) { vpblendd(x, x, op, imm); }
+void vpblendvb(const Xmm& x1, const Operand& op, const Xmm& x4) { vpblendvb(x1, x1, op, x4); }
+void vpblendw(const Xmm& x, const Operand& op, uint8 imm) { vpblendw(x, x, op, imm); }
+void vpclmulqdq(const Xmm& x, const Operand& op, uint8 imm) { vpclmulqdq(x, x, op, imm); }
+void vpcmpeqb(const Xmm& x, const Operand& op) { vpcmpeqb(x, x, op); }
+void vpcmpeqd(const Xmm& x, const Operand& op) { vpcmpeqd(x, x, op); }
+void vpcmpeqq(const Xmm& x, const Operand& op) { vpcmpeqq(x, x, op); }
+void vpcmpeqw(const Xmm& x, const Operand& op) { vpcmpeqw(x, x, op); }
+void vpcmpgtb(const Xmm& x, const Operand& op) { vpcmpgtb(x, x, op); }
+void vpcmpgtd(const Xmm& x, const Operand& op) { vpcmpgtd(x, x, op); }
+void vpcmpgtq(const Xmm& x, const Operand& op) { vpcmpgtq(x, x, op); }
+void vpcmpgtw(const Xmm& x, const Operand& op) { vpcmpgtw(x, x, op); }
+void vphaddd(const Xmm& x, const Operand& op) { vphaddd(x, x, op); }
+void vphaddsw(const Xmm& x, const Operand& op) { vphaddsw(x, x, op); }
+void vphaddw(const Xmm& x, const Operand& op) { vphaddw(x, x, op); }
+void vphsubd(const Xmm& x, const Operand& op) { vphsubd(x, x, op); }
+void vphsubsw(const Xmm& x, const Operand& op) { vphsubsw(x, x, op); }
+void vphsubw(const Xmm& x, const Operand& op) { vphsubw(x, x, op); }
+void vpinsrb(const Xmm& x, const Operand& op, uint8 imm) { vpinsrb(x, x, op, imm); }
+void vpinsrd(const Xmm& x, const Operand& op, uint8 imm) { vpinsrd(x, x, op, imm); }
+void vpinsrq(const Xmm& x, const Operand& op, uint8 imm) { vpinsrq(x, x, op, imm); }
+void vpinsrw(const Xmm& x, const Operand& op, uint8 imm) { vpinsrw(x, x, op, imm); }
+void vpmaddubsw(const Xmm& x, const Operand& op) { vpmaddubsw(x, x, op); }
+void vpmaddwd(const Xmm& x, const Operand& op) { vpmaddwd(x, x, op); }
+void vpmaxsb(const Xmm& x, const Operand& op) { vpmaxsb(x, x, op); }
+void vpmaxsd(const Xmm& x, const Operand& op) { vpmaxsd(x, x, op); }
+void vpmaxsw(const Xmm& x, const Operand& op) { vpmaxsw(x, x, op); }
+void vpmaxub(const Xmm& x, const Operand& op) { vpmaxub(x, x, op); }
+void vpmaxud(const Xmm& x, const Operand& op) { vpmaxud(x, x, op); }
+void vpmaxuw(const Xmm& x, const Operand& op) { vpmaxuw(x, x, op); }
+void vpminsb(const Xmm& x, const Operand& op) { vpminsb(x, x, op); }
+void vpminsd(const Xmm& x, const Operand& op) { vpminsd(x, x, op); }
+void vpminsw(const Xmm& x, const Operand& op) { vpminsw(x, x, op); }
+void vpminub(const Xmm& x, const Operand& op) { vpminub(x, x, op); }
+void vpminud(const Xmm& x, const Operand& op) { vpminud(x, x, op); }
+void vpminuw(const Xmm& x, const Operand& op) { vpminuw(x, x, op); }
+void vpmuldq(const Xmm& x, const Operand& op) { vpmuldq(x, x, op); }
+void vpmulhrsw(const Xmm& x, const Operand& op) { vpmulhrsw(x, x, op); }
+void vpmulhuw(const Xmm& x, const Operand& op) { vpmulhuw(x, x, op); }
+void vpmulhw(const Xmm& x, const Operand& op) { vpmulhw(x, x, op); }
+void vpmulld(const Xmm& x, const Operand& op) { vpmulld(x, x, op); }
+void vpmullw(const Xmm& x, const Operand& op) { vpmullw(x, x, op); }
+void vpmuludq(const Xmm& x, const Operand& op) { vpmuludq(x, x, op); }
+void vpor(const Xmm& x, const Operand& op) { vpor(x, x, op); }
+void vpsadbw(const Xmm& x, const Operand& op) { vpsadbw(x, x, op); }
+void vpsignb(const Xmm& x, const Operand& op) { vpsignb(x, x, op); }
+void vpsignd(const Xmm& x, const Operand& op) { vpsignd(x, x, op); }
+void vpsignw(const Xmm& x, const Operand& op) { vpsignw(x, x, op); }
+void vpslld(const Xmm& x, const Operand& op) { vpslld(x, x, op); }
+void vpslld(const Xmm& x, uint8 imm) { vpslld(x, x, imm); }
+void vpslldq(const Xmm& x, uint8 imm) { vpslldq(x, x, imm); }
+void vpsllq(const Xmm& x, const Operand& op) { vpsllq(x, x, op); }
+void vpsllq(const Xmm& x, uint8 imm) { vpsllq(x, x, imm); }
+void vpsllw(const Xmm& x, const Operand& op) { vpsllw(x, x, op); }
+void vpsllw(const Xmm& x, uint8 imm) { vpsllw(x, x, imm); }
+void vpsrad(const Xmm& x, const Operand& op) { vpsrad(x, x, op); }
+void vpsrad(const Xmm& x, uint8 imm) { vpsrad(x, x, imm); }
+void vpsraw(const Xmm& x, const Operand& op) { vpsraw(x, x, op); }
+void vpsraw(const Xmm& x, uint8 imm) { vpsraw(x, x, imm); }
+void vpsrld(const Xmm& x, const Operand& op) { vpsrld(x, x, op); }
+void vpsrld(const Xmm& x, uint8 imm) { vpsrld(x, x, imm); }
+void vpsrldq(const Xmm& x, uint8 imm) { vpsrldq(x, x, imm); }
+void vpsrlq(const Xmm& x, const Operand& op) { vpsrlq(x, x, op); }
+void vpsrlq(const Xmm& x, uint8 imm) { vpsrlq(x, x, imm); }
+void vpsrlw(const Xmm& x, const Operand& op) { vpsrlw(x, x, op); }
+void vpsrlw(const Xmm& x, uint8 imm) { vpsrlw(x, x, imm); }
+void vpsubb(const Xmm& x, const Operand& op) { vpsubb(x, x, op); }
+void vpsubd(const Xmm& x, const Operand& op) { vpsubd(x, x, op); }
+void vpsubq(const Xmm& x, const Operand& op) { vpsubq(x, x, op); }
+void vpsubsb(const Xmm& x, const Operand& op) { vpsubsb(x, x, op); }
+void vpsubsw(const Xmm& x, const Operand& op) { vpsubsw(x, x, op); }
+void vpsubusb(const Xmm& x, const Operand& op) { vpsubusb(x, x, op); }
+void vpsubusw(const Xmm& x, const Operand& op) { vpsubusw(x, x, op); }
+void vpsubw(const Xmm& x, const Operand& op) { vpsubw(x, x, op); }
+void vpunpckhbw(const Xmm& x, const Operand& op) { vpunpckhbw(x, x, op); }
+void vpunpckhdq(const Xmm& x, const Operand& op) { vpunpckhdq(x, x, op); }
+void vpunpckhqdq(const Xmm& x, const Operand& op) { vpunpckhqdq(x, x, op); }
+void vpunpckhwd(const Xmm& x, const Operand& op) { vpunpckhwd(x, x, op); }
+void vpunpcklbw(const Xmm& x, const Operand& op) { vpunpcklbw(x, x, op); }
+void vpunpckldq(const Xmm& x, const Operand& op) { vpunpckldq(x, x, op); }
+void vpunpcklqdq(const Xmm& x, const Operand& op) { vpunpcklqdq(x, x, op); }
+void vpunpcklwd(const Xmm& x, const Operand& op) { vpunpcklwd(x, x, op); }
+void vpxor(const Xmm& x, const Operand& op) { vpxor(x, x, op); }
+void vrcpss(const Xmm& x, const Operand& op) { vrcpss(x, x, op); }
+void vroundsd(const Xmm& x, const Operand& op, uint8 imm) { vroundsd(x, x, op, imm); }
+void vroundss(const Xmm& x, const Operand& op, uint8 imm) { vroundss(x, x, op, imm); }
+void vrsqrtss(const Xmm& x, const Operand& op) { vrsqrtss(x, x, op); }
+void vshufpd(const Xmm& x, const Operand& op, uint8 imm) { vshufpd(x, x, op, imm); }
+void vshufps(const Xmm& x, const Operand& op, uint8 imm) { vshufps(x, x, op, imm); }
+void vsqrtsd(const Xmm& x, const Operand& op) { vsqrtsd(x, x, op); }
+void vsqrtss(const Xmm& x, const Operand& op) { vsqrtss(x, x, op); }
+void vunpckhpd(const Xmm& x, const Operand& op) { vunpckhpd(x, x, op); }
+void vunpckhps(const Xmm& x, const Operand& op) { vunpckhps(x, x, op); }
+void vunpcklpd(const Xmm& x, const Operand& op) { vunpcklpd(x, x, op); }
+void vunpcklps(const Xmm& x, const Operand& op) { vunpcklps(x, x, op); }
+#endif
+#ifdef XBYAK64
+void jecxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); }
+void jecxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); }
+void jrcxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); }
+void jrcxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); }
+void cdqe() { db(0x48); db(0x98); }
+void cqo() { db(0x48); db(0x99); }
+void cmpsq() { db(0x48); db(0xA7); }
+void movsq() { db(0x48); db(0xA5); }
+void scasq() { db(0x48); db(0xAF); }
+void stosq() { db(0x48); db(0xAB); }
+void cmpxchg16b(const Address& addr) { opModM(addr, Reg64(1), 0x0F, 0xC7); }
+void movq(const Reg64& reg, const Mmx& mmx) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x7E); }
+void movq(const Mmx& mmx, const Reg64& reg) { if (mmx.isXMM()) db(0x66); opModR(mmx, reg, 0x0F, 0x6E); }
+void movsxd(const Reg64& reg, const Operand& op) { if (!op.isBit(32)) throw Error(ERR_BAD_COMBINATION); opModRM(reg, op, op.isREG(), op.isMEM(), 0x63); }
+void pextrq(const Operand& op, const Xmm& xmm, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x16, 0x66, 0, imm, 0x3A); }
+void pinsrq(const Xmm& xmm, const Operand& op, uint8 imm) { if (!op.isREG(64) && !op.isMEM()) throw Error(ERR_BAD_COMBINATION); opGen(Reg64(xmm.getIdx()), op, 0x22, 0x66, 0, imm, 0x3A); }
+void vcvtss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_ER_X | T_N8, 0x2D); }
+void vcvttss2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F3 | T_W1 | T_EVEX | T_EW1 | T_SAE_X | T_N8, 0x2C); }
+void vcvtsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_ER_X, 0x2D); }
+void vcvttsd2si(const Reg64& r, const Operand& op) { opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, T_0F | T_F2 | T_W1 | T_EVEX | T_EW1 | T_N4 | T_SAE_X, 0x2C); }
+void vmovq(const Xmm& x, const Reg64& r) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x6E); }
+void vmovq(const Reg64& r, const Xmm& x) { opAVX_X_X_XM(x, xm0, Xmm(r.getIdx()), T_66 | T_0F | T_W1 | T_EVEX | T_EW1, 0x7E); }
+#else
+void jcxz(std::string label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); }
+void jcxz(const Label& label) { db(0x67); opJmp(label, T_SHORT, 0xe3, 0, 0); }
+void jecxz(std::string label) { opJmp(label, T_SHORT, 0xe3, 0, 0); }
+void jecxz(const Label& label) { opJmp(label, T_SHORT, 0xe3, 0, 0); }
+void aaa() { db(0x37); }
+void aad() { db(0xD5); db(0x0A); }
+void aam() { db(0xD4); db(0x0A); }
+void aas() { db(0x3F); }
+void daa() { db(0x27); }
+void das() { db(0x2F); }
+void popad() { db(0x61); }
+void popfd() { db(0x9D); }
+void pusha() { db(0x60); }
+void pushad() { db(0x60); }
+void pushfd() { db(0x9C); }
+void popa() { db(0x61); }
+#endif
+#ifndef XBYAK_NO_OP_NAMES
+void and(const Operand& op1, const Operand& op2) { and_(op1, op2); }
+void and(const Operand& op, uint32 imm) { and_(op, imm); }
+void or(const Operand& op1, const Operand& op2) { or_(op1, op2); }
+void or(const Operand& op, uint32 imm) { or_(op, imm); }
+void xor(const Operand& op1, const Operand& op2) { xor_(op1, op2); }
+void xor(const Operand& op, uint32 imm) { xor_(op, imm); }
+void not(const Operand& op) { not_(op); }
+#endif
+#ifndef XBYAK_DISABLE_AVX512
+void kaddb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4A); }
+void kaddd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x4A); }
+void kaddq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4A); }
+void kaddw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4A); }
+void kandb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x41); }
+void kandd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x41); }
+void kandnb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x42); }
+void kandnd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x42); }
+void kandnq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x42); }
+void kandnw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x42); }
+void kandq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x41); }
+void kandw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x41); }
+void kmovb(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W0, 0x91); }
+void kmovb(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W0, 0x90); }
+void kmovb(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_66 | T_W0, 0x92); }
+void kmovb(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_66 | T_W0, 0x93); }
+void kmovd(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_66 | T_W1, 0x91); }
+void kmovd(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_66 | T_W1, 0x90); }
+void kmovd(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W0, 0x92); }
+void kmovd(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W0, 0x93); }
+void kmovq(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W1, 0x91); }
+void kmovq(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W1, 0x90); }
+void kmovw(const Address& addr, const Opmask& k) { opVex(k, 0, addr, T_L0 | T_0F | T_W0, 0x91); }
+void kmovw(const Opmask& k, const Operand& op) { opVex(k, 0, op, T_L0 | T_0F | T_W0, 0x90); }
+void kmovw(const Opmask& k, const Reg32& r) { opVex(k, 0, r, T_L0 | T_0F | T_W0, 0x92); }
+void kmovw(const Reg32& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_W0, 0x93); }
+void knotb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x44); }
+void knotd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x44); }
+void knotq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x44); }
+void knotw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x44); }
+void korb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x45); }
+void kord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x45); }
+void korq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x45); }
+void kortestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x98); }
+void kortestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x98); }
+void kortestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x98); }
+void kortestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x98); }
+void korw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x45); }
+void kshiftlb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x32, imm); }
+void kshiftld(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x33, imm); }
+void kshiftlq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x33, imm); }
+void kshiftlw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x32, imm); }
+void kshiftrb(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x30, imm); }
+void kshiftrd(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W0, 0x31, imm); }
+void kshiftrq(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x31, imm); }
+void kshiftrw(const Opmask& r1, const Opmask& r2, uint8 imm) { opVex(r1, 0, r2, T_66 | T_0F3A | T_W1, 0x30, imm); }
+void ktestb(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W0, 0x99); }
+void ktestd(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_66 | T_W1, 0x99); }
+void ktestq(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W1, 0x99); }
+void ktestw(const Opmask& r1, const Opmask& r2) { opVex(r1, 0, r2, T_0F | T_W0, 0x99); }
+void kunpckbw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x4B); }
+void kunpckdq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x4B); }
+void kunpckwd(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x4B); }
+void kxnorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x46); }
+void kxnord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x46); }
+void kxnorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x46); }
+void kxnorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x46); }
+void kxorb(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W0, 0x47); }
+void kxord(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_66 | T_W1, 0x47); }
+void kxorq(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W1, 0x47); }
+void kxorw(const Opmask& r1, const Opmask& r2, const Opmask& r3) { opVex(r1, &r2, r3, T_L1 | T_0F | T_W0, 0x47); }
+void v4fmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x9A); }
+void v4fmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0x9B); }
+void v4fnmaddps(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0xAA); }
+void v4fnmaddss(const Xmm& x1, const Xmm& x2, const Address& addr) { opAVX_X_X_XM(x1, x2, addr, T_0F38 | T_F2 | T_EW0 | T_MUST_EVEX | T_N16, 0xAB); }
+void valignd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x03, imm); }
+void valignq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x03, imm); }
+void vblendmpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x65); }
+void vblendmps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x65); }
+void vbroadcastf32x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x19); }
+void vbroadcastf32x4(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x1A); }
+void vbroadcastf32x8(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x1B); }
+void vbroadcastf64x2(const Ymm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x1A); }
+void vbroadcastf64x4(const Zmm& y, const Address& addr) { opAVX_X_XM_IMM(y, addr, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x1B); }
+void vbroadcasti32x2(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N8, 0x59); }
+void vbroadcasti32x4(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N16, 0x5A); }
+void vbroadcasti32x8(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0 | T_N32, 0x5B); }
+void vbroadcasti64x2(const Ymm& y, const Operand& op) { opAVX_X_XM_IMM(y, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N16, 0x5A); }
+void vbroadcasti64x4(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1 | T_N32, 0x5B); }
+void vcmppd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); }
+void vcmpps(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); }
+void vcmpsd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N8 | T_F2 | T_0F | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); }
+void vcmpss(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_N4 | T_F3 | T_0F | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0xC2, imm); }
+void vcompressb(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x63); }
+void vcompresspd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8A); }
+void vcompressps(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8A); }
+void vcompressw(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x63); }
+void vcvtpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7B); }
+void vcvtpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x79); }
+void vcvtpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x79); }
+void vcvtps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x7B); }
+void vcvtps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x79); }
+void vcvtps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_ER_Y, 0x79); }
+void vcvtqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0xE6); }
+void vcvtqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x5B); }
+void vcvtsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); }
+void vcvtss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_ER_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x79); }
+void vcvttpd2qq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x7A); }
+void vcvttpd2udq(const Xmm& x, const Operand& op) { opCvt2(x, op, T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_SAE_Z, 0x78); }
+void vcvttpd2uqq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x78); }
+void vcvttps2qq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x7A); }
+void vcvttps2udq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_0F | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x78); }
+void vcvttps2uqq(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL | T_SAE_Y, 0x78); }
+void vcvttsd2usi(const Reg32e& r, const Operand& op) { int type = (T_F2 | T_0F | T_MUST_EVEX | T_N8 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); }
+void vcvttss2usi(const Reg32e& r, const Operand& op) { int type = (T_F3 | T_0F | T_MUST_EVEX | T_N4 | T_SAE_X) | (r.isREG(64) ? T_EW1 : T_EW0); opAVX_X_X_XM(Xmm(r.getIdx()), xm0, op, type, 0x78); }
+void vcvtudq2pd(const Xmm& x, const Operand& op) { checkCvt1(x, op); opVex(x, 0, op, T_F3 | T_0F | T_YMM | T_MUST_EVEX | T_EW0 | T_B32 | T_N8 | T_N_VL, 0x7A); }
+void vcvtudq2ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x7A); }
+void vcvtuqq2pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x7A); }
+void vcvtuqq2ps(const Xmm& x, const Operand& op) { opCvt2(x, op, T_F2 | T_0F | T_YMM | T_MUST_EVEX | T_EW1 | T_B64 | T_ER_Z, 0x7A); }
+void vcvtusi2sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F2 | T_0F | T_MUST_EVEX, T_W1 | T_EW1 | T_ER_X | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); }
+void vcvtusi2ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opCvt3(x1, x2, op, T_F3 | T_0F | T_MUST_EVEX | T_ER_X, T_W1 | T_EW1 | T_N8, T_W0 | T_EW0 | T_N4, 0x7B); }
+void vdbpsadbw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x42, imm); }
+void vexp2pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xC8); }
+void vexp2ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xC8); }
+void vexpandpd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x88); }
+void vexpandps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x88); }
+void vextractf32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x19, imm); }
+void vextractf32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1B, imm); }
+void vextractf64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x19, imm); }
+void vextractf64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1B, imm); }
+void vextracti32x4(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x39, imm); }
+void vextracti32x8(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3B, imm); }
+void vextracti64x2(const Operand& op, const Ymm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::XMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x39, imm); }
+void vextracti64x4(const Operand& op, const Zmm& r, uint8 imm) { if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r, 0, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3B, imm); }
+void vfixupimmpd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x54, imm); }
+void vfixupimmps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x54, imm); }
+void vfixupimmsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); }
+void vfixupimmss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_Z | T_MUST_EVEX, 0x55, imm); }
+void vfpclasspd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW1 | T_B64, 0x66, imm); }
+void vfpclassps(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isBit(128|256|512)) throw Error(ERR_BAD_MEM_SIZE); Reg x = k; x.setBit(op.getBit()); opVex(x, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_YMM | T_EW0 | T_B32, 0x66, imm); }
+void vfpclasssd(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW1 | T_N8, 0x67, imm); }
+void vfpclassss(const Opmask& k, const Operand& op, uint8 imm) { if (!op.isXMEM()) throw Error(ERR_BAD_MEM_SIZE); opVex(k, 0, op, T_66 | T_0F3A | T_MUST_EVEX | T_EW0 | T_N4, 0x67, imm); }
+void vgatherdpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 1); }
+void vgatherdps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x92, 0); }
+void vgatherpf0dpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); }
+void vgatherpf0dps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); }
+void vgatherpf0qpd(const Address& addr) { opGatherFetch(addr, zm1, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
+void vgatherpf0qps(const Address& addr) { opGatherFetch(addr, zm1, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
+void vgatherpf1dpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); }
+void vgatherpf1dps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); }
+void vgatherpf1qpd(const Address& addr) { opGatherFetch(addr, zm2, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
+void vgatherpf1qps(const Address& addr) { opGatherFetch(addr, zm2, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
+void vgatherqpd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 0); }
+void vgatherqps(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x93, 2); }
+void vgetexppd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x42); }
+void vgetexpps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x42); }
+void vgetexpsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x43); }
+void vgetexpss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x43); }
+void vgetmantpd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x26, imm); }
+void vgetmantps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x26, imm); }
+void vgetmantsd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x27, imm); }
+void vgetmantss(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x27, imm); }
+void vinsertf32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x18, imm); }
+void vinsertf32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x1A, imm); }
+void vinsertf64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x18, imm); }
+void vinsertf64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x1A, imm); }
+void vinserti32x4(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x38, imm); }
+void vinserti32x8(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3A, imm); }
+void vinserti64x2(const Ymm& r1, const Ymm& r2, const Operand& op, uint8 imm) {if (!(r1.getKind() == r2.getKind() && op.is(Operand::MEM | Operand::XMM))) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N16 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x38, imm); }
+void vinserti64x4(const Zmm& r1, const Zmm& r2, const Operand& op, uint8 imm) {if (!op.is(Operand::MEM | Operand::YMM)) throw Error(ERR_BAD_COMBINATION); opVex(r1, &r2, op, T_N32 | T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3A, imm); }
+void vmovdqa32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
+void vmovdqa32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
+void vmovdqa64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
+void vmovdqa64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
+void vmovdqu16(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
+void vmovdqu16(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
+void vmovdqu32(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
+void vmovdqu32(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
+void vmovdqu64(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
+void vmovdqu64(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F3 | T_0F | T_EW1 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
+void vmovdqu8(const Address& addr, const Xmm& x) { opAVX_X_XM_IMM(x, addr, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX | T_M_K, 0x7F); }
+void vmovdqu8(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_F2 | T_0F | T_EW0 | T_YMM | T_ER_X | T_ER_Y | T_ER_Z | T_MUST_EVEX, 0x6F); }
+void vp4dpwssd(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x52); }
+void vp4dpwssds(const Zmm& z1, const Zmm& z2, const Address& addr) { opAVX_X_X_XM(z1, z2, addr, T_0F38 | T_F2 | T_EW0 | T_YMM | T_MUST_EVEX | T_N16, 0x53); }
+void vpabsq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_MUST_EVEX | T_EW1 | T_B64 | T_YMM, 0x1F); }
+void vpandd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDB); }
+void vpandnd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xDF); }
+void vpandnq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDF); }
+void vpandq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xDB); }
+void vpblendmb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x66); }
+void vpblendmd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x64); }
+void vpblendmq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x64); }
+void vpblendmw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x66); }
+void vpbroadcastb(const Xmm& x, const Reg8& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7A); }
+void vpbroadcastd(const Xmm& x, const Reg32& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7C); }
+void vpbroadcastmb2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW1, 0x2A); }
+void vpbroadcastmw2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_YMM | T_MUST_EVEX | T_EW0, 0x3A); }
+void vpbroadcastw(const Xmm& x, const Reg16& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7B); }
+void vpcmpb(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3F, imm); }
+void vpcmpd(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1F, imm); }
+void vpcmpeqb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x74); }
+void vpcmpeqd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX | T_B32, 0x76); }
+void vpcmpeqq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x29); }
+void vpcmpeqw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x75); }
+void vpcmpgtb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x64); }
+void vpcmpgtd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x66); }
+void vpcmpgtq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x37); }
+void vpcmpgtw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F | T_YMM | T_MUST_EVEX, 0x65); }
+void vpcmpq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1F, imm); }
+void vpcmpub(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX, 0x3E, imm); }
+void vpcmpud(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x1E, imm); }
+void vpcmpuq(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x1E, imm); }
+void vpcmpuw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3E, imm); }
+void vpcmpw(const Opmask& k, const Xmm& x, const Operand& op, uint8 imm) { opAVX_K_X_XM(k, x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX, 0x3F, imm); }
+void vpcompressd(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8B); }
+void vpcompressq(const Operand& op, const Xmm& x) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8B); }
+void vpconflictd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xC4); }
+void vpconflictq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xC4); }
+void vpdpbusd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50); }
+void vpdpbusds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x51); }
+void vpdpwssd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x52); }
+void vpdpwssds(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x53); }
+void vpermb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8D); }
+void vpermi2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x75); }
+void vpermi2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x76); }
+void vpermi2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x77); }
+void vpermi2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x77); }
+void vpermi2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x76); }
+void vpermi2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x75); }
+void vpermt2b(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x7D); }
+void vpermt2d(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7E); }
+void vpermt2pd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7F); }
+void vpermt2ps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x7F); }
+void vpermt2q(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x7E); }
+void vpermt2w(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7D); }
+void vpermw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x8D); }
+void vpexpandb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N1 | T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); }
+void vpexpandd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x89); }
+void vpexpandq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x89); }
+void vpexpandw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_N2 | T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x62); }
+void vpgatherdd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 0); }
+void vpgatherdq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x90, 1); }
+void vpgatherqd(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 2); }
+void vpgatherqq(const Xmm& x, const Address& addr) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_VSIB, 0x91, 0); }
+void vplzcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x44); }
+void vplzcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x44); }
+void vpmadd52huq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB5); }
+void vpmadd52luq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xB4); }
+void vpmaxsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3D); }
+void vpmaxuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3F); }
+void vpminsq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x39); }
+void vpminuq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x3B); }
+void vpmovb2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x29); }
+void vpmovd2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x39); }
+void vpmovdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x31, false); }
+void vpmovdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x33, true); }
+void vpmovm2b(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x28); }
+void vpmovm2d(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0, 0x38); }
+void vpmovm2q(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x38); }
+void vpmovm2w(const Xmm& x, const Opmask& k) { opVex(x, 0, k, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x28); }
+void vpmovq2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x39); }
+void vpmovqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x32, false); }
+void vpmovqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x35, true); }
+void vpmovqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x34, false); }
+void vpmovsdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x21, false); }
+void vpmovsdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x23, true); }
+void vpmovsqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x22, false); }
+void vpmovsqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x25, true); }
+void vpmovsqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x24, false); }
+void vpmovswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x20, true); }
+void vpmovusdb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x11, false); }
+void vpmovusdw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x13, true); }
+void vpmovusqb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N2 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x12, false); }
+void vpmovusqd(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x15, true); }
+void vpmovusqw(const Operand& op, const Xmm& x) { opVmov(op, x, T_N4 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x14, false); }
+void vpmovuswb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x10, true); }
+void vpmovw2m(const Opmask& k, const Xmm& x) { opVex(k, 0, x, T_F3 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1, 0x29); }
+void vpmovwb(const Operand& op, const Xmm& x) { opVmov(op, x, T_N8 | T_N_VL | T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x30, true); }
+void vpmullq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x40); }
+void vpmultishiftqb(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x83); }
+void vpopcntb(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); }
+void vpopcntd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x55); }
+void vpopcntq(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x55); }
+void vpopcntw(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x54); }
+void vpord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEB); }
+void vporq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEB); }
+void vprold(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); }
+void vprolq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 1), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); }
+void vprolvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x15); }
+void vprolvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x15); }
+void vprord(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x72, imm); }
+void vprorq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 0), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); }
+void vprorvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x14); }
+void vprorvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x14); }
+void vpscatterdd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 0); }
+void vpscatterdq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA0, 1); }
+void vpscatterqd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 2); }
+void vpscatterqq(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA1, 0); }
+void vpshldd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71, imm); }
+void vpshldq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71, imm); }
+void vpshldvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x71); }
+void vpshldvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x71); }
+void vpshldvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70); }
+void vpshldw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x70, imm); }
+void vpshrdd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73, imm); }
+void vpshrdq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73, imm); }
+void vpshrdvd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x73); }
+void vpshrdvq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x73); }
+void vpshrdvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72); }
+void vpshrdw(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX, 0x72, imm); }
+void vpshufbitqmb(const Opmask& k, const Xmm& x, const Operand& op) { opVex(k, &x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x8F); }
+void vpsllvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x12); }
+void vpsraq(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_X_XM(Xmm(x.getKind(), 4), x, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x72, imm); }
+void vpsraq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N16 | T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX, 0xE2); }
+void vpsravq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x46); }
+void vpsravw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x11); }
+void vpsrlvw(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x10); }
+void vpternlogd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x25, imm); }
+void vpternlogq(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x25, imm); }
+void vptestmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); }
+void vptestmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); }
+void vptestmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); }
+void vptestmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); }
+void vptestnmb(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x26); }
+void vptestnmd(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x27); }
+void vptestnmq(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x27); }
+void vptestnmw(const Opmask& k, const Xmm& x, const Operand& op) { opAVX_K_X_XM(k, x, op, T_F3 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x26); }
+void vpxord(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0xEF); }
+void vpxorq(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0xEF); }
+void vrangepd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x50, imm); }
+void vrangeps(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x50, imm); }
+void vrangesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x51, imm); }
+void vrangess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x51, imm); }
+void vrcp14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4C); }
+void vrcp14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4C); }
+void vrcp14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX, 0x4D); }
+void vrcp14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX, 0x4D); }
+void vrcp28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCA); }
+void vrcp28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCA); }
+void vrcp28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCB); }
+void vrcp28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCB); }
+void vreducepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B64, 0x56, imm); }
+void vreduceps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_SAE_Z | T_MUST_EVEX | T_B32, 0x56, imm); }
+void vreducesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_SAE_X | T_MUST_EVEX, 0x57, imm); }
+void vreducess(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_SAE_X | T_MUST_EVEX, 0x57, imm); }
+void vrndscalepd(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x09, imm); }
+void vrndscaleps(const Xmm& x, const Operand& op, uint8 imm) { opAVX_X_XM_IMM(x, op, T_66 | T_0F3A | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x08, imm); }
+void vrndscalesd(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F3A | T_EW1 | T_MUST_EVEX, 0x0B, imm); }
+void vrndscaless(const Xmm& x1, const Xmm& x2, const Operand& op, uint8 imm) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F3A | T_EW0 | T_MUST_EVEX, 0x0A, imm); }
+void vrsqrt14pd(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_B64, 0x4E); }
+void vrsqrt14ps(const Xmm& x, const Operand& op) { opAVX_X_XM_IMM(x, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_B32, 0x4E); }
+void vrsqrt14sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x4F); }
+void vrsqrt14ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX, 0x4F); }
+void vrsqrt28pd(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW1 | T_B64 | T_SAE_Z, 0xCC); }
+void vrsqrt28ps(const Zmm& z, const Operand& op) { opAVX_X_XM_IMM(z, op, T_66 | T_0F38 | T_MUST_EVEX | T_YMM | T_EW0 | T_B32 | T_SAE_Z, 0xCC); }
+void vrsqrt28sd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_SAE_X | T_MUST_EVEX, 0xCD); }
+void vrsqrt28ss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_SAE_X | T_MUST_EVEX, 0xCD); }
+void vscalefpd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW1 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B64, 0x2C); }
+void vscalefps(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_66 | T_0F38 | T_EW0 | T_YMM | T_ER_Z | T_MUST_EVEX | T_B32, 0x2C); }
+void vscalefsd(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N8 | T_66 | T_0F38 | T_EW1 | T_ER_X | T_MUST_EVEX, 0x2D); }
+void vscalefss(const Xmm& x1, const Xmm& x2, const Operand& op) { opAVX_X_X_XM(x1, x2, op, T_N4 | T_66 | T_0F38 | T_EW0 | T_ER_X | T_MUST_EVEX, 0x2D); }
+void vscatterdpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 1); }
+void vscatterdps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA2, 0); }
+void vscatterpf0dpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); }
+void vscatterpf0dps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); }
+void vscatterpf0qpd(const Address& addr) { opGatherFetch(addr, zm5, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
+void vscatterpf0qps(const Address& addr) { opGatherFetch(addr, zm5, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
+void vscatterpf1dpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::YMM); }
+void vscatterpf1dps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC6, Operand::ZMM); }
+void vscatterpf1qpd(const Address& addr) { opGatherFetch(addr, zm6, T_N8 | T_66 | T_0F38 | T_EW1 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
+void vscatterpf1qps(const Address& addr) { opGatherFetch(addr, zm6, T_N4 | T_66 | T_0F38 | T_EW0 | T_MUST_EVEX | T_M_K | T_VSIB, 0xC7, Operand::ZMM); }
+void vscatterqpd(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N8 | T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 0); }
+void vscatterqps(const Address& addr, const Xmm& x) { opGather2(x, addr, T_N4 | T_66 | T_0F38 | T_EW0 | T_YMM | T_MUST_EVEX | T_M_K | T_VSIB, 0xA3, 2); }
+void vshuff32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x23, imm); }
+void vshuff64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x23, imm); }
+void vshufi32x4(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW0 | T_B32, 0x43, imm); }
+void vshufi64x2(const Ymm& y1, const Ymm& y2, const Operand& op, uint8 imm) { opAVX_X_X_XM(y1, y2, op, T_66 | T_0F3A | T_YMM | T_MUST_EVEX | T_EW1 | T_B64, 0x43, imm); }
+#ifdef XBYAK64
+void kmovq(const Opmask& k, const Reg64& r) { opVex(k, 0, r, T_L0 | T_0F | T_F2 | T_W1, 0x92); }
+void kmovq(const Reg64& r, const Opmask& k) { opVex(r, 0, k, T_L0 | T_0F | T_F2 | T_W1, 0x93); }
+void vpbroadcastq(const Xmm& x, const Reg64& r) { opVex(x, 0, r, T_66 | T_0F38 | T_EW1 | T_YMM | T_MUST_EVEX, 0x7C); }
+#endif
+#endif
diff --git a/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h
new file mode 100644
index 0000000000..8ef076e680
--- /dev/null
+++ b/thirdparty/oidn/mkl-dnn/src/cpu/xbyak/xbyak_util.h
@@ -0,0 +1,772 @@
+/*******************************************************************************
+* Copyright 2016-2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*******************************************************************************
+* Copyright (c) 2007 MITSUNARI Shigeo
+* All rights reserved.
+*
+* Redistribution and use in source and binary forms, with or without
+* modification, are permitted provided that the following conditions are met:
+*
+* Redistributions of source code must retain the above copyright notice, this
+* list of conditions and the following disclaimer.
+* Redistributions in binary form must reproduce the above copyright notice,
+* this list of conditions and the following disclaimer in the documentation
+* and/or other materials provided with the distribution.
+* Neither the name of the copyright owner nor the names of its contributors may
+* be used to endorse or promote products derived from this software without
+* specific prior written permission.
+*
+* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
+* THE POSSIBILITY OF SUCH DAMAGE.
+*******************************************************************************/
+
+#ifndef XBYAK_XBYAK_UTIL_H_
+#define XBYAK_XBYAK_UTIL_H_
+
+/**
+ utility class and functions for Xbyak
+ Xbyak::util::Clock ; rdtsc timer
+ Xbyak::util::Cpu ; detect CPU
+ @note this header is UNDER CONSTRUCTION!
+*/
+#include "xbyak.h"
+
+#if defined(__i386__) || defined(__x86_64__) || defined(_M_IX86) || defined(_M_X64)
+ #define XBYAK_INTEL_CPU_SPECIFIC
+#endif
+
+#ifdef XBYAK_INTEL_CPU_SPECIFIC
+#ifdef _MSC_VER
+ #if (_MSC_VER < 1400) && defined(XBYAK32)
+ static inline __declspec(naked) void __cpuid(int[4], int)
+ {
+ __asm {
+ push ebx
+ push esi
+ mov eax, dword ptr [esp + 4 * 2 + 8] // eaxIn
+ cpuid
+ mov esi, dword ptr [esp + 4 * 2 + 4] // data
+ mov dword ptr [esi], eax
+ mov dword ptr [esi + 4], ebx
+ mov dword ptr [esi + 8], ecx
+ mov dword ptr [esi + 12], edx
+ pop esi
+ pop ebx
+ ret
+ }
+ }
+ #else
+ #include <intrin.h> // for __cpuid
+ #endif
+#else
+ #ifndef __GNUC_PREREQ
+ #define __GNUC_PREREQ(major, minor) ((((__GNUC__) << 16) + (__GNUC_MINOR__)) >= (((major) << 16) + (minor)))
+ #endif
+ #if __GNUC_PREREQ(4, 3) && !defined(__APPLE__)
+ #include <cpuid.h>
+ #else
+ #if defined(__APPLE__) && defined(XBYAK32) // avoid err : can't find a register in class `BREG' while reloading `asm'
+ #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn))
+ #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("pushl %%ebx\ncpuid\nmovl %%ebp, %%esi\npopl %%ebx" : "=a"(a), "=S"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn))
+ #else
+ #define __cpuid(eaxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn))
+ #define __cpuid_count(eaxIn, ecxIn, a, b, c, d) __asm__ __volatile__("cpuid\n" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "0"(eaxIn), "2"(ecxIn))
+ #endif
+ #endif
+#endif
+#endif
+
+namespace Xbyak { namespace util {
+
+typedef enum {
+ SmtLevel = 1,
+ CoreLevel = 2
+} IntelCpuTopologyLevel;
+
+/**
+ CPU detection class
+*/
+class Cpu {
+ uint64 type_;
+ //system topology
+ bool x2APIC_supported_;
+ static const size_t maxTopologyLevels = 2;
+ unsigned int numCores_[maxTopologyLevels];
+
+ static const unsigned int maxNumberCacheLevels = 10;
+ unsigned int dataCacheSize_[maxNumberCacheLevels];
+ unsigned int coresSharignDataCache_[maxNumberCacheLevels];
+ unsigned int dataCacheLevels_;
+
+ unsigned int get32bitAsBE(const char *x) const
+ {
+ return x[0] | (x[1] << 8) | (x[2] << 16) | (x[3] << 24);
+ }
+ unsigned int mask(int n) const
+ {
+ return (1U << n) - 1;
+ }
+ void setFamily()
+ {
+ unsigned int data[4] = {};
+ getCpuid(1, data);
+ stepping = data[0] & mask(4);
+ model = (data[0] >> 4) & mask(4);
+ family = (data[0] >> 8) & mask(4);
+ // type = (data[0] >> 12) & mask(2);
+ extModel = (data[0] >> 16) & mask(4);
+ extFamily = (data[0] >> 20) & mask(8);
+ if (family == 0x0f) {
+ displayFamily = family + extFamily;
+ } else {
+ displayFamily = family;
+ }
+ if (family == 6 || family == 0x0f) {
+ displayModel = (extModel << 4) + model;
+ } else {
+ displayModel = model;
+ }
+ }
+ unsigned int extractBit(unsigned int val, unsigned int base, unsigned int end)
+ {
+ return (val >> base) & ((1u << (end - base)) - 1);
+ }
+ void setNumCores()
+ {
+ if ((type_ & tINTEL) == 0) return;
+
+ unsigned int data[4] = {};
+
+ /* CAUTION: These numbers are configuration as shipped by Intel. */
+ getCpuidEx(0x0, 0, data);
+ if (data[0] >= 0xB) {
+ /*
+ if leaf 11 exists(x2APIC is supported),
+ we use it to get the number of smt cores and cores on socket
+
+ leaf 0xB can be zeroed-out by a hypervisor
+ */
+ x2APIC_supported_ = true;
+ for (unsigned int i = 0; i < maxTopologyLevels; i++) {
+ getCpuidEx(0xB, i, data);
+ IntelCpuTopologyLevel level = (IntelCpuTopologyLevel)extractBit(data[2], 8, 15);
+ if (level == SmtLevel || level == CoreLevel) {
+ numCores_[level - 1] = extractBit(data[1], 0, 15);
+ }
+ }
+ } else {
+ /*
+ Failed to deremine num of cores without x2APIC support.
+ TODO: USE initial APIC ID to determine ncores.
+ */
+ numCores_[SmtLevel - 1] = 0;
+ numCores_[CoreLevel - 1] = 0;
+ }
+
+ }
+ void setCacheHierarchy()
+ {
+ if ((type_ & tINTEL) == 0) return;
+ const unsigned int NO_CACHE = 0;
+ const unsigned int DATA_CACHE = 1;
+// const unsigned int INSTRUCTION_CACHE = 2;
+ const unsigned int UNIFIED_CACHE = 3;
+ unsigned int smt_width = 0;
+ unsigned int logical_cores = 0;
+ unsigned int data[4] = {};
+
+ if (x2APIC_supported_) {
+ smt_width = numCores_[0];
+ logical_cores = numCores_[1];
+ }
+
+ /*
+ Assumptions:
+ the first level of data cache is not shared (which is the
+ case for every existing architecture) and use this to
+ determine the SMT width for arch not supporting leaf 11.
+ when leaf 4 reports a number of core less than numCores_
+ on socket reported by leaf 11, then it is a correct number
+ of cores not an upperbound.
+ */
+ for (int i = 0; dataCacheLevels_ < maxNumberCacheLevels; i++) {
+ getCpuidEx(0x4, i, data);
+ unsigned int cacheType = extractBit(data[0], 0, 4);
+ if (cacheType == NO_CACHE) break;
+ if (cacheType == DATA_CACHE || cacheType == UNIFIED_CACHE) {
+ unsigned int actual_logical_cores = extractBit(data[0], 14, 25) + 1;
+ if (logical_cores != 0) { // true only if leaf 0xB is supported and valid
+ actual_logical_cores = (std::min)(actual_logical_cores, logical_cores);
+ }
+ assert(actual_logical_cores != 0);
+ dataCacheSize_[dataCacheLevels_] =
+ (extractBit(data[1], 22, 31) + 1)
+ * (extractBit(data[1], 12, 21) + 1)
+ * (extractBit(data[1], 0, 11) + 1)
+ * (data[2] + 1);
+ if (cacheType == DATA_CACHE && smt_width == 0) smt_width = actual_logical_cores;
+ assert(smt_width != 0);
+ // FIXME: check and fix number of cores sharing L3 cache for different configurations
+ // (HT-, 2 sockets), (HT-, 1 socket), (HT+, 2 sockets), (HT+, 1 socket)
+ coresSharignDataCache_[dataCacheLevels_] = (std::max)(actual_logical_cores / smt_width, 1u);
+ dataCacheLevels_++;
+ }
+ }
+ }
+
+public:
+ int model;
+ int family;
+ int stepping;
+ int extModel;
+ int extFamily;
+ int displayFamily; // family + extFamily
+ int displayModel; // model + extModel
+
+ unsigned int getNumCores(IntelCpuTopologyLevel level) {
+ if (level != SmtLevel && level != CoreLevel) throw Error(ERR_BAD_PARAMETER);
+ if (!x2APIC_supported_) throw Error(ERR_X2APIC_IS_NOT_SUPPORTED);
+ return (level == CoreLevel)
+ ? numCores_[level - 1] / numCores_[SmtLevel - 1]
+ : numCores_[level - 1];
+ }
+
+ unsigned int getDataCacheLevels() const { return dataCacheLevels_; }
+ unsigned int getCoresSharingDataCache(unsigned int i) const
+ {
+ if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER);
+ return coresSharignDataCache_[i];
+ }
+ unsigned int getDataCacheSize(unsigned int i) const
+ {
+ if (i >= dataCacheLevels_) throw Error(ERR_BAD_PARAMETER);
+ return dataCacheSize_[i];
+ }
+
+ /*
+ data[] = { eax, ebx, ecx, edx }
+ */
+ static inline void getCpuid(unsigned int eaxIn, unsigned int data[4])
+ {
+#ifdef XBYAK_INTEL_CPU_SPECIFIC
+ #ifdef _MSC_VER
+ __cpuid(reinterpret_cast<int*>(data), eaxIn);
+ #else
+ __cpuid(eaxIn, data[0], data[1], data[2], data[3]);
+ #endif
+#else
+ (void)eaxIn;
+ (void)data;
+#endif
+ }
+ static inline void getCpuidEx(unsigned int eaxIn, unsigned int ecxIn, unsigned int data[4])
+ {
+#ifdef XBYAK_INTEL_CPU_SPECIFIC
+ #ifdef _MSC_VER
+ __cpuidex(reinterpret_cast<int*>(data), eaxIn, ecxIn);
+ #else
+ __cpuid_count(eaxIn, ecxIn, data[0], data[1], data[2], data[3]);
+ #endif
+#else
+ (void)eaxIn;
+ (void)ecxIn;
+ (void)data;
+#endif
+ }
+ static inline uint64 getXfeature()
+ {
+#ifdef XBYAK_INTEL_CPU_SPECIFIC
+ #ifdef _MSC_VER
+ return _xgetbv(0);
+ #else
+ unsigned int eax, edx;
+ // xgetvb is not support on gcc 4.2
+// __asm__ volatile("xgetbv" : "=a"(eax), "=d"(edx) : "c"(0));
+ __asm__ volatile(".byte 0x0f, 0x01, 0xd0" : "=a"(eax), "=d"(edx) : "c"(0));
+ return ((uint64)edx << 32) | eax;
+ #endif
+#else
+ return 0;
+#endif
+ }
+ typedef uint64 Type;
+
+ static const Type NONE = 0;
+ static const Type tMMX = 1 << 0;
+ static const Type tMMX2 = 1 << 1;
+ static const Type tCMOV = 1 << 2;
+ static const Type tSSE = 1 << 3;
+ static const Type tSSE2 = 1 << 4;
+ static const Type tSSE3 = 1 << 5;
+ static const Type tSSSE3 = 1 << 6;
+ static const Type tSSE41 = 1 << 7;
+ static const Type tSSE42 = 1 << 8;
+ static const Type tPOPCNT = 1 << 9;
+ static const Type tAESNI = 1 << 10;
+ static const Type tSSE5 = 1 << 11;
+ static const Type tOSXSAVE = 1 << 12;
+ static const Type tPCLMULQDQ = 1 << 13;
+ static const Type tAVX = 1 << 14;
+ static const Type tFMA = 1 << 15;
+
+ static const Type t3DN = 1 << 16;
+ static const Type tE3DN = 1 << 17;
+ static const Type tSSE4a = 1 << 18;
+ static const Type tRDTSCP = 1 << 19;
+ static const Type tAVX2 = 1 << 20;
+ static const Type tBMI1 = 1 << 21; // andn, bextr, blsi, blsmsk, blsr, tzcnt
+ static const Type tBMI2 = 1 << 22; // bzhi, mulx, pdep, pext, rorx, sarx, shlx, shrx
+ static const Type tLZCNT = 1 << 23;
+
+ static const Type tINTEL = 1 << 24;
+ static const Type tAMD = 1 << 25;
+
+ static const Type tENHANCED_REP = 1 << 26; // enhanced rep movsb/stosb
+ static const Type tRDRAND = 1 << 27;
+ static const Type tADX = 1 << 28; // adcx, adox
+ static const Type tRDSEED = 1 << 29; // rdseed
+ static const Type tSMAP = 1 << 30; // stac
+ static const Type tHLE = uint64(1) << 31; // xacquire, xrelease, xtest
+ static const Type tRTM = uint64(1) << 32; // xbegin, xend, xabort
+ static const Type tF16C = uint64(1) << 33; // vcvtph2ps, vcvtps2ph
+ static const Type tMOVBE = uint64(1) << 34; // mobve
+ static const Type tAVX512F = uint64(1) << 35;
+ static const Type tAVX512DQ = uint64(1) << 36;
+ static const Type tAVX512_IFMA = uint64(1) << 37;
+ static const Type tAVX512IFMA = tAVX512_IFMA;
+ static const Type tAVX512PF = uint64(1) << 38;
+ static const Type tAVX512ER = uint64(1) << 39;
+ static const Type tAVX512CD = uint64(1) << 40;
+ static const Type tAVX512BW = uint64(1) << 41;
+ static const Type tAVX512VL = uint64(1) << 42;
+ static const Type tAVX512_VBMI = uint64(1) << 43;
+ static const Type tAVX512VBMI = tAVX512_VBMI; // changed by Intel's manual
+ static const Type tAVX512_4VNNIW = uint64(1) << 44;
+ static const Type tAVX512_4FMAPS = uint64(1) << 45;
+ static const Type tPREFETCHWT1 = uint64(1) << 46;
+ static const Type tPREFETCHW = uint64(1) << 47;
+ static const Type tSHA = uint64(1) << 48;
+ static const Type tMPX = uint64(1) << 49;
+ static const Type tAVX512_VBMI2 = uint64(1) << 50;
+ static const Type tGFNI = uint64(1) << 51;
+ static const Type tVAES = uint64(1) << 52;
+ static const Type tVPCLMULQDQ = uint64(1) << 53;
+ static const Type tAVX512_VNNI = uint64(1) << 54;
+ static const Type tAVX512_BITALG = uint64(1) << 55;
+ static const Type tAVX512_VPOPCNTDQ = uint64(1) << 56;
+
+ Cpu()
+ : type_(NONE)
+ , x2APIC_supported_(false)
+ , numCores_()
+ , dataCacheSize_()
+ , coresSharignDataCache_()
+ , dataCacheLevels_(0)
+ {
+ unsigned int data[4] = {};
+ const unsigned int& EAX = data[0];
+ const unsigned int& EBX = data[1];
+ const unsigned int& ECX = data[2];
+ const unsigned int& EDX = data[3];
+ getCpuid(0, data);
+ const unsigned int maxNum = EAX;
+ static const char intel[] = "ntel";
+ static const char amd[] = "cAMD";
+ if (ECX == get32bitAsBE(amd)) {
+ type_ |= tAMD;
+ getCpuid(0x80000001, data);
+ if (EDX & (1U << 31)) type_ |= t3DN;
+ if (EDX & (1U << 15)) type_ |= tCMOV;
+ if (EDX & (1U << 30)) type_ |= tE3DN;
+ if (EDX & (1U << 22)) type_ |= tMMX2;
+ if (EDX & (1U << 27)) type_ |= tRDTSCP;
+ }
+ if (ECX == get32bitAsBE(intel)) {
+ type_ |= tINTEL;
+ getCpuid(0x80000001, data);
+ if (EDX & (1U << 27)) type_ |= tRDTSCP;
+ if (ECX & (1U << 5)) type_ |= tLZCNT;
+ if (ECX & (1U << 8)) type_ |= tPREFETCHW;
+ }
+ getCpuid(1, data);
+ if (ECX & (1U << 0)) type_ |= tSSE3;
+ if (ECX & (1U << 9)) type_ |= tSSSE3;
+ if (ECX & (1U << 19)) type_ |= tSSE41;
+ if (ECX & (1U << 20)) type_ |= tSSE42;
+ if (ECX & (1U << 22)) type_ |= tMOVBE;
+ if (ECX & (1U << 23)) type_ |= tPOPCNT;
+ if (ECX & (1U << 25)) type_ |= tAESNI;
+ if (ECX & (1U << 1)) type_ |= tPCLMULQDQ;
+ if (ECX & (1U << 27)) type_ |= tOSXSAVE;
+ if (ECX & (1U << 30)) type_ |= tRDRAND;
+ if (ECX & (1U << 29)) type_ |= tF16C;
+
+ if (EDX & (1U << 15)) type_ |= tCMOV;
+ if (EDX & (1U << 23)) type_ |= tMMX;
+ if (EDX & (1U << 25)) type_ |= tMMX2 | tSSE;
+ if (EDX & (1U << 26)) type_ |= tSSE2;
+
+ if (type_ & tOSXSAVE) {
+ // check XFEATURE_ENABLED_MASK[2:1] = '11b'
+ uint64 bv = getXfeature();
+ if ((bv & 6) == 6) {
+ if (ECX & (1U << 28)) type_ |= tAVX;
+ if (ECX & (1U << 12)) type_ |= tFMA;
+ if (((bv >> 5) & 7) == 7) {
+ getCpuidEx(7, 0, data);
+ if (EBX & (1U << 16)) type_ |= tAVX512F;
+ if (type_ & tAVX512F) {
+ if (EBX & (1U << 17)) type_ |= tAVX512DQ;
+ if (EBX & (1U << 21)) type_ |= tAVX512_IFMA;
+ if (EBX & (1U << 26)) type_ |= tAVX512PF;
+ if (EBX & (1U << 27)) type_ |= tAVX512ER;
+ if (EBX & (1U << 28)) type_ |= tAVX512CD;
+ if (EBX & (1U << 30)) type_ |= tAVX512BW;
+ if (EBX & (1U << 31)) type_ |= tAVX512VL;
+ if (ECX & (1U << 1)) type_ |= tAVX512_VBMI;
+ if (ECX & (1U << 6)) type_ |= tAVX512_VBMI2;
+ if (ECX & (1U << 8)) type_ |= tGFNI;
+ if (ECX & (1U << 9)) type_ |= tVAES;
+ if (ECX & (1U << 10)) type_ |= tVPCLMULQDQ;
+ if (ECX & (1U << 11)) type_ |= tAVX512_VNNI;
+ if (ECX & (1U << 12)) type_ |= tAVX512_BITALG;
+ if (ECX & (1U << 14)) type_ |= tAVX512_VPOPCNTDQ;
+ if (EDX & (1U << 2)) type_ |= tAVX512_4VNNIW;
+ if (EDX & (1U << 3)) type_ |= tAVX512_4FMAPS;
+ }
+ }
+ }
+ }
+ if (maxNum >= 7) {
+ getCpuidEx(7, 0, data);
+ if (type_ & tAVX && (EBX & (1U << 5))) type_ |= tAVX2;
+ if (EBX & (1U << 3)) type_ |= tBMI1;
+ if (EBX & (1U << 8)) type_ |= tBMI2;
+ if (EBX & (1U << 9)) type_ |= tENHANCED_REP;
+ if (EBX & (1U << 18)) type_ |= tRDSEED;
+ if (EBX & (1U << 19)) type_ |= tADX;
+ if (EBX & (1U << 20)) type_ |= tSMAP;
+ if (EBX & (1U << 4)) type_ |= tHLE;
+ if (EBX & (1U << 11)) type_ |= tRTM;
+ if (EBX & (1U << 14)) type_ |= tMPX;
+ if (EBX & (1U << 29)) type_ |= tSHA;
+ if (ECX & (1U << 0)) type_ |= tPREFETCHWT1;
+ }
+ setFamily();
+ setNumCores();
+ setCacheHierarchy();
+ }
+ void putFamily() const
+ {
+ printf("family=%d, model=%X, stepping=%d, extFamily=%d, extModel=%X\n",
+ family, model, stepping, extFamily, extModel);
+ printf("display:family=%X, model=%X\n", displayFamily, displayModel);
+ }
+ bool has(Type type) const
+ {
+ return (type & type_) != 0;
+ }
+};
+
+class Clock {
+public:
+ static inline uint64 getRdtsc()
+ {
+#ifdef XBYAK_INTEL_CPU_SPECIFIC
+ #ifdef _MSC_VER
+ return __rdtsc();
+ #else
+ unsigned int eax, edx;
+ __asm__ volatile("rdtsc" : "=a"(eax), "=d"(edx));
+ return ((uint64)edx << 32) | eax;
+ #endif
+#else
+ // TODO: Need another impl of Clock or rdtsc-equivalent for non-x86 cpu
+ return 0;
+#endif
+ }
+ Clock()
+ : clock_(0)
+ , count_(0)
+ {
+ }
+ void begin()
+ {
+ clock_ -= getRdtsc();
+ }
+ void end()
+ {
+ clock_ += getRdtsc();
+ count_++;
+ }
+ int getCount() const { return count_; }
+ uint64 getClock() const { return clock_; }
+ void clear() { count_ = 0; clock_ = 0; }
+private:
+ uint64 clock_;
+ int count_;
+};
+
+#ifdef XBYAK64
+const int UseRCX = 1 << 6;
+const int UseRDX = 1 << 7;
+
+class Pack {
+ static const size_t maxTblNum = 15;
+ const Xbyak::Reg64 *tbl_[maxTblNum];
+ size_t n_;
+public:
+ Pack() : tbl_(), n_(0) {}
+ Pack(const Xbyak::Reg64 *tbl, size_t n) { init(tbl, n); }
+ Pack(const Pack& rhs)
+ : n_(rhs.n_)
+ {
+ for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i];
+ }
+ Pack& operator=(const Pack& rhs)
+ {
+ n_ = rhs.n_;
+ for (size_t i = 0; i < n_; i++) tbl_[i] = rhs.tbl_[i];
+ return *this;
+ }
+ Pack(const Xbyak::Reg64& t0)
+ { n_ = 1; tbl_[0] = &t0; }
+ Pack(const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
+ { n_ = 2; tbl_[0] = &t0; tbl_[1] = &t1; }
+ Pack(const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
+ { n_ = 3; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; }
+ Pack(const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
+ { n_ = 4; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; }
+ Pack(const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
+ { n_ = 5; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; }
+ Pack(const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
+ { n_ = 6; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; }
+ Pack(const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
+ { n_ = 7; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; }
+ Pack(const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
+ { n_ = 8; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; }
+ Pack(const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
+ { n_ = 9; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; }
+ Pack(const Xbyak::Reg64& t9, const Xbyak::Reg64& t8, const Xbyak::Reg64& t7, const Xbyak::Reg64& t6, const Xbyak::Reg64& t5, const Xbyak::Reg64& t4, const Xbyak::Reg64& t3, const Xbyak::Reg64& t2, const Xbyak::Reg64& t1, const Xbyak::Reg64& t0)
+ { n_ = 10; tbl_[0] = &t0; tbl_[1] = &t1; tbl_[2] = &t2; tbl_[3] = &t3; tbl_[4] = &t4; tbl_[5] = &t5; tbl_[6] = &t6; tbl_[7] = &t7; tbl_[8] = &t8; tbl_[9] = &t9; }
+ Pack& append(const Xbyak::Reg64& t)
+ {
+ if (n_ == maxTblNum) {
+ fprintf(stderr, "ERR Pack::can't append\n");
+ throw Error(ERR_BAD_PARAMETER);
+ }
+ tbl_[n_++] = &t;
+ return *this;
+ }
+ void init(const Xbyak::Reg64 *tbl, size_t n)
+ {
+ if (n > maxTblNum) {
+ fprintf(stderr, "ERR Pack::init bad n=%d\n", (int)n);
+ throw Error(ERR_BAD_PARAMETER);
+ }
+ n_ = n;
+ for (size_t i = 0; i < n; i++) {
+ tbl_[i] = &tbl[i];
+ }
+ }
+ const Xbyak::Reg64& operator[](size_t n) const
+ {
+ if (n >= n_) {
+ fprintf(stderr, "ERR Pack bad n=%d(%d)\n", (int)n, (int)n_);
+ throw Error(ERR_BAD_PARAMETER);
+ }
+ return *tbl_[n];
+ }
+ size_t size() const { return n_; }
+ /*
+ get tbl[pos, pos + num)
+ */
+ Pack sub(size_t pos, size_t num = size_t(-1)) const
+ {
+ if (num == size_t(-1)) num = n_ - pos;
+ if (pos + num > n_) {
+ fprintf(stderr, "ERR Pack::sub bad pos=%d, num=%d\n", (int)pos, (int)num);
+ throw Error(ERR_BAD_PARAMETER);
+ }
+ Pack pack;
+ pack.n_ = num;
+ for (size_t i = 0; i < num; i++) {
+ pack.tbl_[i] = tbl_[pos + i];
+ }
+ return pack;
+ }
+ void put() const
+ {
+ for (size_t i = 0; i < n_; i++) {
+ printf("%s ", tbl_[i]->toString());
+ }
+ printf("\n");
+ }
+};
+
+class StackFrame {
+#ifdef XBYAK64_WIN
+ static const int noSaveNum = 6;
+ static const int rcxPos = 0;
+ static const int rdxPos = 1;
+#else
+ static const int noSaveNum = 8;
+ static const int rcxPos = 3;
+ static const int rdxPos = 2;
+#endif
+ static const int maxRegNum = 14; // maxRegNum = 16 - rsp - rax
+ Xbyak::CodeGenerator *code_;
+ int pNum_;
+ int tNum_;
+ bool useRcx_;
+ bool useRdx_;
+ int saveNum_;
+ int P_;
+ bool makeEpilog_;
+ Xbyak::Reg64 pTbl_[4];
+ Xbyak::Reg64 tTbl_[maxRegNum];
+ Pack p_;
+ Pack t_;
+ StackFrame(const StackFrame&);
+ void operator=(const StackFrame&);
+public:
+ const Pack& p;
+ const Pack& t;
+ /*
+ make stack frame
+ @param sf [in] this
+ @param pNum [in] num of function parameter(0 <= pNum <= 4)
+ @param tNum [in] num of temporary register(0 <= tNum, with UseRCX, UseRDX) #{pNum + tNum [+rcx] + [rdx]} <= 14
+ @param stackSizeByte [in] local stack size
+ @param makeEpilog [in] automatically call close() if true
+
+ you can use
+ rax
+ gp0, ..., gp(pNum - 1)
+ gt0, ..., gt(tNum-1)
+ rcx if tNum & UseRCX
+ rdx if tNum & UseRDX
+ rsp[0..stackSizeByte - 1]
+ */
+ StackFrame(Xbyak::CodeGenerator *code, int pNum, int tNum = 0, int stackSizeByte = 0, bool makeEpilog = true)
+ : code_(code)
+ , pNum_(pNum)
+ , tNum_(tNum & ~(UseRCX | UseRDX))
+ , useRcx_((tNum & UseRCX) != 0)
+ , useRdx_((tNum & UseRDX) != 0)
+ , saveNum_(0)
+ , P_(0)
+ , makeEpilog_(makeEpilog)
+ , p(p_)
+ , t(t_)
+ {
+ using namespace Xbyak;
+ if (pNum < 0 || pNum > 4) throw Error(ERR_BAD_PNUM);
+ const int allRegNum = pNum + tNum_ + (useRcx_ ? 1 : 0) + (useRdx_ ? 1 : 0);
+ if (tNum_ < 0 || allRegNum > maxRegNum) throw Error(ERR_BAD_TNUM);
+ const Reg64& _rsp = code->rsp;
+ saveNum_ = (std::max)(0, allRegNum - noSaveNum);
+ const int *tbl = getOrderTbl() + noSaveNum;
+ for (int i = 0; i < saveNum_; i++) {
+ code->push(Reg64(tbl[i]));
+ }
+ P_ = (stackSizeByte + 7) / 8;
+ if (P_ > 0 && (P_ & 1) == (saveNum_ & 1)) P_++; // (rsp % 16) == 8, then increment P_ for 16 byte alignment
+ P_ *= 8;
+ if (P_ > 0) code->sub(_rsp, P_);
+ int pos = 0;
+ for (int i = 0; i < pNum; i++) {
+ pTbl_[i] = Xbyak::Reg64(getRegIdx(pos));
+ }
+ for (int i = 0; i < tNum_; i++) {
+ tTbl_[i] = Xbyak::Reg64(getRegIdx(pos));
+ }
+ if (useRcx_ && rcxPos < pNum) code_->mov(code_->r10, code_->rcx);
+ if (useRdx_ && rdxPos < pNum) code_->mov(code_->r11, code_->rdx);
+ p_.init(pTbl_, pNum);
+ t_.init(tTbl_, tNum_);
+ }
+ /*
+ make epilog manually
+ @param callRet [in] call ret() if true
+ */
+ void close(bool callRet = true)
+ {
+ using namespace Xbyak;
+ const Reg64& _rsp = code_->rsp;
+ const int *tbl = getOrderTbl() + noSaveNum;
+ if (P_ > 0) code_->add(_rsp, P_);
+ for (int i = 0; i < saveNum_; i++) {
+ code_->pop(Reg64(tbl[saveNum_ - 1 - i]));
+ }
+
+ if (callRet) code_->ret();
+ }
+ ~StackFrame()
+ {
+ if (!makeEpilog_) return;
+ try {
+ close();
+ } catch (std::exception& e) {
+ printf("ERR:StackFrame %s\n", e.what());
+ //exit(1);
+ }
+ }
+private:
+ const int *getOrderTbl() const
+ {
+ using namespace Xbyak;
+ static const int tbl[] = {
+#ifdef XBYAK64_WIN
+ Operand::RCX, Operand::RDX, Operand::R8, Operand::R9, Operand::R10, Operand::R11, Operand::RDI, Operand::RSI,
+#else
+ Operand::RDI, Operand::RSI, Operand::RDX, Operand::RCX, Operand::R8, Operand::R9, Operand::R10, Operand::R11,
+#endif
+ Operand::RBX, Operand::RBP, Operand::R12, Operand::R13, Operand::R14, Operand::R15
+ };
+ return &tbl[0];
+ }
+ int getRegIdx(int& pos) const
+ {
+ assert(pos < maxRegNum);
+ using namespace Xbyak;
+ const int *tbl = getOrderTbl();
+ int r = tbl[pos++];
+ if (useRcx_) {
+ if (r == Operand::RCX) { return Operand::R10; }
+ if (r == Operand::R10) { r = tbl[pos++]; }
+ }
+ if (useRdx_) {
+ if (r == Operand::RDX) { return Operand::R11; }
+ if (r == Operand::R11) { return tbl[pos++]; }
+ }
+ return r;
+ }
+};
+#endif
+
+} } // end of util
+#endif